yakvrz commited on
Commit
0c4c32b
·
0 Parent(s):

Initial import

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +5 -0
  2. ARCHITECTURE.md +90 -0
  3. gradio_app.py +1321 -0
  4. src/depth_anything_3/__init__.py +21 -0
  5. src/depth_anything_3/api.py +416 -0
  6. src/depth_anything_3/app/__init__.py +1 -0
  7. src/depth_anything_3/app/css_and_html.py +594 -0
  8. src/depth_anything_3/app/gradio_app.py +747 -0
  9. src/depth_anything_3/app/modules/__init__.py +45 -0
  10. src/depth_anything_3/app/modules/event_handlers.py +629 -0
  11. src/depth_anything_3/app/modules/file_handlers.py +304 -0
  12. src/depth_anything_3/app/modules/model_inference.py +286 -0
  13. src/depth_anything_3/app/modules/ui_components.py +474 -0
  14. src/depth_anything_3/app/modules/utils.py +211 -0
  15. src/depth_anything_3/app/modules/visualization.py +434 -0
  16. src/depth_anything_3/cfg.py +144 -0
  17. src/depth_anything_3/cli.py +742 -0
  18. src/depth_anything_3/configs/da3-base.yaml +45 -0
  19. src/depth_anything_3/configs/da3-giant.yaml +71 -0
  20. src/depth_anything_3/configs/da3-large.yaml +45 -0
  21. src/depth_anything_3/configs/da3-small.yaml +45 -0
  22. src/depth_anything_3/configs/da3metric-large.yaml +28 -0
  23. src/depth_anything_3/configs/da3mono-large.yaml +28 -0
  24. src/depth_anything_3/configs/da3nested-giant-large.yaml +10 -0
  25. src/depth_anything_3/model/__init__.py +20 -0
  26. src/depth_anything_3/model/cam_dec.py +45 -0
  27. src/depth_anything_3/model/cam_enc.py +80 -0
  28. src/depth_anything_3/model/da3.py +378 -0
  29. src/depth_anything_3/model/dinov2/dinov2.py +64 -0
  30. src/depth_anything_3/model/dinov2/layers/__init__.py +25 -0
  31. src/depth_anything_3/model/dinov2/layers/attention.py +100 -0
  32. src/depth_anything_3/model/dinov2/layers/block.py +143 -0
  33. src/depth_anything_3/model/dinov2/layers/drop_path.py +35 -0
  34. src/depth_anything_3/model/dinov2/layers/layer_scale.py +31 -0
  35. src/depth_anything_3/model/dinov2/layers/mlp.py +40 -0
  36. src/depth_anything_3/model/dinov2/layers/patch_embed.py +94 -0
  37. src/depth_anything_3/model/dinov2/layers/rope.py +200 -0
  38. src/depth_anything_3/model/dinov2/layers/swiglu_ffn.py +62 -0
  39. src/depth_anything_3/model/dinov2/vision_transformer.py +437 -0
  40. src/depth_anything_3/model/dpt.py +457 -0
  41. src/depth_anything_3/model/dualdpt.py +488 -0
  42. src/depth_anything_3/model/gs_adapter.py +200 -0
  43. src/depth_anything_3/model/gsdpt.py +133 -0
  44. src/depth_anything_3/model/utils/attention.py +109 -0
  45. src/depth_anything_3/model/utils/block.py +81 -0
  46. src/depth_anything_3/model/utils/gs_renderer.py +340 -0
  47. src/depth_anything_3/model/utils/head_utils.py +230 -0
  48. src/depth_anything_3/model/utils/transform.py +208 -0
  49. src/depth_anything_3/registry.py +50 -0
  50. src/depth_anything_3/services/__init__.py +24 -0
.gitignore ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ __pycache__/
2
+ .DS_Store
3
+ .python-version
4
+ data/
5
+ *.pyc
ARCHITECTURE.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Landing Site Safety Analyzer – Architecture and Calculations
2
+
3
+ This document describes the full computation flow in `gradio_app.py`, from input handling through model inference, safety scoring, and UI composition.
4
+
5
+ ## Data and Models
6
+ - **Inputs**: Images from `data/Image/VISLOC` (dropdown populated via `list_visloc_images()` and a 5% border crop via `crop_nonblack` to eliminate black padding). Supported extensions: jpg/jpeg/png (any case).
7
+ - **Depth model**: Depth Anything 3, loaded lazily and cached (`get_model`). Inference is invoked with `process_res` and `process_res_method="upper_bound_resize"` to resize the long image side to the chosen resolution before predicting.
8
+ - **Segmentation model(s)**: Mask2Former ADE20K (`facebook/mask2former-swin-large-ade-semantic`) for optional water and road masking. Loaded lazily and cached per model id (`get_water_segmenter`). Masks are cached per `(model_id, source_path)` to avoid recomputation.
9
+
10
+ ## Constants and Assumptions
11
+ - Image crop: 5% on each border before processing.
12
+ - Safety thresholds: user-controlled sliders for flatness (`std_thresh`) and gradient (`grad_thresh`).
13
+ - Altitude and FOV are user-provided (defaults 450 m, 90°).
14
+ - Roof mask mode defaults to “Buildings only” (optionally includes “roof” labels).
15
+ - Flatness detail slider scales the flatness visualization window.
16
+ - Clearance factor dilates hazards to enforce buffer distance.
17
+
18
+ ## Per-Image Processing Pipeline
19
+ 1. **Load and crop** the selected image to RGB, removing 5% border padding.
20
+ 2. **Depth inference**: Run Depth Anything 3 at the input image’s native resolution; get a depth map (`depth_raw`).
21
+ 3. **Plane removal**: Fit a best-fit plane over the depth map via least squares (`remove_global_plane`) and subtract it to obtain a detrended depth (`depth`). This prevents large, tilted planes (e.g., concrete yards) from being mislabeled as sloped.
22
+ 4. **Footprint to pixels**:
23
+ - Compute focal length in pixels: `fx = (W/2) / tan(FOV/2)`, where `W` is the depth map width and `FOV` is user-set.
24
+ - Convert user footprint (meters) to depth pixels: `patch_px = footprint_m * fx / altitude_m`, clamped to the depth map bounds and forced to be odd.
25
+ 5. **Flatness visualization scale**: Compute a separate window size `vis_patch` (odd), scaled by the flatness detail slider and capped to 1/10 of the smallest map dimension, for a visualization-only std map (sharper view).
26
+ 6. **Optional masks**:
27
+ - Water and road masks are computed on the (possibly downscaled) RGB using Mask2Former, resized to depth resolution with nearest-neighbor, and converted to boolean arrays. Cached per input path.
28
+ - Roof mask is depth-based: pixels significantly closer than the median depth (per MAD threshold) are treated as raised structures; morphology smooths the mask.
29
+ 7. **Flat region search (`pick_flat_patch`)**:
30
+ - Normalize depth to [0,1].
31
+ - Compute local mean and variance via reflective padding and torch avg pooling over `patch` to obtain `std_map`.
32
+ - Compute gradient magnitude via `np.gradient`; normalize by the 95th percentile to get `grad_norm`; derive `grad_mask = grad_norm < grad_thresh`.
33
+ - Combine masks: start with `landing_mask = grad_mask`; if a water mask is available, exclude water pixels; return the lowest-variance window inside this mask as a candidate patch (`box`).
34
+ 8. **Safe mask construction**:
35
+ - Initial safe pixels: `(std_map < std_thresh) & (grad_norm < grad_thresh) & landing_mask`.
36
+ - Clearance buffer: build a hazard mask as the inverse of safe pixels, dilate it with an ellipse kernel sized from `clearance_factor * patch_px`, and subtract from safe_mask to enforce distance from hazards.
37
+ - Enforce full-footprint coverage by box filtering the safe mask with a `patch_px` window; keep pixels where coverage >= 0.999.
38
+ - Remove small components: require at least one footprint area (`patch_px * patch_px`) per connected component.
39
+ 9. **Landing spot selection**:
40
+ - Prefer centers where the full footprint is safe: find pixels with full coverage, choose the largest connected component, and pick its flattest point (lowest `std_map`).
41
+ - Fallback: center of the flattest window from step 7.
42
+ - Convert the chosen depth-space center to image-space using `scale_x = W_img / W_depth` and draw a square whose side is `patch_px * scale_x` (clamped, min 3 px) on the original image. Also draw a center dot.
43
+ 10. **Visualization layers**:
44
+ - Depth visualization uses the original depth (`depth_raw`) rendered with `visualize_depth` and resized to the input image size.
45
+ - Flatness map (std), gradient magnitude, gradient mask, water mask, road mask are resized to image size for display.
46
+ - Safety heatmap overlay and grayscale score are built from the final `safe_mask`.
47
+
48
+ ## Safety Heatmap Calculation
49
+ - Input: boolean `safe_mask` (post-coverage + component filtering).
50
+ - Convert to `score` in [0,1].
51
+ - Color mapping (per-pixel):
52
+ - Red channel: `red = (1 - score) ** 0.9 * 255`.
53
+ - Green channel: `green = score ** 1.2 * 200` (gamma and cap to prevent overpowering red).
54
+ - Resize to image size with nearest-neighbor for overlays. A grayscale safety score is also produced for debugging/alternative views.
55
+
56
+ ## Overlay Composition (`compose_view`)
57
+ - Start from the selected base view (RGB/Depth/Flatness/Gradient/Mask/Heatmap/Score/Landing spot).
58
+ - Add overlays conditionally with per-layer alpha:
59
+ - Safety heatmap (RGBA alpha from slider).
60
+ - Gradient, flatness maps (alpha from sliders).
61
+ - Water and road masks: convert mask to luminance, modulate alpha by mask pixels and slider, tint red (water) or orange (road).
62
+ - Landing spot overlay is composited last.
63
+ - Output is converted back to RGB.
64
+
65
+ ## Caching and State
66
+ - **Model cache**: depth model loaded once per process; segmentation model loaded once per model id.
67
+ - **Mask cache**: water/road masks cached per `(model_id, source_path)` to avoid recomputation across UI tweaks.
68
+ - **Image state**: `images_state` (Gradio State) holds the latest computed layers; overlay-only UI controls recompute compositions without rerunning inference unless base parameters change.
69
+
70
+ ## User Controls and Their Effects
71
+ - `process_res`: sets inference resize; affects depth resolution and subsequent pixel scaling.
72
+ - `footprint_m`: desired landing square (meters); converted to pixels via FOV/altitude.
73
+ - `altitude_m`, `fov_deg`: camera parameters for footprint sizing.
74
+ - `flatness_detail`: scales the visualization window for the flatness map.
75
+ - `clearance_factor`: scales the hazard dilation kernel.
76
+ - `std_thresh`, `grad_thresh`: safety criteria for flatness and slope.
77
+ - `use_water_mask` / `use_road_mask` (segmentation-based) and `use_roof_mask` (depth-based): enable masking in safety logic; overlays are separately toggled via `*_on` and alpha sliders.
78
+ - `base_view`, overlay toggles, and alphas: affect only display composition, not safety computation (except when needing masks to be computed for display).
79
+ - `model_id`: selects the Depth Anything 3 checkpoint.
80
+
81
+ ## Error Handling
82
+ - Missing/unsupported input paths raise Gradio errors.
83
+ - Water/road masking failures log a warning and fall back to no mask.
84
+ - Coverage/boxFilter failures fall back to looser checks; if no finite std values are available within masks, the flattest region overall is used.
85
+
86
+ ## Outputs
87
+ The pipeline returns a dictionary of PIL Images keyed by view name:
88
+ - RGB, Depth, Flatness map (std), Depth gradient, Gradient mask, Water mask, Road mask, Safety heatmap overlay, Safety score (grayscale), Landing spot overlay.
89
+
90
+ These are then composed into the preview according to UI choices.
gradio_app.py ADDED
@@ -0,0 +1,1321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio demo: depth overlays on VISLOC imagery using Depth Anything 3.
4
+
5
+ Run:
6
+ python gradio_app.py
7
+
8
+ Then open the printed local URL. Requires: gradio, pillow, torch, transformers (for water mask).
9
+ """
10
+
11
+ import cv2
12
+ import functools
13
+ import math
14
+ import os
15
+ from pathlib import Path
16
+
17
+ import gradio as gr
18
+ import numpy as np
19
+ import torch
20
+ from PIL import Image, ImageDraw, ImageFilter
21
+ import matplotlib.cm as cm
22
+
23
+ # Prefer installed package; fall back to local src for dev runs.
24
+ try:
25
+ from depth_anything_3.api import DepthAnything3 # type: ignore
26
+ from depth_anything_3.utils.visualize import visualize_depth # type: ignore
27
+ except ModuleNotFoundError:
28
+ import sys
29
+
30
+ ROOT = Path(__file__).resolve().parent
31
+ sys.path.append(str(ROOT / "src"))
32
+ from depth_anything_3.api import DepthAnything3 # noqa: E402
33
+ from depth_anything_3.utils.visualize import visualize_depth # noqa: E402
34
+
35
+ VISLOC_DIR = Path("data/Image/VISLOC")
36
+ HAGDAVS_DIR = Path("data/Image/HAGDAVS")
37
+ VIDEO_DIR = Path("data/Video")
38
+ IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG")
39
+ VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
40
+ DEFAULT_ALTITUDE_M = 450.0
41
+ ASSUMED_FOV_DEG = 90.0
42
+ WATER_MODEL_ID = "facebook/mask2former-swin-large-ade-semantic"
43
+ ROAD_MODEL_ID = "facebook/mask2former-swin-large-ade-semantic"
44
+
45
+ def crop_nonblack(img: Image.Image, frac: float = 0.05) -> Image.Image:
46
+ """Naively crop a fixed fraction off each border (to drop black padding)."""
47
+ w, h = img.size
48
+ dx = int(round(w * frac))
49
+ dy = int(round(h * frac))
50
+ return img.crop((dx, dy, w - dx, h - dy))
51
+
52
+
53
+ @functools.lru_cache(maxsize=1)
54
+ def get_water_segmenter(model_id: str):
55
+ """Load Mask2Former for water masking (kept on CPU to avoid OOM)."""
56
+ try:
57
+ from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
58
+ except ImportError as e:
59
+ raise ImportError("transformers is required for water masking; install with `pip install transformers`") from e
60
+ device = torch.device("cpu")
61
+ try:
62
+ processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
63
+ except TypeError:
64
+ processor = AutoImageProcessor.from_pretrained(model_id)
65
+ model = Mask2FormerForUniversalSegmentation.from_pretrained(model_id).to(device)
66
+ model.eval()
67
+ return processor, model, device
68
+
69
+
70
+ def compute_water_mask(img: Image.Image, model_id: str, max_side: int = 640) -> np.ndarray | None:
71
+ """Return boolean mask for water-like classes using Mask2Former ADE weights."""
72
+ processor, model, device = get_water_segmenter(model_id)
73
+ try:
74
+ img_proc = img
75
+ if max(img.size) > max_side:
76
+ scale = max_side / max(img.size)
77
+ new_size = (int(round(img.size[0] * scale)), int(round(img.size[1] * scale)))
78
+ img_proc = img.resize(new_size, resample=Image.BILINEAR)
79
+ try:
80
+ inputs = processor(images=img_proc, return_tensors="pt", use_fast=True).to(device)
81
+ except TypeError:
82
+ inputs = processor(images=img_proc, return_tensors="pt").to(device)
83
+ with torch.inference_mode():
84
+ outputs = model(**inputs)
85
+ seg = processor.post_process_semantic_segmentation(outputs, target_sizes=[img_proc.size[::-1]])[0]
86
+ if torch.is_tensor(seg):
87
+ seg = seg.cpu()
88
+ labels = model.config.id2label
89
+ keywords = ["water", "sea", "lake", "river", "ocean", "pond"]
90
+ water_ids = {i for i, name in labels.items() if any(k in name.lower() for k in keywords)}
91
+ seg_np = np.array(seg)
92
+ mask_small = np.isin(seg_np, list(water_ids)).astype(np.uint8) * 255
93
+ mask_img = Image.fromarray(mask_small).resize(img.size, resample=Image.NEAREST)
94
+ return np.array(mask_img) > 0
95
+ except RuntimeError as e:
96
+ print(f"[WARN] Water masking failed (fallback to no water mask): {e}")
97
+ return None
98
+
99
+
100
+ def compute_road_mask(img: Image.Image, model_id: str, max_side: int = 640) -> np.ndarray | None:
101
+ """Return boolean mask for road/highway classes using Mask2Former ADE weights."""
102
+ processor, model, device = get_water_segmenter(model_id)
103
+ try:
104
+ img_proc = img
105
+ if max(img.size) > max_side:
106
+ scale = max_side / max(img.size)
107
+ new_size = (int(round(img.size[0] * scale)), int(round(img.size[1] * scale)))
108
+ img_proc = img.resize(new_size, resample=Image.BILINEAR)
109
+ try:
110
+ inputs = processor(images=img_proc, return_tensors="pt", use_fast=True).to(device)
111
+ except TypeError:
112
+ inputs = processor(images=img_proc, return_tensors="pt").to(device)
113
+ with torch.inference_mode():
114
+ outputs = model(**inputs)
115
+ seg = processor.post_process_semantic_segmentation(outputs, target_sizes=[img_proc.size[::-1]])[0]
116
+ if torch.is_tensor(seg):
117
+ seg = seg.cpu()
118
+ labels = model.config.id2label
119
+ keywords = ["highway", "road", "street", "runway"]
120
+ blocklist = ["field", "park", "grass", "lawn", "garden", "court", "yard", "green"]
121
+ road_ids = {
122
+ i
123
+ for i, name in labels.items()
124
+ if any(k in name.lower() for k in keywords) and not any(b in name.lower() for b in blocklist)
125
+ }
126
+ seg_np = np.array(seg)
127
+ mask_small = np.isin(seg_np, list(road_ids)).astype(np.uint8) * 255
128
+ mask_img = Image.fromarray(mask_small).resize(img.size, resample=Image.NEAREST)
129
+ return np.array(mask_img) > 0
130
+ except RuntimeError as e:
131
+ print(f"[WARN] Road masking failed (fallback to no road mask): {e}")
132
+ return None
133
+
134
+
135
+ def compute_roof_mask_depth(depth: np.ndarray, aggressiveness: float = 1.3, morph_kernel: int = 5) -> np.ndarray:
136
+ """Depth-based roof/structure mask: flag pixels significantly closer than the median (raised surfaces)."""
137
+ d = depth.astype(np.float32)
138
+ med = np.median(d)
139
+ mad = np.median(np.abs(d - med)) + 1e-6
140
+ threshold = med - aggressiveness * mad
141
+ mask = d < threshold
142
+ mask = mask.astype(np.uint8)
143
+ k = max(1, int(morph_kernel))
144
+ if k % 2 == 0:
145
+ k += 1
146
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
147
+ try:
148
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
149
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
150
+ except Exception:
151
+ pass
152
+ return mask > 0
153
+
154
+
155
+ def remove_global_plane(depth: np.ndarray) -> np.ndarray:
156
+ """Remove best-fit global plane from depth to avoid penalizing large flat areas viewed at an angle."""
157
+ if depth.ndim != 2:
158
+ return depth
159
+ h, w = depth.shape
160
+ yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
161
+ A = np.stack((xx, yy, np.ones_like(xx)), axis=-1).reshape(-1, 3)
162
+ b = depth.astype(np.float32).reshape(-1, 1)
163
+ try:
164
+ coef, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
165
+ plane = (A @ coef).reshape(h, w)
166
+ return depth - plane
167
+ except np.linalg.LinAlgError:
168
+ return depth
169
+
170
+
171
+ def pick_flat_patch(
172
+ depth: np.ndarray,
173
+ patch: int = 96,
174
+ std_thresh: float = 0.03,
175
+ grad_thresh: float = 0.35,
176
+ water_mask: np.ndarray | None = None,
177
+ ):
178
+ """Find a low-variance depth window as a proxy for flat landing area."""
179
+ depth = depth.astype(np.float32)
180
+ if depth.ndim != 2:
181
+ raise ValueError("Depth map must be 2D (H, W)")
182
+
183
+ patch = max(3, min(patch, min(depth.shape)))
184
+ if patch % 2 == 0:
185
+ patch += 1 # keeps pooling output same size
186
+ depth_norm = (depth - depth.min()) / (depth.ptp() + 1e-6)
187
+
188
+ # Efficient box std via torch avg pooling
189
+ import torch.nn.functional as F
190
+
191
+ def box_mean(arr, k):
192
+ pad = k // 2
193
+ t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
194
+ # Reflective padding avoids dark/bright rims in the std map
195
+ t = F.pad(t, (pad, pad, pad, pad), mode="reflect")
196
+ mean = F.avg_pool2d(t, kernel_size=k, stride=1, padding=0, count_include_pad=False)
197
+ return mean.squeeze(0).squeeze(0).numpy()
198
+
199
+ mean = box_mean(depth_norm, patch)
200
+ mean_sq = box_mean(depth_norm * depth_norm, patch)
201
+ var = np.maximum(mean_sq - mean * mean, 0.0)
202
+ std_map = np.sqrt(var)
203
+
204
+ # Gradient mask to down-weight slopes/edges
205
+ dy, dx = np.gradient(depth_norm)
206
+ grad = np.sqrt(dx * dx + dy * dy)
207
+ grad_ref = np.percentile(grad, 95) + 1e-6
208
+ grad_norm = np.clip(grad / grad_ref, 0.0, 1.0)
209
+ grad_mask = grad_norm < grad_thresh
210
+
211
+ landing_mask = grad_mask
212
+ if water_mask is not None and water_mask.shape == grad_mask.shape:
213
+ landing_mask = landing_mask & (~water_mask)
214
+
215
+ masked_std = np.where(landing_mask, std_map, np.inf)
216
+ if not np.isfinite(masked_std).any():
217
+ masked_std = std_map # fallback: just take the flattest spot
218
+ y, x = np.unravel_index(np.argmin(masked_std), masked_std.shape)
219
+ half = patch // 2
220
+ y0, y1 = max(y - half, 0), min(y + half, depth.shape[0] - 1)
221
+ x0, x1 = max(x - half, 0), min(x + half, depth.shape[1] - 1)
222
+ return (x0, y0, x1, y1), std_map, grad_norm, grad_mask, landing_mask
223
+
224
+
225
+ def make_safety_heatmap(
226
+ rgb: Image.Image,
227
+ safe_mask: np.ndarray,
228
+ ):
229
+ """Produce a safety heatmap overlay on RGB from a provided safe mask."""
230
+ score = np.clip(safe_mask.astype(np.float32), 0.0, 1.0)
231
+
232
+ # Color: red (unsafe) -> green (safe). Gamma the green channel and cap its max
233
+ # so bright green does not overpower red when blended on the base image.
234
+ green = np.power(score, 1.2) * 200.0
235
+ red = np.power(1.0 - score, 0.9) * 255.0
236
+ heat = np.zeros((*score.shape, 3), dtype=np.uint8)
237
+ heat[..., 0] = red
238
+ heat[..., 1] = green
239
+ heat_img = Image.fromarray(heat).resize(rgb.size, resample=Image.NEAREST)
240
+ score_gray = Image.fromarray((score * 255).astype(np.uint8)).resize(rgb.size, resample=Image.NEAREST)
241
+ return heat_img, score_gray
242
+
243
+
244
+ @functools.lru_cache(maxsize=1)
245
+ def get_model(model_id: str = "depth-anything/DA3METRIC-LARGE"):
246
+ """Load model once and cache."""
247
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
248
+ model = DepthAnything3.from_pretrained(model_id).to(device)
249
+ model.eval()
250
+ return model, device
251
+
252
+
253
+ @functools.lru_cache(maxsize=1)
254
+ def list_visloc_images() -> list[Path]:
255
+ """Return sorted VISLOC image paths from data/Image/VISLOC."""
256
+ if not VISLOC_DIR.exists():
257
+ return []
258
+ files = [p for p in VISLOC_DIR.iterdir() if p.suffix in IMAGE_EXTS]
259
+ return sorted(files)
260
+
261
+
262
+ @functools.lru_cache(maxsize=1)
263
+ def list_hagdavs_images() -> list[Path]:
264
+ """Return sorted HAGDAVS image paths from data/Image/HAGDAVS."""
265
+ if not HAGDAVS_DIR.exists():
266
+ return []
267
+ files = [p for p in HAGDAVS_DIR.iterdir() if p.suffix in IMAGE_EXTS]
268
+ return sorted(files)
269
+
270
+
271
+ @functools.lru_cache(maxsize=1)
272
+ def list_videos() -> list[Path]:
273
+ if not VIDEO_DIR.exists():
274
+ return []
275
+ files = [p for p in VIDEO_DIR.iterdir() if p.suffix.lower() in VIDEO_EXTS]
276
+ return sorted(files)
277
+
278
+
279
+ @functools.lru_cache(maxsize=1)
280
+ def list_all_data_inputs() -> list[str]:
281
+ """Collect VISLOC image files for selection."""
282
+ return [str(p) for p in list_visloc_images()]
283
+
284
+
285
+ # Simple cache for water/road masks keyed by (model_id, path)
286
+ WATER_MASK_CACHE: dict[tuple[str, str], np.ndarray] = {}
287
+ ROAD_MASK_CACHE: dict[tuple[str, str], np.ndarray] = {}
288
+
289
+
290
+ def run_on_image(
291
+ image: Image.Image,
292
+ footprint_m: float,
293
+ std_thresh: float,
294
+ grad_thresh: float,
295
+ use_water_mask: bool,
296
+ use_road_mask: bool,
297
+ use_roof_mask: bool,
298
+ altitude_m: float,
299
+ fov_deg: float,
300
+ flatness_detail: float,
301
+ clearance_factor: float,
302
+ process_res_cap: int,
303
+ roof_aggressiveness: float,
304
+ roof_morph_frac: float,
305
+ segmentation_max_side: int,
306
+ depth_smoothing_base: float,
307
+ coverage_strictness: float,
308
+ min_component_multiplier: float,
309
+ model_id: str,
310
+ source_path: str | None = None,
311
+ ) -> dict:
312
+ rgb_np = np.array(image)
313
+
314
+ model, device = get_model(model_id)
315
+ # Fixed upper-bound resolution (cap) while avoiding upscaling small images.
316
+ process_res = min(max(image.size), int(process_res_cap))
317
+ with torch.inference_mode():
318
+ pred = model.inference(
319
+ image=[rgb_np],
320
+ process_res=process_res,
321
+ process_res_method="upper_bound_resize",
322
+ export_dir=None,
323
+ )
324
+ depth_raw = np.array(pred.depth[0])
325
+ depth = remove_global_plane(depth_raw)
326
+ # Smooth depth for resolution-invariant flatness/gradient (higher res -> slightly more smoothing)
327
+ res_scale = max(0.5, min(2.5, process_res / 1024))
328
+ sigma = max(0.0, depth_smoothing_base) * res_scale
329
+ k = max(3, int(round(sigma * 3)) * 2 + 1)
330
+ try:
331
+ depth = cv2.GaussianBlur(depth, (k, k), sigmaX=sigma, sigmaY=sigma)
332
+ except Exception:
333
+ pass
334
+ # Convert landing footprint (meters) to pixels at current processed resolution
335
+ fov = max(10.0, min(170.0, float(fov_deg)))
336
+ altitude = max(1.0, float(altitude_m))
337
+ fx = (depth.shape[1] / 2.0) / math.tan(math.radians(fov) / 2.0)
338
+ patch_px = footprint_m * fx / altitude
339
+ patch_px = max(3, min(int(round(patch_px)), min(depth.shape) - 1))
340
+ if patch_px % 2 == 0:
341
+ patch_px += 1 # keep pooling symmetric
342
+
343
+ # For visualization, compute a flatness map with a smaller, sharper window (decoupled from footprint)
344
+ depth_norm = (depth - depth.min()) / (depth.ptp() + 1e-6)
345
+ vis_patch = max(
346
+ 5,
347
+ min(
348
+ int(max(1.0, flatness_detail) * patch_px),
349
+ min(depth.shape) // 10,
350
+ min(depth.shape) - 1,
351
+ ),
352
+ )
353
+ if vis_patch % 2 == 0:
354
+ vis_patch += 1
355
+ import torch.nn.functional as F
356
+
357
+ def box_mean_np(arr: np.ndarray, k: int):
358
+ pad = k // 2
359
+ t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
360
+ t = F.pad(t, (pad, pad, pad, pad), mode="reflect")
361
+ mean = F.avg_pool2d(t, kernel_size=k, stride=1, padding=0, count_include_pad=False)
362
+ return mean.squeeze(0).squeeze(0).numpy()
363
+
364
+ std_map_vis = np.sqrt(
365
+ np.maximum(box_mean_np(depth_norm * depth_norm, vis_patch) - box_mean_np(depth_norm, vis_patch) ** 2, 0.0)
366
+ )
367
+
368
+ # Optional water mask (resized to depth resolution)
369
+ water_mask_resized = None
370
+ water_mask_img = None
371
+ if use_water_mask:
372
+ cache_key = (WATER_MODEL_ID, source_path or "", int(segmentation_max_side))
373
+ if cache_key in WATER_MASK_CACHE:
374
+ water_mask_img = WATER_MASK_CACHE[cache_key]
375
+ else:
376
+ water_mask_img = compute_water_mask(image, WATER_MODEL_ID, max_side=segmentation_max_side)
377
+ if source_path is not None and water_mask_img is not None:
378
+ WATER_MASK_CACHE[cache_key] = water_mask_img
379
+ if water_mask_img is not None:
380
+ water_mask_resized = (
381
+ np.array(water_mask_img)
382
+ if isinstance(water_mask_img, np.ndarray)
383
+ else np.array(water_mask_img)
384
+ )
385
+ water_mask_resized = (
386
+ Image.fromarray(water_mask_resized.astype(np.uint8) * 255)
387
+ .resize((depth.shape[1], depth.shape[0]), resample=Image.NEAREST)
388
+ )
389
+ water_mask_resized = np.array(water_mask_resized) > 0
390
+
391
+ road_mask_resized = None
392
+ road_mask_img = None
393
+ if use_road_mask:
394
+ cache_key_r = (ROAD_MODEL_ID, source_path or "", int(segmentation_max_side))
395
+ if cache_key_r in ROAD_MASK_CACHE:
396
+ road_mask_img = ROAD_MASK_CACHE[cache_key_r]
397
+ else:
398
+ road_mask_img = compute_road_mask(image, ROAD_MODEL_ID, max_side=segmentation_max_side)
399
+ if source_path is not None and road_mask_img is not None:
400
+ ROAD_MASK_CACHE[cache_key_r] = road_mask_img
401
+ if road_mask_img is not None:
402
+ road_mask_resized = (
403
+ np.array(road_mask_img)
404
+ if isinstance(road_mask_img, np.ndarray)
405
+ else np.array(road_mask_img)
406
+ )
407
+ road_mask_resized = (
408
+ Image.fromarray(road_mask_resized.astype(np.uint8) * 255)
409
+ .resize((depth.shape[1], depth.shape[0]), resample=Image.NEAREST)
410
+ )
411
+ road_mask_resized = np.array(road_mask_resized) > 0
412
+ roof_mask_resized = None
413
+ if use_roof_mask:
414
+ # Depth-based elevation mask: closer-than-median surfaces are treated as roofs/structures.
415
+ aggressiveness = max(0.5, min(3.0, roof_aggressiveness))
416
+ morph_k = max(3, int(round(patch_px * roof_morph_frac)))
417
+ roof_mask_resized = compute_roof_mask_depth(depth, aggressiveness=aggressiveness, morph_kernel=morph_k)
418
+
419
+ box, std_map, grad_norm, grad_mask, landing_mask = pick_flat_patch(
420
+ depth,
421
+ patch=patch_px,
422
+ std_thresh=std_thresh,
423
+ grad_thresh=grad_thresh,
424
+ water_mask=water_mask_resized,
425
+ )
426
+ if road_mask_resized is not None:
427
+ landing_mask = landing_mask & (~road_mask_resized)
428
+ if roof_mask_resized is not None:
429
+ landing_mask = landing_mask & (~roof_mask_resized)
430
+ safe_mask = (std_map < std_thresh) & (grad_norm < grad_thresh) & landing_mask
431
+ # Clearance: dilate hazards to enforce buffer around unsafe regions
432
+ try:
433
+ clearance_px = max(1, int(round(clearance_factor * patch_px)))
434
+ if clearance_px % 2 == 0:
435
+ clearance_px += 1
436
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clearance_px, clearance_px))
437
+ hazard = (~safe_mask).astype(np.uint8)
438
+ buffered = cv2.dilate(hazard, kernel, iterations=1).astype(bool)
439
+ safe_mask = safe_mask & (~buffered)
440
+ except Exception:
441
+ pass
442
+ # Strict footprint coverage: a center is safe only if the full footprint is safe
443
+ try:
444
+ coverage = cv2.boxFilter(
445
+ safe_mask.astype(np.float32),
446
+ ddepth=-1,
447
+ ksize=(patch_px, patch_px),
448
+ normalize=True,
449
+ anchor=(patch_px // 2, patch_px // 2),
450
+ )
451
+ safe_mask = coverage >= max(0.0, min(1.0, coverage_strictness))
452
+ except Exception:
453
+ pass
454
+
455
+ # Drop tiny components: require at least one footprint area
456
+ area_thresh = max(1, int(patch_px * patch_px * max(0.1, min_component_multiplier)))
457
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(safe_mask.astype(np.uint8), connectivity=8)
458
+ if num_labels > 1:
459
+ keep = np.zeros_like(labels, dtype=bool)
460
+ for i in range(1, num_labels):
461
+ if stats[i, cv2.CC_STAT_AREA] >= area_thresh:
462
+ keep |= labels == i
463
+ safe_mask = keep
464
+
465
+ # Recommended landing spot overlay (scaled to input image size)
466
+ # Prefer centers where the full footprint is safe; fall back to best flat spot
467
+ safe_fit = safe_mask.astype(np.float32)
468
+ try:
469
+ coverage = cv2.boxFilter(
470
+ safe_fit.astype(np.float32),
471
+ ddepth=-1,
472
+ ksize=(patch_px, patch_px),
473
+ normalize=True,
474
+ anchor=(patch_px // 2, patch_px // 2),
475
+ )
476
+ valid_centers = coverage >= 1.0
477
+ except Exception:
478
+ valid_centers = safe_fit > 0.5
479
+
480
+ if valid_centers.any():
481
+ cc_mask = valid_centers.astype(np.uint8)
482
+ num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(cc_mask, connectivity=8)
483
+ target_mask = valid_centers
484
+ if num_c > 1:
485
+ # Pick largest safe component by area (skip background)
486
+ areas = stats_c[1:, cv2.CC_STAT_AREA]
487
+ largest_idx = 1 + int(np.argmax(areas))
488
+ target_mask = labels_c == largest_idx
489
+ cand = np.where(target_mask)
490
+ std_cand = std_map[cand]
491
+ idx = np.argmin(std_cand)
492
+ cy, cx = cand[0][idx], cand[1][idx]
493
+ else:
494
+ y0, x0, y1, x1 = box[1], box[0], box[3], box[2]
495
+ cy, cx = (y0 + y1) // 2, (x0 + x1) // 2
496
+
497
+ half = patch_px // 2
498
+ x0 = max(int(cx - half), 0)
499
+ x1 = min(int(cx + half), depth.shape[1] - 1)
500
+ y0 = max(int(cy - half), 0)
501
+ y1 = min(int(cy + half), depth.shape[0] - 1)
502
+
503
+ scale_x = image.width / depth.shape[1]
504
+ scale_y = image.height / depth.shape[0]
505
+ # Draw a box whose side length matches the footprint in input-image pixels
506
+ side_img = max(3, int(round(patch_px * scale_x)))
507
+ cx_img = int(round(cx * scale_x))
508
+ cy_img = int(round(cy * scale_y))
509
+ half_img = side_img // 2
510
+ bx0 = max(cx_img - half_img, 0)
511
+ bx1 = min(cx_img + half_img, image.width - 1)
512
+ by0 = max(cy_img - half_img, 0)
513
+ by1 = min(cy_img + half_img, image.height - 1)
514
+ spot_overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
515
+ draw = ImageDraw.Draw(spot_overlay)
516
+ draw.rectangle((bx0, by0, bx1, by1), outline=(0, 255, 0, 255), width=4)
517
+ cx, cy = (bx0 + bx1) // 2, (by0 + by1) // 2
518
+ draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill=(0, 255, 0, 255))
519
+
520
+ depth_vis = Image.fromarray(visualize_depth(depth_raw, cmap="Spectral")).resize(
521
+ image.size, resample=Image.BILINEAR
522
+ )
523
+ flatness_img = Image.fromarray((std_map_vis / (std_map_vis.max() + 1e-6) * 255).astype(np.uint8)).resize(
524
+ image.size, resample=Image.NEAREST
525
+ )
526
+ grad_img = Image.fromarray((grad_norm * 255).astype(np.uint8)).resize(
527
+ image.size, resample=Image.BILINEAR
528
+ )
529
+ grad_mask_img = Image.fromarray(((grad_norm < grad_thresh).astype(np.uint8) * 255)).resize(
530
+ image.size, resample=Image.NEAREST
531
+ )
532
+ water_mask_view = None
533
+ if use_water_mask and water_mask_img is not None:
534
+ water_mask_view = Image.fromarray((np.array(water_mask_img).astype(np.uint8) * 255))
535
+ water_mask_view = water_mask_view.resize(image.size, resample=Image.NEAREST)
536
+ road_mask_view = None
537
+ if use_road_mask and road_mask_img is not None:
538
+ road_mask_view = Image.fromarray((np.array(road_mask_img).astype(np.uint8) * 255))
539
+ road_mask_view = road_mask_view.resize(image.size, resample=Image.NEAREST)
540
+ roof_mask_view = None
541
+ if use_roof_mask and roof_mask_resized is not None:
542
+ roof_mask_view = Image.fromarray((roof_mask_resized.astype(np.uint8) * 255))
543
+ roof_mask_view = roof_mask_view.resize(image.size, resample=Image.NEAREST)
544
+
545
+ heat_overlay, heat_gray = make_safety_heatmap(image, safe_mask)
546
+
547
+ images = {
548
+ "RGB": image,
549
+ "Depth": depth_vis,
550
+ "Flatness map (std)": flatness_img,
551
+ "Depth gradient": grad_img,
552
+ "Gradient mask": grad_mask_img,
553
+ "Water mask": water_mask_view if water_mask_view is not None else Image.new("L", image.size, 0),
554
+ "Road mask": road_mask_view if road_mask_view is not None else Image.new("L", image.size, 0),
555
+ "Roof mask": roof_mask_view if roof_mask_view is not None else Image.new("L", image.size, 0),
556
+ "Safety heatmap overlay": heat_overlay,
557
+ "Safety score": heat_gray,
558
+ "Landing spot overlay": spot_overlay,
559
+ }
560
+ return images
561
+
562
+
563
+ def process_image(
564
+ input_path: str,
565
+ footprint_m: float,
566
+ std_thresh: float,
567
+ grad_thresh: float,
568
+ use_water_mask: bool,
569
+ use_road_mask: bool,
570
+ use_roof_mask: bool,
571
+ altitude_m: float,
572
+ fov_deg: float,
573
+ flatness_detail: float,
574
+ clearance_factor: float,
575
+ process_res_cap: int,
576
+ roof_aggressiveness: float,
577
+ roof_morph_frac: float,
578
+ segmentation_max_side: int,
579
+ depth_smoothing_base: float,
580
+ coverage_strictness: float,
581
+ min_component_multiplier: float,
582
+ model_id: str,
583
+ source_path: str | None = None,
584
+ ) -> dict:
585
+ path = Path(input_path)
586
+ if not path.exists():
587
+ raise gr.Error(f"Input path not found: {path}")
588
+ if path.suffix.lower() not in IMAGE_EXTS:
589
+ raise gr.Error(f"Unsupported image type for path: {path}")
590
+ image = crop_nonblack(Image.open(path).convert("RGB"))
591
+ return run_on_image(
592
+ image=image,
593
+ footprint_m=footprint_m,
594
+ std_thresh=std_thresh,
595
+ grad_thresh=grad_thresh,
596
+ use_water_mask=use_water_mask,
597
+ use_road_mask=use_road_mask,
598
+ use_roof_mask=use_roof_mask,
599
+ altitude_m=altitude_m,
600
+ fov_deg=fov_deg,
601
+ flatness_detail=flatness_detail,
602
+ clearance_factor=clearance_factor,
603
+ process_res_cap=process_res_cap,
604
+ roof_aggressiveness=roof_aggressiveness,
605
+ roof_morph_frac=roof_morph_frac,
606
+ segmentation_max_side=segmentation_max_side,
607
+ depth_smoothing_base=depth_smoothing_base,
608
+ coverage_strictness=coverage_strictness,
609
+ min_component_multiplier=min_component_multiplier,
610
+ model_id=model_id,
611
+ source_path=str(path),
612
+ )
613
+
614
+
615
+ def compose_view(
616
+ images_dict: dict,
617
+ base_view: str,
618
+ heat_on: bool,
619
+ heat_alpha: float,
620
+ grad_on: bool,
621
+ grad_alpha: float,
622
+ flat_on: bool,
623
+ flat_alpha: float,
624
+ water_on: bool,
625
+ water_alpha: float,
626
+ water_enabled: bool,
627
+ spot_on: bool,
628
+ road_on: bool,
629
+ road_alpha: float,
630
+ road_enabled: bool,
631
+ roof_on: bool,
632
+ roof_alpha: float,
633
+ roof_enabled: bool,
634
+ ) -> Image.Image:
635
+ """Return a composited view with per-layer alpha controls."""
636
+ if not images_dict:
637
+ raise gr.Error("Run inference first, then select a view.")
638
+ if base_view not in images_dict:
639
+ raise gr.Error(f"Unknown view: {base_view}")
640
+
641
+ base = images_dict.get(base_view)
642
+ if base is None:
643
+ raise gr.Error(f"No image for view: {base_view}")
644
+ out = base.convert("RGBA")
645
+
646
+ if heat_on and "Safety heatmap overlay" in images_dict:
647
+ heat = images_dict["Safety heatmap overlay"]
648
+ if heat is not None:
649
+ heat_rgba = heat.convert("RGBA")
650
+ alpha = int(min(max(heat_alpha, 0.0), 1.0) * 255)
651
+ heat_rgba.putalpha(alpha)
652
+ out = Image.alpha_composite(out, heat_rgba)
653
+
654
+ if grad_on and "Depth gradient" in images_dict:
655
+ grad_img = images_dict["Depth gradient"]
656
+ if grad_img is not None:
657
+ grad_rgba = grad_img.convert("RGBA")
658
+ grad_rgba.putalpha(int(min(max(grad_alpha, 0.0), 1.0) * 255))
659
+ out = Image.alpha_composite(out, grad_rgba)
660
+
661
+ if flat_on and "Flatness map (std)" in images_dict:
662
+ flat_img = images_dict["Flatness map (std)"]
663
+ if flat_img is not None:
664
+ flat_rgba = flat_img.convert("RGBA")
665
+ flat_rgba.putalpha(int(min(max(flat_alpha, 0.0), 1.0) * 255))
666
+ out = Image.alpha_composite(out, flat_rgba)
667
+
668
+ if water_on and water_enabled and "Water mask" in images_dict:
669
+ wm = images_dict["Water mask"]
670
+ if wm is not None:
671
+ m = wm.convert("L")
672
+ overlay = Image.new("RGBA", wm.size, (255, 0, 0, 0))
673
+ alpha = int(min(max(water_alpha, 0.0), 1.0) * 255)
674
+ overlay.putalpha(Image.eval(m, lambda px: int(px * (alpha / 255.0))))
675
+ out = Image.alpha_composite(out, overlay)
676
+
677
+ if road_on and road_enabled and "Road mask" in images_dict:
678
+ rm = images_dict["Road mask"]
679
+ if rm is not None:
680
+ m = rm.convert("L")
681
+ overlay = Image.new("RGBA", rm.size, (255, 165, 0, 0)) # orange
682
+ alpha = int(min(max(road_alpha, 0.0), 1.0) * 255)
683
+ overlay.putalpha(Image.eval(m, lambda px: int(px * (alpha / 255.0))))
684
+ out = Image.alpha_composite(out, overlay)
685
+
686
+ if roof_on and roof_enabled and "Roof mask" in images_dict:
687
+ rf = images_dict["Roof mask"]
688
+ if rf is not None:
689
+ m = rf.convert("L")
690
+ overlay = Image.new("RGBA", rf.size, (255, 0, 255, 0)) # magenta tint for roofs
691
+ alpha = int(min(max(roof_alpha, 0.0), 1.0) * 255)
692
+ overlay.putalpha(Image.eval(m, lambda px: int(px * (alpha / 255.0))))
693
+ out = Image.alpha_composite(out, overlay)
694
+
695
+ if spot_on and "Landing spot overlay" in images_dict:
696
+ spot = images_dict["Landing spot overlay"]
697
+ if spot is not None:
698
+ out = Image.alpha_composite(out, spot.convert("RGBA"))
699
+
700
+ return out.convert("RGB")
701
+
702
+
703
+ def build_ui():
704
+ with gr.Blocks(title="Landing Site Safety Analyzer (VISLOC)") as demo:
705
+ gr.Markdown(
706
+ "## Landing Site Safety Analyzer\n"
707
+ "Run DepthAnything3 on VISLOC images under `data/Image/VISLOC` to evaluate landing zones: depth, safety heatmap, gradients, flatness, and water masks. Toggle layers, footprint, and opacity to assess safety."
708
+ )
709
+ with gr.Row():
710
+ with gr.Column(scale=1, min_width=320):
711
+ gr.Markdown("### Input")
712
+ all_choices = list_all_data_inputs()
713
+ input_path = gr.Dropdown(
714
+ label="Input file",
715
+ choices=all_choices,
716
+ value=all_choices[0] if all_choices else "",
717
+ info="Pick any VISLOC image under data/Image/VISLOC/.",
718
+ )
719
+ footprint_m = gr.Slider(
720
+ label="Landing footprint (meters)",
721
+ value=10,
722
+ minimum=1,
723
+ maximum=150,
724
+ step=1,
725
+ info="Side length (meters) of the clear area required for landing (assumes ~450m altitude, 90° FOV).",
726
+ )
727
+ std_thresh = gr.Slider(
728
+ label="Flatness threshold",
729
+ value=0.01,
730
+ minimum=0.001,
731
+ maximum=0.08,
732
+ step=0.001,
733
+ info="Lower values favor flatter regions when computing the heatmap.",
734
+ )
735
+ grad_thresh = gr.Slider(
736
+ label="Gradient threshold",
737
+ value=0.1,
738
+ minimum=0.02,
739
+ maximum=1.0,
740
+ step=0.01,
741
+ info="Lower values suppress sloped/edgy areas in the heatmap.",
742
+ )
743
+ flatness_detail = gr.Slider(
744
+ label="Flatness detail (relative)",
745
+ value=1.0,
746
+ minimum=0.5,
747
+ maximum=2.5,
748
+ step=0.1,
749
+ info="Scales the window for the flatness visualization; lower = finer detail.",
750
+ )
751
+ clearance_factor = gr.Slider(
752
+ label="Clearance factor",
753
+ value=0.5,
754
+ minimum=0.0,
755
+ maximum=2.0,
756
+ step=0.05,
757
+ info="How much to dilate unsafe regions relative to the footprint to enforce buffer distance.",
758
+ )
759
+ process_res_cap = gr.Slider(
760
+ label="Processing resolution cap",
761
+ value=1024,
762
+ minimum=512,
763
+ maximum=2048,
764
+ step=32,
765
+ info="Upper bound on the longest side fed to the depth model; avoids oversized, noisy inference.",
766
+ )
767
+ depth_smoothing_base = gr.Slider(
768
+ label="Depth smoothing base",
769
+ value=0.8,
770
+ minimum=0.0,
771
+ maximum=2.0,
772
+ step=0.05,
773
+ info="Base Gaussian sigma multiplier for depth smoothing (scaled by resolution).",
774
+ )
775
+ coverage_strictness = gr.Slider(
776
+ label="Coverage strictness",
777
+ value=0.999,
778
+ minimum=0.8,
779
+ maximum=1.0,
780
+ step=0.001,
781
+ info="Minimum fraction of a footprint that must be safe to count a center as safe.",
782
+ )
783
+ min_component_multiplier = gr.Slider(
784
+ label="Min safe area (x footprint)",
785
+ value=1.0,
786
+ minimum=0.1,
787
+ maximum=5.0,
788
+ step=0.1,
789
+ info="Minimum safe component area in multiples of footprint^2.",
790
+ )
791
+ segmentation_max_side = gr.Slider(
792
+ label="Segmentation max side",
793
+ value=640,
794
+ minimum=256,
795
+ maximum=1024,
796
+ step=32,
797
+ info="Resize longest image side to this for water/road segmentation.",
798
+ )
799
+ with gr.Accordion("Camera settings", open=False):
800
+ altitude_m = gr.Slider(
801
+ label="Camera altitude (m)",
802
+ value=450,
803
+ minimum=10,
804
+ maximum=1500,
805
+ step=5,
806
+ info="Altitude used to convert footprint meters to pixels.",
807
+ )
808
+ fov_deg = gr.Slider(
809
+ label="Camera FOV (deg)",
810
+ value=90,
811
+ minimum=30,
812
+ maximum=150,
813
+ step=1,
814
+ info="Horizontal field of view used for footprint sizing.",
815
+ )
816
+ model_id = gr.Dropdown(
817
+ label="Model",
818
+ value="depth-anything/DA3MONO-LARGE",
819
+ choices=[
820
+ "depth-anything/DA3MONO-LARGE",
821
+ "depth-anything/DA3METRIC-LARGE",
822
+ "depth-anything/DA3-BASE",
823
+ "depth-anything/DA3NESTED-GIANT-LARGE",
824
+ ],
825
+ info="Which pretrained DepthAnything3 checkpoint to use.",
826
+ )
827
+ with gr.Accordion("Masking", open=True):
828
+ with gr.Row():
829
+ use_water_mask = gr.Checkbox(
830
+ label="Exclude water (segmentation)", value=True, info="Apply water segmentation to down-weight water regions."
831
+ )
832
+ use_road_mask = gr.Checkbox(
833
+ label="Exclude roads (segmentation)", value=True, info="Apply road segmentation to avoid roads/highways."
834
+ )
835
+ use_roof_mask = gr.Checkbox(
836
+ label="Exclude rooftops (depth)", value=True, info="Use depth (closer-than-median) to avoid rooftops/raised structures."
837
+ )
838
+ roof_aggressiveness = gr.Slider(
839
+ label="Rooftop aggressiveness (MAD multiplier)",
840
+ value=1.3,
841
+ minimum=0.5,
842
+ maximum=3.0,
843
+ step=0.05,
844
+ info="Higher = more aggressive exclusion of raised areas in the depth-based rooftop mask.",
845
+ )
846
+ roof_morph_frac = gr.Slider(
847
+ label="Rooftop morph kernel (fraction of footprint px)",
848
+ value=0.15,
849
+ minimum=0.05,
850
+ maximum=0.5,
851
+ step=0.01,
852
+ info="Controls smoothing/merging of rooftop mask relative to footprint size.",
853
+ )
854
+ with gr.Row():
855
+ run_btn = gr.Button("Run", variant="primary", scale=1)
856
+ stop_btn = gr.Button("Stop", variant="stop", scale=1)
857
+ images_state = gr.State({})
858
+ with gr.Column(scale=3):
859
+ gr.Markdown("### Preview")
860
+ main_view = gr.Image(
861
+ label="Preview",
862
+ height=800,
863
+ elem_id="main-preview",
864
+ show_fullscreen_button=False,
865
+ )
866
+ gr.HTML(
867
+ """
868
+ <style>
869
+ #main-preview img,
870
+ #main-preview canvas { cursor: zoom-in; }
871
+ #main-preview-zoom-overlay {
872
+ position: fixed;
873
+ inset: 0;
874
+ z-index: 1000;
875
+ display: none;
876
+ align-items: center;
877
+ justify-content: center;
878
+ background: rgba(0, 0, 0, 0.85);
879
+ }
880
+ #main-preview-zoom-overlay img {
881
+ max-width: 95vw;
882
+ max-height: 95vh;
883
+ box-shadow: 0 0 24px rgba(0, 0, 0, 0.6);
884
+ }
885
+ </style>
886
+ <div id="main-preview-zoom-overlay"></div>
887
+ <script>
888
+ (() => {
889
+ const containerId = "main-preview";
890
+ const overlayId = "main-preview-zoom-overlay";
891
+
892
+ const ensureOverlay = () => {
893
+ let overlay = document.getElementById(overlayId);
894
+ if (!overlay) {
895
+ overlay = document.createElement("div");
896
+ overlay.id = overlayId;
897
+ document.body.appendChild(overlay);
898
+ }
899
+ overlay.onclick = () => {
900
+ overlay.style.display = "none";
901
+ overlay.innerHTML = "";
902
+ };
903
+ return overlay;
904
+ };
905
+
906
+ const getMedia = (container) => {
907
+ if (!container) return null;
908
+ const img = container.querySelector("img");
909
+ if (img) return { type: "img", el: img, getSrc: () => img.currentSrc || img.src };
910
+ const canvas = container.querySelector("canvas");
911
+ if (canvas) return { type: "canvas", el: canvas, getSrc: () => canvas.toDataURL("image/png") };
912
+ return null;
913
+ };
914
+
915
+ const bind = () => {
916
+ const container = document.getElementById(containerId);
917
+ if (!container || container.dataset.zoomBound) return;
918
+ container.dataset.zoomBound = "1";
919
+ container.addEventListener("click", (ev) => {
920
+ const media = getMedia(container);
921
+ if (!media) return;
922
+ const src = media.getSrc();
923
+ if (!src) return;
924
+ const overlay = ensureOverlay();
925
+ overlay.innerHTML = "";
926
+ const zoomed = document.createElement("img");
927
+ zoomed.src = src;
928
+ overlay.appendChild(zoomed);
929
+ overlay.style.display = "flex";
930
+ ev.stopPropagation();
931
+ });
932
+ };
933
+
934
+ // Poll because Gradio swaps the image element on updates.
935
+ const interval = setInterval(() => {
936
+ const media = getMedia(document.getElementById(containerId));
937
+ if (media && media.el && !media.el.dataset.cursorSet) {
938
+ media.el.dataset.cursorSet = "1";
939
+ media.el.style.cursor = "zoom-in";
940
+ }
941
+ bind();
942
+ }, 500);
943
+ window.addEventListener("beforeunload", () => clearInterval(interval));
944
+ })();
945
+ </script>
946
+ """,
947
+ elem_id="main-preview-zoom-helper",
948
+ )
949
+ with gr.Column(scale=1, min_width=260):
950
+ gr.Markdown("### Overlays")
951
+ base_view = gr.Dropdown(
952
+ label="Base view",
953
+ value="RGB",
954
+ choices=[
955
+ "RGB",
956
+ "Depth",
957
+ "Flatness map (std)",
958
+ "Depth gradient",
959
+ "Gradient mask",
960
+ "Water mask",
961
+ "Safety score",
962
+ "Safety heatmap overlay",
963
+ ],
964
+ )
965
+ heat_on = gr.Checkbox(label="Heatmap", value=True, info="Show the safety heatmap overlay.")
966
+ heat_alpha = gr.Slider(
967
+ label="Heatmap alpha", value=0.15, minimum=0.0, maximum=1.0, step=0.05, info="Heatmap opacity."
968
+ )
969
+ grad_on = gr.Checkbox(label="Depth gradient", value=False, info="Overlay the depth gradient magnitude.")
970
+ grad_alpha = gr.Slider(
971
+ label="Gradient alpha", value=0.35, minimum=0.0, maximum=1.0, step=0.05, info="Gradient overlay opacity."
972
+ )
973
+ flat_on = gr.Checkbox(label="Flatness map", value=False, info="Overlay per-pixel flatness (std).")
974
+ flat_alpha = gr.Slider(
975
+ label="Flatness alpha", value=0.25, minimum=0.0, maximum=1.0, step=0.05, info="Flatness overlay opacity."
976
+ )
977
+ spot_on = gr.Checkbox(label="Show landing spot", value=True, info="Overlay the recommended landing box.")
978
+ with gr.Accordion("Mask overlays", open=True):
979
+ water_on = gr.Checkbox(label="Water mask overlay", value=False, info="Overlay detected water regions.")
980
+ water_alpha = gr.Slider(
981
+ label="Water mask alpha",
982
+ value=0.5,
983
+ minimum=0.0,
984
+ maximum=1.0,
985
+ step=0.05,
986
+ info="Water overlay opacity.",
987
+ )
988
+ road_on = gr.Checkbox(label="Road mask overlay", value=False, info="Overlay detected road regions.")
989
+ road_alpha = gr.Slider(
990
+ label="Road mask alpha",
991
+ value=0.5,
992
+ minimum=0.0,
993
+ maximum=1.0,
994
+ step=0.05,
995
+ info="Road overlay opacity.",
996
+ )
997
+ roof_on = gr.Checkbox(label="Roof mask overlay", value=False, info="Overlay detected roof regions.")
998
+ roof_alpha = gr.Slider(
999
+ label="Roof mask alpha",
1000
+ value=0.5,
1001
+ minimum=0.0,
1002
+ maximum=1.0,
1003
+ step=0.05,
1004
+ info="Roof overlay opacity.",
1005
+ )
1006
+
1007
+ def process_any(
1008
+ input_path,
1009
+ footprint_m,
1010
+ std_thresh,
1011
+ grad_thresh,
1012
+ use_water_mask,
1013
+ use_road_mask,
1014
+ use_roof_mask,
1015
+ altitude_m,
1016
+ fov_deg,
1017
+ flatness_detail,
1018
+ clearance_factor,
1019
+ model_id,
1020
+ base_view,
1021
+ heat_on,
1022
+ heat_alpha,
1023
+ grad_on,
1024
+ grad_alpha,
1025
+ flat_on,
1026
+ flat_alpha,
1027
+ water_on,
1028
+ water_alpha,
1029
+ spot_on,
1030
+ road_on,
1031
+ road_alpha,
1032
+ roof_on,
1033
+ roof_alpha,
1034
+ ):
1035
+ if not input_path:
1036
+ raise gr.Error("Select an input image first.")
1037
+ path = Path(input_path)
1038
+ if not path.exists():
1039
+ raise gr.Error(f"Input not found: {path}")
1040
+ if path.suffix.lower() in IMAGE_EXTS:
1041
+ imgs = process_image(
1042
+ input_path=str(path),
1043
+ footprint_m=footprint_m,
1044
+ std_thresh=std_thresh,
1045
+ grad_thresh=grad_thresh,
1046
+ use_water_mask=use_water_mask,
1047
+ use_road_mask=use_road_mask,
1048
+ use_roof_mask=use_roof_mask,
1049
+ altitude_m=altitude_m,
1050
+ fov_deg=fov_deg,
1051
+ flatness_detail=flatness_detail,
1052
+ clearance_factor=clearance_factor,
1053
+ model_id=model_id,
1054
+ source_path=str(path),
1055
+ )
1056
+ composed = compose_view(
1057
+ imgs,
1058
+ base_view,
1059
+ heat_on,
1060
+ heat_alpha,
1061
+ grad_on,
1062
+ grad_alpha,
1063
+ flat_on,
1064
+ flat_alpha,
1065
+ water_on,
1066
+ water_alpha,
1067
+ water_enabled=use_water_mask,
1068
+ road_on=road_on,
1069
+ road_alpha=road_alpha,
1070
+ road_enabled=use_road_mask,
1071
+ roof_on=roof_on,
1072
+ roof_alpha=roof_alpha,
1073
+ roof_enabled=use_roof_mask,
1074
+ spot_on=spot_on,
1075
+ )
1076
+ yield imgs, composed
1077
+ else:
1078
+ raise gr.Error(f"Unsupported input type for path: {path} (images only)")
1079
+
1080
+ run_event = run_btn.click(
1081
+ fn=process_any,
1082
+ inputs=[
1083
+ input_path,
1084
+ footprint_m,
1085
+ std_thresh,
1086
+ grad_thresh,
1087
+ use_water_mask,
1088
+ use_road_mask,
1089
+ use_roof_mask,
1090
+ altitude_m,
1091
+ fov_deg,
1092
+ flatness_detail,
1093
+ clearance_factor,
1094
+ model_id,
1095
+ base_view,
1096
+ heat_on,
1097
+ heat_alpha,
1098
+ grad_on,
1099
+ grad_alpha,
1100
+ flat_on,
1101
+ flat_alpha,
1102
+ water_on,
1103
+ water_alpha,
1104
+ spot_on,
1105
+ road_on,
1106
+ road_alpha,
1107
+ roof_on,
1108
+ roof_alpha,
1109
+ ],
1110
+ outputs=[images_state, main_view],
1111
+ )
1112
+ stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[run_event])
1113
+ def update_preview_ui(
1114
+ images_state_val,
1115
+ input_path_val,
1116
+ footprint_m_val,
1117
+ std_thresh_val,
1118
+ grad_thresh_val,
1119
+ use_water_mask_val,
1120
+ use_road_mask_val,
1121
+ use_roof_mask_val,
1122
+ altitude_m_val,
1123
+ fov_deg_val,
1124
+ flatness_detail_val,
1125
+ clearance_factor_val,
1126
+ model_id_val,
1127
+ base_view_val,
1128
+ heat_on_val,
1129
+ heat_alpha_val,
1130
+ grad_on_val,
1131
+ grad_alpha_val,
1132
+ flat_on_val,
1133
+ flat_alpha_val,
1134
+ water_on_val,
1135
+ water_alpha_val,
1136
+ spot_on_val,
1137
+ road_on_val,
1138
+ road_alpha_val,
1139
+ roof_on_val,
1140
+ roof_alpha_val,
1141
+ ):
1142
+ path = Path(str(input_path_val))
1143
+ imgs_val = images_state_val
1144
+ # If current input is an image, re-run processing to reflect new settings
1145
+ if path.exists() and path.suffix.lower() in IMAGE_EXTS:
1146
+ try:
1147
+ imgs_val = process_image(
1148
+ input_path=str(path),
1149
+ footprint_m=footprint_m_val,
1150
+ std_thresh=std_thresh_val,
1151
+ grad_thresh=grad_thresh_val,
1152
+ use_water_mask=use_water_mask_val,
1153
+ use_road_mask=use_road_mask_val,
1154
+ use_roof_mask=use_roof_mask_val,
1155
+ altitude_m=altitude_m_val,
1156
+ fov_deg=fov_deg_val,
1157
+ flatness_detail=flatness_detail_val,
1158
+ clearance_factor=clearance_factor_val,
1159
+ model_id=model_id_val,
1160
+ )
1161
+ except Exception:
1162
+ imgs_val = images_state_val
1163
+ if not imgs_val:
1164
+ return images_state_val, gr.update()
1165
+ composed = compose_view(
1166
+ imgs_val,
1167
+ base_view_val,
1168
+ heat_on_val,
1169
+ heat_alpha_val,
1170
+ grad_on_val,
1171
+ grad_alpha_val,
1172
+ flat_on_val,
1173
+ flat_alpha_val,
1174
+ water_on_val,
1175
+ water_alpha_val,
1176
+ use_water_mask_val,
1177
+ spot_on_val,
1178
+ road_on_val,
1179
+ road_alpha_val,
1180
+ use_road_mask_val,
1181
+ roof_on_val,
1182
+ roof_alpha_val,
1183
+ use_roof_mask_val,
1184
+ )
1185
+ return imgs_val, composed
1186
+
1187
+ overlay_inputs = [
1188
+ images_state,
1189
+ base_view,
1190
+ heat_on,
1191
+ heat_alpha,
1192
+ grad_on,
1193
+ grad_alpha,
1194
+ flat_on,
1195
+ flat_alpha,
1196
+ water_on,
1197
+ water_alpha,
1198
+ spot_on,
1199
+ use_water_mask,
1200
+ road_on,
1201
+ road_alpha,
1202
+ use_road_mask,
1203
+ roof_on,
1204
+ roof_alpha,
1205
+ use_roof_mask,
1206
+ ]
1207
+
1208
+ def update_overlays_only(
1209
+ images_state_val,
1210
+ base_view_val,
1211
+ heat_on_val,
1212
+ heat_alpha_val,
1213
+ grad_on_val,
1214
+ grad_alpha_val,
1215
+ flat_on_val,
1216
+ flat_alpha_val,
1217
+ water_on_val,
1218
+ water_alpha_val,
1219
+ spot_on_val,
1220
+ use_water_mask_val,
1221
+ road_on_val,
1222
+ road_alpha_val,
1223
+ use_road_mask_val,
1224
+ roof_on_val,
1225
+ roof_alpha_val,
1226
+ use_roof_mask_val,
1227
+ ):
1228
+ if not images_state_val:
1229
+ return images_state_val, gr.update()
1230
+ return images_state_val, compose_view(
1231
+ images_state_val,
1232
+ base_view_val,
1233
+ heat_on_val,
1234
+ heat_alpha_val,
1235
+ grad_on_val,
1236
+ grad_alpha_val,
1237
+ flat_on_val,
1238
+ flat_alpha_val,
1239
+ water_on_val,
1240
+ water_alpha_val,
1241
+ use_water_mask_val,
1242
+ spot_on_val,
1243
+ road_on_val,
1244
+ road_alpha_val,
1245
+ use_road_mask_val,
1246
+ roof_on_val,
1247
+ roof_alpha_val,
1248
+ use_roof_mask_val,
1249
+ )
1250
+
1251
+ base_view.change(fn=update_overlays_only, inputs=overlay_inputs, outputs=[images_state, main_view])
1252
+ for control in (
1253
+ heat_on,
1254
+ heat_alpha,
1255
+ grad_on,
1256
+ grad_alpha,
1257
+ flat_on,
1258
+ flat_alpha,
1259
+ water_on,
1260
+ water_alpha,
1261
+ spot_on,
1262
+ use_water_mask,
1263
+ road_on,
1264
+ road_alpha,
1265
+ use_road_mask,
1266
+ roof_on,
1267
+ roof_alpha,
1268
+ use_roof_mask,
1269
+ ):
1270
+ control.change(fn=update_overlays_only, inputs=overlay_inputs, outputs=[images_state, main_view])
1271
+
1272
+ model_inputs = [
1273
+ images_state,
1274
+ input_path,
1275
+ footprint_m,
1276
+ std_thresh,
1277
+ grad_thresh,
1278
+ use_water_mask,
1279
+ use_road_mask,
1280
+ use_roof_mask,
1281
+ altitude_m,
1282
+ fov_deg,
1283
+ flatness_detail,
1284
+ clearance_factor,
1285
+ model_id,
1286
+ base_view,
1287
+ heat_on,
1288
+ heat_alpha,
1289
+ grad_on,
1290
+ grad_alpha,
1291
+ flat_on,
1292
+ flat_alpha,
1293
+ water_on,
1294
+ water_alpha,
1295
+ spot_on,
1296
+ road_on,
1297
+ road_alpha,
1298
+ roof_on,
1299
+ roof_alpha,
1300
+ ]
1301
+ for control in (
1302
+ input_path,
1303
+ footprint_m,
1304
+ std_thresh,
1305
+ grad_thresh,
1306
+ use_water_mask,
1307
+ use_road_mask,
1308
+ use_roof_mask,
1309
+ altitude_m,
1310
+ fov_deg,
1311
+ flatness_detail,
1312
+ clearance_factor,
1313
+ model_id,
1314
+ ):
1315
+ control.change(fn=update_preview_ui, inputs=model_inputs, outputs=[images_state, main_view])
1316
+ return demo
1317
+
1318
+
1319
+ if __name__ == "__main__":
1320
+ demo = build_ui()
1321
+ demo.queue().launch()
src/depth_anything_3/__init__.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Depth Anything 3 Python package entrypoint.
16
+ """
17
+
18
+ from depth_anything_3.api import DepthAnything3
19
+
20
+ __all__ = ["DepthAnything3"]
21
+ __version__ = "0.1.0"
src/depth_anything_3/api.py ADDED
@@ -0,0 +1,416 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """
15
+ Depth Anything 3 API module.
16
+
17
+ This module provides the main API for Depth Anything 3, including model loading,
18
+ inference, and export capabilities. It supports both single and nested model architectures.
19
+ """
20
+
21
+ from __future__ import annotations
22
+
23
+ import time
24
+ from typing import Optional, Sequence
25
+ import numpy as np
26
+ import torch
27
+ import torch.nn as nn
28
+ from huggingface_hub import PyTorchModelHubMixin
29
+ from PIL import Image
30
+
31
+ from depth_anything_3.cfg import create_object, load_config
32
+ from depth_anything_3.registry import MODEL_REGISTRY
33
+ from depth_anything_3.specs import Prediction
34
+ from depth_anything_3.utils.export import export
35
+ from depth_anything_3.utils.geometry import affine_inverse
36
+ from depth_anything_3.utils.io.input_processor import InputProcessor
37
+ from depth_anything_3.utils.io.output_processor import OutputProcessor
38
+ from depth_anything_3.utils.logger import logger
39
+ from depth_anything_3.utils.pose_align import align_poses_umeyama
40
+
41
+ torch.backends.cudnn.benchmark = False
42
+ # logger.info("CUDNN Benchmark Disabled")
43
+
44
+ SAFETENSORS_NAME = "model.safetensors"
45
+ CONFIG_NAME = "config.json"
46
+
47
+
48
+ class DepthAnything3(nn.Module, PyTorchModelHubMixin):
49
+ """
50
+ Depth Anything 3 main API class.
51
+
52
+ This class provides a high-level interface for depth estimation using Depth Anything 3.
53
+ It supports both single and nested model architectures with metric scaling capabilities.
54
+
55
+ Features:
56
+ - Hugging Face Hub integration via PyTorchModelHubMixin
57
+ - Support for multiple model presets (vitb, vitg, nested variants)
58
+ - Automatic mixed precision inference
59
+ - Export capabilities for various formats (GLB, PLY, NPZ, etc.)
60
+ - Camera pose estimation and metric depth scaling
61
+
62
+ Usage:
63
+ # Load from Hugging Face Hub
64
+ model = DepthAnything3.from_pretrained("huggingface/model-name")
65
+
66
+ # Or create with specific preset
67
+ model = DepthAnything3(preset="vitg")
68
+
69
+ # Run inference
70
+ prediction = model.inference(images, export_dir="output", export_format="glb")
71
+ """
72
+
73
+ _commit_hash: str | None = None # Set by mixin when loading from Hub
74
+
75
+ def __init__(self, model_name: str = "da3-large", **kwargs):
76
+ """
77
+ Initialize DepthAnything3 with specified preset.
78
+
79
+ Args:
80
+ model_name: The name of the model preset to use.
81
+ Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
82
+ **kwargs: Additional keyword arguments (currently unused).
83
+ """
84
+ super().__init__()
85
+ self.model_name = model_name
86
+
87
+ # Build the underlying network
88
+ self.config = load_config(MODEL_REGISTRY[self.model_name])
89
+ self.model = create_object(self.config)
90
+ self.model.eval()
91
+
92
+ # Initialize processors
93
+ self.input_processor = InputProcessor()
94
+ self.output_processor = OutputProcessor()
95
+
96
+ # Device management (set by user)
97
+ self.device = None
98
+
99
+ @torch.inference_mode()
100
+ def forward(
101
+ self,
102
+ image: torch.Tensor,
103
+ extrinsics: torch.Tensor | None = None,
104
+ intrinsics: torch.Tensor | None = None,
105
+ export_feat_layers: list[int] | None = None,
106
+ infer_gs: bool = False,
107
+ ) -> dict[str, torch.Tensor]:
108
+ """
109
+ Forward pass through the model.
110
+
111
+ Args:
112
+ image: Input batch with shape ``(B, N, 3, H, W)`` on the model device.
113
+ extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``.
114
+ intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``.
115
+ export_feat_layers: Layer indices to return intermediate features for.
116
+
117
+ Returns:
118
+ Dictionary containing model predictions
119
+ """
120
+ # Determine optimal autocast dtype
121
+ autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
122
+ with torch.no_grad():
123
+ with torch.autocast(device_type=image.device.type, dtype=autocast_dtype):
124
+ return self.model(image, extrinsics, intrinsics, export_feat_layers, infer_gs)
125
+
126
+ def inference(
127
+ self,
128
+ image: list[np.ndarray | Image.Image | str],
129
+ extrinsics: np.ndarray | None = None,
130
+ intrinsics: np.ndarray | None = None,
131
+ align_to_input_ext_scale: bool = True,
132
+ infer_gs: bool = False,
133
+ render_exts: np.ndarray | None = None,
134
+ render_ixts: np.ndarray | None = None,
135
+ render_hw: tuple[int, int] | None = None,
136
+ process_res: int = 504,
137
+ process_res_method: str = "upper_bound_resize",
138
+ export_dir: str | None = None,
139
+ export_format: str = "mini_npz",
140
+ export_feat_layers: Sequence[int] | None = None,
141
+ # GLB export parameters
142
+ conf_thresh_percentile: float = 40.0,
143
+ num_max_points: int = 1_000_000,
144
+ show_cameras: bool = True,
145
+ # Feat_vis export parameters
146
+ feat_vis_fps: int = 15,
147
+ # Other export parameters, e.g., gs_ply, gs_video
148
+ export_kwargs: Optional[dict] = None,
149
+ ) -> Prediction:
150
+ """
151
+ Run inference on input images.
152
+
153
+ Args:
154
+ image: List of input images (numpy arrays, PIL Images, or file paths)
155
+ extrinsics: Camera extrinsics (N, 4, 4)
156
+ intrinsics: Camera intrinsics (N, 3, 3)
157
+ align_to_input_ext_scale: whether to align the input pose scale to the prediction
158
+ infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports)
159
+ render_exts: Optional render extrinsics for Gaussian video export
160
+ render_ixts: Optional render intrinsics for Gaussian video export
161
+ render_hw: Optional render resolution for Gaussian video export
162
+ process_res: Processing resolution
163
+ process_res_method: Resize method for processing
164
+ export_dir: Directory to export results
165
+ export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video)
166
+ export_feat_layers: Layer indices to export intermediate features from
167
+ conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501
168
+ num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000)
169
+ show_cameras: [GLB] Show camera wireframes in the exported scene (default: True)
170
+ feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15)
171
+ export_kwargs: additional arguments to export functions.
172
+
173
+ Returns:
174
+ Prediction object containing depth maps and camera parameters
175
+ """
176
+ if "gs" in export_format:
177
+ assert infer_gs, "must set `infer_gs=True` to perform gs-related export."
178
+
179
+ # Preprocess images
180
+ imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs(
181
+ image, extrinsics, intrinsics, process_res, process_res_method
182
+ )
183
+
184
+ # Prepare tensors for model
185
+ imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics)
186
+
187
+ # Normalize extrinsics
188
+ ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None)
189
+
190
+ # Run model forward pass
191
+ export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else []
192
+
193
+ raw_output = self._run_model_forward(imgs, ex_t_norm, in_t, export_feat_layers, infer_gs)
194
+
195
+ # Convert raw output to prediction
196
+ prediction = self._convert_to_prediction(raw_output)
197
+
198
+ # Align prediction to extrinsincs
199
+ prediction = self._align_to_input_extrinsics_intrinsics(
200
+ extrinsics, intrinsics, prediction, align_to_input_ext_scale
201
+ )
202
+
203
+ # Add processed images for visualization
204
+ prediction = self._add_processed_images(prediction, imgs_cpu)
205
+
206
+ # Export if requested
207
+ if export_dir is not None:
208
+ export_kwargs = {} if export_kwargs is None else dict(export_kwargs)
209
+ if "gs" in export_format:
210
+ if infer_gs and "gs_video" not in export_format:
211
+ export_format = f"{export_format}-gs_video"
212
+ if "gs_video" in export_format:
213
+ if "gs_video" not in export_kwargs:
214
+ export_kwargs["gs_video"] = {}
215
+ export_kwargs["gs_video"].update(
216
+ {
217
+ "extrinsics": render_exts,
218
+ "intrinsics": render_ixts,
219
+ "out_image_hw": render_hw,
220
+ }
221
+ )
222
+ # Add GLB export parameters
223
+ if "glb" in export_format:
224
+ if "glb" not in export_kwargs:
225
+ export_kwargs["glb"] = {}
226
+ export_kwargs["glb"].update(
227
+ {
228
+ "conf_thresh_percentile": conf_thresh_percentile,
229
+ "num_max_points": num_max_points,
230
+ "show_cameras": show_cameras,
231
+ }
232
+ )
233
+ # Add Feat_vis export parameters
234
+ if "feat_vis" in export_format:
235
+ if "feat_vis" not in export_kwargs:
236
+ export_kwargs["feat_vis"] = {}
237
+ export_kwargs["feat_vis"].update(
238
+ {
239
+ "fps": feat_vis_fps,
240
+ }
241
+ )
242
+ self._export_results(prediction, export_format, export_dir, **export_kwargs)
243
+
244
+ return prediction
245
+
246
+ def _preprocess_inputs(
247
+ self,
248
+ image: list[np.ndarray | Image.Image | str],
249
+ extrinsics: np.ndarray | None = None,
250
+ intrinsics: np.ndarray | None = None,
251
+ process_res: int = 504,
252
+ process_res_method: str = "upper_bound_resize",
253
+ ) -> torch.Tensor:
254
+ """Preprocess input images using input processor."""
255
+ start_time = time.time()
256
+ imgs_cpu, extrinsics, intrinsics = self.input_processor(
257
+ image,
258
+ extrinsics.copy() if extrinsics is not None else None,
259
+ intrinsics.copy() if intrinsics is not None else None,
260
+ process_res,
261
+ process_res_method,
262
+ )
263
+ end_time = time.time()
264
+ logger.info(
265
+ "Processed Images Done taking",
266
+ end_time - start_time,
267
+ "seconds. Shape: ",
268
+ imgs_cpu.shape,
269
+ )
270
+ return imgs_cpu, extrinsics, intrinsics
271
+
272
+ def _prepare_model_inputs(
273
+ self,
274
+ imgs_cpu: torch.Tensor,
275
+ extrinsics: torch.tensor | None,
276
+ intrinsics: torch.tensor | None,
277
+ ) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
278
+ """Prepare tensors for model input."""
279
+ device = self._get_model_device()
280
+
281
+ # Move images to model device
282
+ imgs = imgs_cpu.to(device, non_blocking=True)[None].float()
283
+
284
+ # Convert camera parameters to tensors
285
+ ex_t = (
286
+ extrinsics.to(device, non_blocking=True)[None].float()
287
+ if extrinsics is not None
288
+ else None
289
+ )
290
+ in_t = (
291
+ intrinsics.to(device, non_blocking=True)[None].float()
292
+ if intrinsics is not None
293
+ else None
294
+ )
295
+
296
+ return imgs, ex_t, in_t
297
+
298
+ def _normalize_extrinsics(self, ex_t: torch.Tensor) -> torch.Tensor:
299
+ """Normalize extrinsics"""
300
+ if ex_t is None:
301
+ return None
302
+ transform = affine_inverse(ex_t[:, :1])
303
+ ex_t_norm = ex_t @ transform
304
+ c2ws = affine_inverse(ex_t_norm)
305
+ translations = c2ws[..., :3, 3]
306
+ dists = translations.norm(dim=-1)
307
+ median_dist = torch.median(dists)
308
+ median_dist = torch.clamp(median_dist, min=1e-1)
309
+ ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist
310
+ return ex_t_norm
311
+
312
+ def _align_to_input_extrinsics_intrinsics(
313
+ self,
314
+ extrinsics: torch.Tensor,
315
+ intrinsics: torch.Tensor,
316
+ prediction: Prediction,
317
+ align_to_input_ext_scale: bool = True,
318
+ ransac_view_thresh: int = 10,
319
+ ) -> Prediction:
320
+ """Align depth map to input extrinsics"""
321
+ if extrinsics is None or prediction.extrinsics is None:
322
+ return prediction
323
+ if intrinsics is not None:
324
+ prediction.intrinsics = intrinsics.numpy()
325
+ _, _, scale, aligned_extrinsics = align_poses_umeyama(
326
+ prediction.extrinsics,
327
+ extrinsics.numpy(),
328
+ ransac=len(extrinsics) >= ransac_view_thresh,
329
+ return_aligned=True,
330
+ random_state=42,
331
+ )
332
+ if align_to_input_ext_scale:
333
+ prediction.extrinsics = extrinsics[..., :3, :].numpy()
334
+ prediction.depth /= scale
335
+ else:
336
+ prediction.extrinsics = aligned_extrinsics
337
+ return prediction
338
+
339
+ def _run_model_forward(
340
+ self,
341
+ imgs: torch.Tensor,
342
+ ex_t: torch.Tensor | None,
343
+ in_t: torch.Tensor | None,
344
+ export_feat_layers: Sequence[int] | None = None,
345
+ infer_gs: bool = False,
346
+ ) -> dict[str, torch.Tensor]:
347
+ """Run model forward pass."""
348
+ device = imgs.device
349
+ need_sync = device.type == "cuda"
350
+ if need_sync:
351
+ torch.cuda.synchronize(device)
352
+ start_time = time.time()
353
+ feat_layers = list(export_feat_layers) if export_feat_layers is not None else None
354
+ output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs)
355
+ if need_sync:
356
+ torch.cuda.synchronize(device)
357
+ end_time = time.time()
358
+ logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds")
359
+ return output
360
+
361
+ def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction:
362
+ """Convert raw model output to Prediction object."""
363
+ start_time = time.time()
364
+ output = self.output_processor(raw_output)
365
+ end_time = time.time()
366
+ logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds")
367
+ return output
368
+
369
+ def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction:
370
+ """Add processed images to prediction for visualization."""
371
+ # Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize
372
+ processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3)
373
+
374
+ # Denormalize from ImageNet normalization
375
+ mean = np.array([0.485, 0.456, 0.406])
376
+ std = np.array([0.229, 0.224, 0.225])
377
+ processed_imgs = processed_imgs * std + mean
378
+ processed_imgs = np.clip(processed_imgs, 0, 1)
379
+ processed_imgs = (processed_imgs * 255).astype(np.uint8)
380
+
381
+ prediction.processed_images = processed_imgs
382
+ return prediction
383
+
384
+ def _export_results(
385
+ self, prediction: Prediction, export_format: str, export_dir: str, **kwargs
386
+ ) -> None:
387
+ """Export results to specified format and directory."""
388
+ start_time = time.time()
389
+ export(prediction, export_format, export_dir, **kwargs)
390
+ end_time = time.time()
391
+ logger.info(f"Export Results Done. Time: {end_time - start_time} seconds")
392
+
393
+ def _get_model_device(self) -> torch.device:
394
+ """
395
+ Get the device where the model is located.
396
+
397
+ Returns:
398
+ Device where the model parameters are located
399
+
400
+ Raises:
401
+ ValueError: If no tensors are found in the model
402
+ """
403
+ if self.device is not None:
404
+ return self.device
405
+
406
+ # Find device from parameters
407
+ for param in self.parameters():
408
+ self.device = param.device
409
+ return param.device
410
+
411
+ # Find device from buffers
412
+ for buffer in self.buffers():
413
+ self.device = buffer.device
414
+ return buffer.device
415
+
416
+ raise ValueError("No tensor found in model")
src/depth_anything_3/app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Package marker for app modules.
src/depth_anything_3/app/css_and_html.py ADDED
@@ -0,0 +1,594 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E501
2
+
3
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ CSS and HTML content for the Depth Anything 3 Gradio application.
19
+ This module contains all the CSS styles and HTML content blocks
20
+ used in the Gradio interface.
21
+ """
22
+
23
+ # CSS Styles for the Gradio interface
24
+ GRADIO_CSS = """
25
+ /* Add Font Awesome CDN with all styles including brands and colors */
26
+ @import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css');
27
+
28
+ /* Add custom styles for colored icons */
29
+ .fa-color-blue {
30
+ color: #3b82f6;
31
+ }
32
+
33
+ .fa-color-purple {
34
+ color: #8b5cf6;
35
+ }
36
+
37
+ .fa-color-cyan {
38
+ color: #06b6d4;
39
+ }
40
+
41
+ .fa-color-green {
42
+ color: #10b981;
43
+ }
44
+
45
+ .fa-color-yellow {
46
+ color: #f59e0b;
47
+ }
48
+
49
+ .fa-color-red {
50
+ color: #ef4444;
51
+ }
52
+
53
+ .link-btn {
54
+ display: inline-flex;
55
+ align-items: center;
56
+ gap: 8px;
57
+ text-decoration: none;
58
+ padding: 12px 24px;
59
+ border-radius: 50px;
60
+ font-weight: 500;
61
+ transition: all 0.3s ease;
62
+ }
63
+
64
+ /* Dark mode tech theme */
65
+ @media (prefers-color-scheme: dark) {
66
+ html, body {
67
+ background: #1e293b;
68
+ color: #ffffff;
69
+ }
70
+
71
+ .gradio-container {
72
+ background: #1e293b;
73
+ color: #ffffff;
74
+ }
75
+
76
+ .link-btn {
77
+ background: rgba(255, 255, 255, 0.2);
78
+ color: white;
79
+ backdrop-filter: blur(10px);
80
+ border: 1px solid rgba(255, 255, 255, 0.3);
81
+ }
82
+
83
+ .link-btn:hover {
84
+ background: rgba(255, 255, 255, 0.3);
85
+ transform: translateY(-2px);
86
+ box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
87
+ }
88
+
89
+ .tech-bg {
90
+ background: linear-gradient(135deg, #0f172a, #1e293b); /* Darker colors */
91
+ position: relative;
92
+ overflow: hidden;
93
+ }
94
+
95
+ .tech-bg::before {
96
+ content: '';
97
+ position: absolute;
98
+ top: 0;
99
+ left: 0;
100
+ right: 0;
101
+ bottom: 0;
102
+ background:
103
+ radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
104
+ radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
105
+ radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.1) 0%, transparent 50%); /* Reduced opacity */
106
+ animation: techPulse 8s ease-in-out infinite;
107
+ }
108
+
109
+ .gradio-container .panel,
110
+ .gradio-container .block,
111
+ .gradio-container .form {
112
+ background: rgba(0, 0, 0, 0.3);
113
+ border: 1px solid rgba(59, 130, 246, 0.2);
114
+ border-radius: 10px;
115
+ }
116
+
117
+ .gradio-container * {
118
+ color: #ffffff;
119
+ }
120
+
121
+ .gradio-container label {
122
+ color: #e0e0e0;
123
+ }
124
+
125
+ .gradio-container .markdown {
126
+ color: #e0e0e0;
127
+ }
128
+ }
129
+
130
+ /* Light mode tech theme */
131
+ @media (prefers-color-scheme: light) {
132
+ html, body {
133
+ background: #ffffff;
134
+ color: #1e293b;
135
+ }
136
+
137
+ .gradio-container {
138
+ background: #ffffff;
139
+ color: #1e293b;
140
+ }
141
+
142
+ .tech-bg {
143
+ background: linear-gradient(135deg, #ffffff, #f1f5f9);
144
+ position: relative;
145
+ overflow: hidden;
146
+ }
147
+
148
+ .link-btn {
149
+ background: rgba(59, 130, 246, 0.15);
150
+ color: var(--body-text-color);
151
+ border: 1px solid rgba(59, 130, 246, 0.3);
152
+ }
153
+
154
+ .link-btn:hover {
155
+ background: rgba(59, 130, 246, 0.25);
156
+ transform: translateY(-2px);
157
+ box-shadow: 0 8px 25px rgba(59, 130, 246, 0.2);
158
+ }
159
+
160
+ .tech-bg::before {
161
+ content: '';
162
+ position: absolute;
163
+ top: 0;
164
+ left: 0;
165
+ right: 0;
166
+ bottom: 0;
167
+ background:
168
+ radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.1) 0%, transparent 50%),
169
+ radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.1) 0%, transparent 50%),
170
+ radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.08) 0%, transparent 50%);
171
+ animation: techPulse 8s ease-in-out infinite;
172
+ }
173
+
174
+ .gradio-container .panel,
175
+ .gradio-container .block,
176
+ .gradio-container .form {
177
+ background: rgba(255, 255, 255, 0.8);
178
+ border: 1px solid rgba(59, 130, 246, 0.3);
179
+ border-radius: 10px;
180
+ box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
181
+ }
182
+
183
+ .gradio-container * {
184
+ color: #1e293b;
185
+ }
186
+
187
+ .gradio-container label {
188
+ color: #334155;
189
+ }
190
+
191
+ .gradio-container .markdown {
192
+ color: #334155;
193
+ }
194
+ }
195
+
196
+
197
+
198
+
199
+ @keyframes techPulse {
200
+ 0%, 100% { opacity: 0.5; }
201
+ 50% { opacity: 0.8; }
202
+ }
203
+
204
+ /* Custom log with tech gradient */
205
+ .custom-log * {
206
+ font-style: italic;
207
+ font-size: 22px !important;
208
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6);
209
+ background-size: 400% 400%;
210
+ -webkit-background-clip: text;
211
+ background-clip: text;
212
+ font-weight: bold !important;
213
+ color: transparent !important;
214
+ text-align: center !important;
215
+ animation: techGradient 3s ease infinite;
216
+ }
217
+
218
+ @keyframes techGradient {
219
+ 0% { background-position: 0% 50%; }
220
+ 50% { background-position: 100% 50%; }
221
+ 100% { background-position: 0% 50%; }
222
+ }
223
+
224
+ @keyframes metricPulse {
225
+ 0%, 100% { background-position: 0% 50%; }
226
+ 50% { background-position: 100% 50%; }
227
+ }
228
+
229
+ @keyframes pointcloudPulse {
230
+ 0%, 100% { background-position: 0% 50%; }
231
+ 50% { background-position: 100% 50%; }
232
+ }
233
+
234
+ @keyframes camerasPulse {
235
+ 0%, 100% { background-position: 0% 50%; }
236
+ 50% { background-position: 100% 50%; }
237
+ }
238
+
239
+ @keyframes gaussiansPulse {
240
+ 0%, 100% { background-position: 0% 50%; }
241
+ 50% { background-position: 100% 50%; }
242
+ }
243
+
244
+ /* Special colors for key terms - Global styles */
245
+ .metric-text {
246
+ background: linear-gradient(45deg, #ff6b6b, #ff8e53, #ff6b6b);
247
+ background-size: 200% 200%;
248
+ -webkit-background-clip: text;
249
+ background-clip: text;
250
+ color: transparent !important;
251
+ animation: metricPulse 2s ease-in-out infinite;
252
+ font-weight: 700;
253
+ text-shadow: 0 0 10px rgba(255, 107, 107, 0.5);
254
+ }
255
+
256
+ .pointcloud-text {
257
+ background: linear-gradient(45deg, #4ecdc4, #44a08d, #4ecdc4);
258
+ background-size: 200% 200%;
259
+ -webkit-background-clip: text;
260
+ background-clip: text;
261
+ color: transparent !important;
262
+ animation: pointcloudPulse 2.5s ease-in-out infinite;
263
+ font-weight: 700;
264
+ text-shadow: 0 0 10px rgba(78, 205, 196, 0.5);
265
+ }
266
+
267
+ .cameras-text {
268
+ background: linear-gradient(45deg, #667eea, #764ba2, #667eea);
269
+ background-size: 200% 200%;
270
+ -webkit-background-clip: text;
271
+ background-clip: text;
272
+ color: transparent !important;
273
+ animation: camerasPulse 3s ease-in-out infinite;
274
+ font-weight: 700;
275
+ text-shadow: 0 0 10px rgba(102, 126, 234, 0.5);
276
+ }
277
+
278
+ .gaussians-text {
279
+ background: linear-gradient(45deg, #f093fb, #f5576c, #f093fb);
280
+ background-size: 200% 200%;
281
+ -webkit-background-clip: text;
282
+ background-clip: text;
283
+ color: transparent !important;
284
+ animation: gaussiansPulse 2.2s ease-in-out infinite;
285
+ font-weight: 700;
286
+ text-shadow: 0 0 10px rgba(240, 147, 251, 0.5);
287
+ }
288
+
289
+ .example-log * {
290
+ font-style: italic;
291
+ font-size: 16px !important;
292
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6);
293
+ -webkit-background-clip: text;
294
+ background-clip: text;
295
+ color: transparent !important;
296
+ }
297
+
298
+ #my_radio .wrap {
299
+ display: flex;
300
+ flex-wrap: nowrap;
301
+ justify-content: center;
302
+ align-items: center;
303
+ }
304
+
305
+ #my_radio .wrap label {
306
+ display: flex;
307
+ width: 50%;
308
+ justify-content: center;
309
+ align-items: center;
310
+ margin: 0;
311
+ padding: 10px 0;
312
+ box-sizing: border-box;
313
+ }
314
+
315
+ /* Align navigation buttons with dropdown bottom */
316
+ .navigation-row {
317
+ display: flex !important;
318
+ align-items: flex-end !important;
319
+ gap: 8px !important;
320
+ }
321
+
322
+ .navigation-row > div:nth-child(1),
323
+ .navigation-row > div:nth-child(3) {
324
+ align-self: flex-end !important;
325
+ }
326
+
327
+ .navigation-row > div:nth-child(2) {
328
+ flex: 1 !important;
329
+ }
330
+
331
+ /* Make thumbnails clickable with pointer cursor */
332
+ .clickable-thumbnail img {
333
+ cursor: pointer !important;
334
+ }
335
+
336
+ .clickable-thumbnail:hover img {
337
+ cursor: pointer !important;
338
+ opacity: 0.8;
339
+ transition: opacity 0.3s ease;
340
+ }
341
+
342
+ /* Make thumbnail containers narrower horizontally */
343
+ .clickable-thumbnail {
344
+ padding: 5px 2px !important;
345
+ margin: 0 2px !important;
346
+ }
347
+
348
+ .clickable-thumbnail .image-container {
349
+ margin: 0 !important;
350
+ padding: 0 !important;
351
+ }
352
+
353
+ .scene-info {
354
+ text-align: center !important;
355
+ padding: 5px 2px !important;
356
+ margin: 0 !important;
357
+ }
358
+ """
359
+
360
+
361
+ def get_header_html(logo_base64=None):
362
+ """
363
+ Generate the main header HTML with logo and title.
364
+
365
+ Args:
366
+ logo_base64 (str, optional): Base64 encoded logo image
367
+
368
+ Returns:
369
+ str: HTML string for the header
370
+ """
371
+ return """
372
+ <div class="tech-bg" style="text-align: center; margin-bottom: 5px; padding: 40px 20px; border-radius: 15px; position: relative; overflow: hidden;">
373
+ <div style="position: relative; z-index: 2;">
374
+ <h1 style="margin: 0; font-size: 3.5em; font-weight: 700;
375
+ background: linear-gradient(135deg, #3b82f6, #8b5cf6);
376
+ background-size: 400% 400%;
377
+ -webkit-background-clip: text;
378
+ background-clip: text;
379
+ color: transparent;
380
+ animation: techGradient 3s ease infinite;
381
+ text-shadow: 0 0 30px rgba(59, 130, 246, 0.5);
382
+ letter-spacing: 2px;">
383
+ Depth Anything 3
384
+ </h1>
385
+ <p style="margin: 15px 0 0 0; font-size: 2.16em; font-weight: 300;" class="header-subtitle">
386
+ Recovering the Visual Space from Any Views
387
+ </p>
388
+ <div style="margin-top: 20px;">
389
+ <!-- Revert buttons to original inline styles -->
390
+ <a href="https://depth-anything-3.github.io" target="_blank" class="link-btn">
391
+ <i class="fas fa-globe" style="margin-right: 8px;"></i> Project Page
392
+ </a>
393
+ <a href="https://arxiv.org/abs/2406.09414" target="_blank" class="link-btn">
394
+ <i class="fas fa-file-pdf" style="margin-right: 8px;"></i> Paper
395
+ </a>
396
+ <a href="https://github.com/ByteDance-Seed/Depth-Anything-3" target="_blank" class="link-btn">
397
+ <i class="fab fa-github" style="margin-right: 8px;"></i> Code
398
+ </a>
399
+ </div>
400
+ </div>
401
+ </div>
402
+
403
+ <style>
404
+ /* Ensure tech-bg class is properly applied in dark mode */
405
+ @media (prefers-color-scheme: dark) {
406
+ .header-subtitle {
407
+ color: #cbd5e1;
408
+ }
409
+ /* Increase priority to ensure background color is properly applied */
410
+ .tech-bg {
411
+ background: linear-gradient(135deg, #0f172a, #1e293b) !important;
412
+ }
413
+ }
414
+
415
+ @media (prefers-color-scheme: light) {
416
+ .header-subtitle {
417
+ color: #475569;
418
+ }
419
+ /* Also add explicit background color for light mode */
420
+ .tech-bg {
421
+ background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%) !important;
422
+ }
423
+ }
424
+ </style>
425
+ """
426
+
427
+
428
+ def get_description_html():
429
+ """
430
+ Generate the main description and getting started HTML.
431
+
432
+ Returns:
433
+ str: HTML string for the description
434
+ """
435
+ return """
436
+ <div class="description-container" style="padding: 25px; border-radius: 15px; margin: 0 0 20px 0;">
437
+ <h2 class="description-title" style="margin-top: 0; font-size: 1.6em; text-align: center;">
438
+ <i class="fas fa-bullseye fa-color-red" style="margin-right: 8px;"></i> What This Demo Does
439
+ </h2>
440
+ <div class="description-content" style="padding: 20px; border-radius: 10px; margin: 15px 0; text-align: center;">
441
+ <p class="description-main" style="line-height: 1.6; margin: 0; font-size: 1.45em;">
442
+ <strong>Upload images or videos</strong> → <strong>Get <span class="metric-text">Metric</span> <span class="pointcloud-text">Point Clouds</span>, <span class="cameras-text">Cameras</span> and <span class="gaussians-text">Novel Views</span></strong> → <strong>Explore in 3D</strong>
443
+ </p>
444
+ </div>
445
+
446
+ <div style="text-align: center; margin-top: 15px;">
447
+ <p class="description-tip" style="font-style: italic; margin: 0;">
448
+ <i class="fas fa-lightbulb fa-color-yellow" style="margin-right: 8px;"></i> <strong>Tip:</strong> Landscape-oriented images or videos are preferred for best 3D recovering.
449
+ </p>
450
+ </div>
451
+ </div>
452
+
453
+ <style>
454
+ @media (prefers-color-scheme: dark) {
455
+ .description-container {
456
+ background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
457
+ border: 1px solid rgba(59, 130, 246, 0.2);
458
+ }
459
+ .description-title { color: #3b82f6; }
460
+ .description-content { background: rgba(0, 0, 0, 0.3); }
461
+ .description-main { color: #e0e0e0; }
462
+ .description-text { color: #cbd5e1; }
463
+ .description-tip { color: #cbd5e1; }
464
+ }
465
+
466
+ @media (prefers-color-scheme: light) {
467
+ .description-container {
468
+ background: linear-gradient(135deg, rgba(59, 130, 246, 0.05) 0%, rgba(139, 92, 246, 0.05) 100%);
469
+ border: 1px solid rgba(59, 130, 246, 0.3);
470
+ }
471
+ .description-title { color: #3b82f6; }
472
+ .description-content { background: transparent; }
473
+ .description-main { color: #1e293b; }
474
+ .description-text { color: #475569; }
475
+ .description-tip { color: #475569; }
476
+ }
477
+ </style>
478
+ """
479
+
480
+
481
+ def get_acknowledgements_html():
482
+ """
483
+ Generate the acknowledgements section HTML.
484
+
485
+ Returns:
486
+ str: HTML string for the acknowledgements
487
+ """
488
+ return """
489
+ <div style="background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
490
+ padding: 25px; border-radius: 15px; margin: 20px 0; border: 1px solid rgba(59, 130, 246, 0.2);">
491
+ <h3 style="color: #3b82f6; margin-top: 0; text-align: center; font-size: 1.4em;">
492
+ <i class="fas fa-trophy fa-color-yellow" style="margin-right: 8px;"></i> Research Credits & Acknowledgments
493
+ </h3>
494
+
495
+ <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 15px 0;">
496
+ <!-- Original Research Section (Left) -->
497
+ <div style="text-align: center;">
498
+ <h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-flask fa-color-green" style="margin-right: 8px;"></i> Original Research</h4>
499
+ <p style="color: #e0e0e0; margin: 5px 0;">
500
+ <a href="https://depth-anything-3.github.io" target="_blank"
501
+ style="color: #3b82f6; text-decoration: none; font-weight: 600;">
502
+ Depth Anything 3
503
+ </a>
504
+ </p>
505
+ </div>
506
+
507
+ <!-- Previous Versions Section (Right) -->
508
+ <div style="text-align: center;">
509
+ <h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-history fa-color-blue" style="margin-right: 8px;"></i> Previous Versions</h4>
510
+ <div style="display: flex; flex-direction: row; gap: 15px; justify-content: center; align-items: center;">
511
+ <p style="color: #e0e0e0; margin: 0;">
512
+ <a href="https://huggingface.co/spaces/LiheYoung/Depth-Anything" target="_blank"
513
+ style="color: #3b82f6; text-decoration: none; font-weight: 600;">
514
+ Depth-Anything
515
+ </a>
516
+ </p>
517
+ <span style="color: #e0e0e0;">•</span>
518
+ <p style="color: #e0e0e0; margin: 0;">
519
+ <a href="https://huggingface.co/spaces/depth-anything/Depth-Anything-V2" target="_blank"
520
+ style="color: #3b82f6; text-decoration: none; font-weight: 600;">
521
+ Depth-Anything-V2
522
+ </a>
523
+ </p>
524
+ </div>
525
+ </div>
526
+ </div>
527
+
528
+ <!-- HF Demo Adapted from - Centered at the bottom of the whole block -->
529
+ <div style="margin-top: 20px; padding-top: 15px; border-top: 1px solid rgba(59, 130, 246, 0.3); text-align: center;">
530
+ <p style="color: #a0a0a0; font-size: 0.9em; margin: 0;">
531
+ <i class="fas fa-code-branch fa-color-gray" style="margin-right: 5px;"></i> HF demo adapted from <a href="https://huggingface.co/spaces/facebook/map-anything" target="_blank" style="color: inherit; text-decoration: none;">Map Anything</a>
532
+ </p>
533
+ </div>
534
+ </div>
535
+ """
536
+
537
+
538
+ def get_gradio_theme():
539
+ """
540
+ Get the configured Gradio theme with adaptive tech colors.
541
+
542
+ Returns:
543
+ gr.themes.Base: Configured Gradio theme
544
+ """
545
+ import gradio as gr
546
+
547
+ return gr.themes.Base(
548
+ primary_hue=gr.themes.Color(
549
+ c50="#eff6ff",
550
+ c100="#dbeafe",
551
+ c200="#bfdbfe",
552
+ c300="#93c5fd",
553
+ c400="#60a5fa",
554
+ c500="#3b82f6",
555
+ c600="#2563eb",
556
+ c700="#1d4ed8",
557
+ c800="#1e40af",
558
+ c900="#1e3a8a",
559
+ c950="#172554",
560
+ ),
561
+ secondary_hue=gr.themes.Color(
562
+ c50="#f5f3ff",
563
+ c100="#ede9fe",
564
+ c200="#ddd6fe",
565
+ c300="#c4b5fd",
566
+ c400="#a78bfa",
567
+ c500="#8b5cf6",
568
+ c600="#7c3aed",
569
+ c700="#6d28d9",
570
+ c800="#5b21b6",
571
+ c900="#4c1d95",
572
+ c950="#2e1065",
573
+ ),
574
+ neutral_hue=gr.themes.Color(
575
+ c50="#f8fafc",
576
+ c100="#f1f5f9",
577
+ c200="#e2e8f0",
578
+ c300="#cbd5e1",
579
+ c400="#94a3b8",
580
+ c500="#64748b",
581
+ c600="#475569",
582
+ c700="#334155",
583
+ c800="#1e293b",
584
+ c900="#0f172a",
585
+ c950="#020617",
586
+ ),
587
+ )
588
+
589
+
590
+ # Measure tab instructions HTML
591
+ MEASURE_INSTRUCTIONS_HTML = """
592
+ ### Click points on the image to compute distance.
593
+ > <i class="fas fa-triangle-exclamation fa-color-red" style="margin-right: 5px;"></i> Metric scale estimation is difficult on aerial/drone images.
594
+ """
src/depth_anything_3/app/gradio_app.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Refactored Gradio App for Depth Anything 3.
17
+
18
+ This is the main application file that orchestrates all components.
19
+ The original functionality has been split into modular components for better maintainability.
20
+ """
21
+
22
+ import argparse
23
+ import os
24
+ from typing import Any, Dict, List
25
+ import gradio as gr
26
+
27
+ from depth_anything_3.app.css_and_html import GRADIO_CSS, get_gradio_theme
28
+ from depth_anything_3.app.modules.event_handlers import EventHandlers
29
+ from depth_anything_3.app.modules.ui_components import UIComponents
30
+
31
+ # Set environment variables
32
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
33
+
34
+
35
+ class DepthAnything3App:
36
+ """
37
+ Main application class for Depth Anything 3 Gradio app.
38
+ """
39
+
40
+ def __init__(self, model_dir: str = None, workspace_dir: str = None, gallery_dir: str = None):
41
+ """
42
+ Initialize the application.
43
+
44
+ Args:
45
+ model_dir: Path to the model directory
46
+ workspace_dir: Path to the workspace directory
47
+ gallery_dir: Path to the gallery directory
48
+ """
49
+ self.model_dir = model_dir
50
+ self.workspace_dir = workspace_dir
51
+ self.gallery_dir = gallery_dir
52
+
53
+ # Set environment variables for directories
54
+ if self.model_dir:
55
+ os.environ["DA3_MODEL_DIR"] = self.model_dir
56
+ if self.workspace_dir:
57
+ os.environ["DA3_WORKSPACE_DIR"] = self.workspace_dir
58
+ if self.gallery_dir:
59
+ os.environ["DA3_GALLERY_DIR"] = self.gallery_dir
60
+
61
+ self.event_handlers = EventHandlers()
62
+ self.ui_components = UIComponents()
63
+
64
+ def cache_examples(
65
+ self,
66
+ show_cam: bool = True,
67
+ filter_black_bg: bool = False,
68
+ filter_white_bg: bool = False,
69
+ save_percentage: float = 20.0,
70
+ num_max_points: int = 1000,
71
+ cache_gs_tag: str = "",
72
+ gs_trj_mode: str = "smooth",
73
+ gs_video_quality: str = "low",
74
+ ) -> None:
75
+ """
76
+ Pre-cache all example scenes at startup.
77
+
78
+ Args:
79
+ show_cam: Whether to show camera in visualization
80
+ filter_black_bg: Whether to filter black background
81
+ filter_white_bg: Whether to filter white background
82
+ save_percentage: Filter percentage for point cloud
83
+ num_max_points: Maximum number of points
84
+ cache_gs_tag: Tag to match scene names for high-res+3DGS caching (e.g., "dl3dv")
85
+ gs_trj_mode: Trajectory mode for 3DGS
86
+ gs_video_quality: Video quality for 3DGS
87
+ """
88
+ from depth_anything_3.app.modules.utils import get_scene_info
89
+
90
+ examples_dir = os.path.join(self.workspace_dir, "examples")
91
+ if not os.path.exists(examples_dir):
92
+ print(f"Examples directory not found: {examples_dir}")
93
+ return
94
+
95
+ scenes = get_scene_info(examples_dir)
96
+ if not scenes:
97
+ print("No example scenes found to cache.")
98
+ return
99
+
100
+ print(f"\n{'='*60}")
101
+ print(f"Caching {len(scenes)} example scenes...")
102
+ print(f"{'='*60}\n")
103
+
104
+ for i, scene in enumerate(scenes, 1):
105
+ scene_name = scene["name"]
106
+
107
+ # Check if scene name matches the gs tag for high-res+3DGS caching
108
+ use_high_res_gs = cache_gs_tag and cache_gs_tag.lower() in scene_name.lower()
109
+
110
+ if use_high_res_gs:
111
+ print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (HIGH-RES + 3DGS)")
112
+ print(f" - Number of images: {scene['num_images']}")
113
+ print(f" - Matched tag: '{cache_gs_tag}' - using high_res + 3DGS")
114
+ else:
115
+ print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (LOW-RES)")
116
+ print(f" - Number of images: {scene['num_images']}")
117
+
118
+ try:
119
+ # Load example scene
120
+ _, target_dir, _, _, _, _, _, _, _ = self.event_handlers.load_example_scene(
121
+ scene_name
122
+ )
123
+
124
+ if target_dir and target_dir != "None":
125
+ # Run reconstruction with appropriate settings
126
+ print(" - Running reconstruction...")
127
+ result = self.event_handlers.gradio_demo(
128
+ target_dir=target_dir,
129
+ show_cam=show_cam,
130
+ filter_black_bg=filter_black_bg,
131
+ filter_white_bg=filter_white_bg,
132
+ process_res_method="high_res" if use_high_res_gs else "low_res",
133
+ selected_first_frame="",
134
+ save_percentage=save_percentage,
135
+ num_max_points=num_max_points,
136
+ infer_gs=use_high_res_gs,
137
+ gs_trj_mode=gs_trj_mode,
138
+ gs_video_quality=gs_video_quality,
139
+ )
140
+
141
+ # Check if successful
142
+ if result[0] is not None: # reconstruction_output
143
+ print(f" ✓ Scene '{scene_name}' cached successfully")
144
+ else:
145
+ print(f" ✗ Scene '{scene_name}' caching failed: {result[1]}")
146
+ else:
147
+ print(f" ✗ Scene '{scene_name}' loading failed")
148
+
149
+ except Exception as e:
150
+ print(f" ✗ Error caching scene '{scene_name}': {str(e)}")
151
+
152
+ print()
153
+
154
+ print("=" * 60)
155
+ print("Example scene caching completed!")
156
+ print("=" * 60 + "\n")
157
+
158
+ def create_app(self) -> gr.Blocks:
159
+ """
160
+ Create and configure the Gradio application.
161
+
162
+ Returns:
163
+ Configured Gradio Blocks interface
164
+ """
165
+
166
+ # Initialize theme
167
+ def get_theme():
168
+ return get_gradio_theme()
169
+
170
+ with gr.Blocks(theme=get_theme(), css=GRADIO_CSS) as demo:
171
+ # State variables for the tabbed interface
172
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
173
+ processed_data_state = gr.State(value=None)
174
+ measure_points_state = gr.State(value=[])
175
+ selected_first_frame_state = gr.State(value="")
176
+ selected_image_index_state = gr.State(value=0) # Track selected image index
177
+ # current_view_index = gr.State(value=0) # noqa: F841 Track current view index
178
+
179
+ # Header and description
180
+ self.ui_components.create_header_section()
181
+ self.ui_components.create_description_section()
182
+
183
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
184
+
185
+ # Main content area
186
+ with gr.Row():
187
+ with gr.Column(scale=2):
188
+ # Upload section
189
+ (
190
+ input_video,
191
+ s_time_interval,
192
+ input_images,
193
+ image_gallery,
194
+ select_first_frame_btn,
195
+ ) = self.ui_components.create_upload_section()
196
+
197
+ with gr.Column(scale=4):
198
+ with gr.Column():
199
+ # gr.Markdown("**Metric 3D Reconstruction (Point Cloud and Camera Poses)**")
200
+ # Reconstruction control section (buttons) - moved below tabs
201
+
202
+ log_output = gr.Markdown(
203
+ "Please upload a video or images, then click Reconstruct.",
204
+ elem_classes=["custom-log"],
205
+ )
206
+
207
+ # Tabbed interface
208
+ with gr.Tabs():
209
+ with gr.Tab("Point Cloud & Cameras"):
210
+ reconstruction_output = (
211
+ self.ui_components.create_3d_viewer_section()
212
+ )
213
+
214
+ with gr.Tab("Metric Depth"):
215
+ (
216
+ prev_measure_btn,
217
+ measure_view_selector,
218
+ next_measure_btn,
219
+ measure_image,
220
+ measure_depth_image,
221
+ measure_text,
222
+ ) = self.ui_components.create_measure_section()
223
+
224
+ with gr.Tab("3DGS Rendered Novel Views"):
225
+ gs_video, gs_info = self.ui_components.create_nvs_video()
226
+
227
+ # Inference control section (before inference)
228
+ (process_res_method_dropdown, infer_gs) = (
229
+ self.ui_components.create_inference_control_section()
230
+ )
231
+
232
+ # Display control section - includes 3DGS options, buttons, and Visualization Options # noqa: E501
233
+ (
234
+ show_cam,
235
+ filter_black_bg,
236
+ filter_white_bg,
237
+ save_percentage,
238
+ num_max_points,
239
+ gs_trj_mode,
240
+ gs_video_quality,
241
+ submit_btn,
242
+ clear_btn,
243
+ ) = self.ui_components.create_display_control_section()
244
+
245
+ # bind visibility of gs_trj_mode to infer_gs
246
+ infer_gs.change(
247
+ fn=lambda checked: (
248
+ gr.update(visible=checked),
249
+ gr.update(visible=checked),
250
+ gr.update(visible=checked),
251
+ gr.update(visible=(not checked)),
252
+ ),
253
+ inputs=infer_gs,
254
+ outputs=[gs_trj_mode, gs_video_quality, gs_video, gs_info],
255
+ )
256
+
257
+ # Example scenes section
258
+ gr.Markdown("## Example Scenes")
259
+
260
+ scenes = self.ui_components.create_example_scenes_section()
261
+ scene_components = self.ui_components.create_example_scene_grid(scenes)
262
+
263
+ # Set up event handlers
264
+ self._setup_event_handlers(
265
+ demo,
266
+ is_example,
267
+ processed_data_state,
268
+ measure_points_state,
269
+ target_dir_output,
270
+ input_video,
271
+ input_images,
272
+ s_time_interval,
273
+ image_gallery,
274
+ reconstruction_output,
275
+ log_output,
276
+ show_cam,
277
+ filter_black_bg,
278
+ filter_white_bg,
279
+ process_res_method_dropdown,
280
+ save_percentage,
281
+ submit_btn,
282
+ clear_btn,
283
+ num_max_points,
284
+ infer_gs,
285
+ select_first_frame_btn,
286
+ selected_first_frame_state,
287
+ selected_image_index_state,
288
+ measure_view_selector,
289
+ measure_image,
290
+ measure_depth_image,
291
+ measure_text,
292
+ prev_measure_btn,
293
+ next_measure_btn,
294
+ scenes,
295
+ scene_components,
296
+ gs_video,
297
+ gs_info,
298
+ gs_trj_mode,
299
+ gs_video_quality,
300
+ )
301
+
302
+ # Acknowledgements
303
+ self.ui_components.create_acknowledgements_section()
304
+
305
+ return demo
306
+
307
+ def _setup_event_handlers(
308
+ self,
309
+ demo: gr.Blocks,
310
+ is_example: gr.Textbox,
311
+ processed_data_state: gr.State,
312
+ measure_points_state: gr.State,
313
+ target_dir_output: gr.Textbox,
314
+ input_video: gr.Video,
315
+ input_images: gr.File,
316
+ s_time_interval: gr.Slider,
317
+ image_gallery: gr.Gallery,
318
+ reconstruction_output: gr.Model3D,
319
+ log_output: gr.Markdown,
320
+ show_cam: gr.Checkbox,
321
+ filter_black_bg: gr.Checkbox,
322
+ filter_white_bg: gr.Checkbox,
323
+ process_res_method_dropdown: gr.Dropdown,
324
+ save_percentage: gr.Slider,
325
+ submit_btn: gr.Button,
326
+ clear_btn: gr.ClearButton,
327
+ num_max_points: gr.Slider,
328
+ infer_gs: gr.Checkbox,
329
+ select_first_frame_btn: gr.Button,
330
+ selected_first_frame_state: gr.State,
331
+ selected_image_index_state: gr.State,
332
+ measure_view_selector: gr.Dropdown,
333
+ measure_image: gr.Image,
334
+ measure_depth_image: gr.Image,
335
+ measure_text: gr.Markdown,
336
+ prev_measure_btn: gr.Button,
337
+ next_measure_btn: gr.Button,
338
+ scenes: List[Dict[str, Any]],
339
+ scene_components: List[gr.Image],
340
+ gs_video: gr.Video,
341
+ gs_info: gr.Markdown,
342
+ gs_trj_mode: gr.Dropdown,
343
+ gs_video_quality: gr.Dropdown,
344
+ ) -> None:
345
+ """
346
+ Set up all event handlers for the application.
347
+
348
+ Args:
349
+ demo: Gradio Blocks interface
350
+ All other arguments: Gradio components to connect
351
+ """
352
+ # Configure clear button
353
+ clear_btn.add(
354
+ [
355
+ input_video,
356
+ input_images,
357
+ reconstruction_output,
358
+ log_output,
359
+ target_dir_output,
360
+ image_gallery,
361
+ gs_video,
362
+ ]
363
+ )
364
+
365
+ # Main reconstruction button
366
+ submit_btn.click(
367
+ fn=self.event_handlers.clear_fields, inputs=[], outputs=[reconstruction_output]
368
+ ).then(fn=self.event_handlers.update_log, inputs=[], outputs=[log_output]).then(
369
+ fn=self.event_handlers.gradio_demo,
370
+ inputs=[
371
+ target_dir_output,
372
+ show_cam,
373
+ filter_black_bg,
374
+ filter_white_bg,
375
+ process_res_method_dropdown,
376
+ selected_first_frame_state,
377
+ save_percentage,
378
+ # pass num_max_points
379
+ num_max_points,
380
+ infer_gs,
381
+ gs_trj_mode,
382
+ gs_video_quality,
383
+ ],
384
+ outputs=[
385
+ reconstruction_output,
386
+ log_output,
387
+ processed_data_state,
388
+ measure_image,
389
+ measure_depth_image,
390
+ measure_text,
391
+ measure_view_selector,
392
+ gs_video,
393
+ gs_video, # gs_video visibility
394
+ gs_info, # gs_info visibility
395
+ ],
396
+ ).then(
397
+ fn=lambda: "False",
398
+ inputs=[],
399
+ outputs=[is_example], # set is_example to "False"
400
+ )
401
+
402
+ # Real-time visualization updates
403
+ self._setup_visualization_handlers(
404
+ show_cam,
405
+ filter_black_bg,
406
+ filter_white_bg,
407
+ process_res_method_dropdown,
408
+ target_dir_output,
409
+ is_example,
410
+ reconstruction_output,
411
+ log_output,
412
+ )
413
+
414
+ # File upload handlers
415
+ input_video.change(
416
+ fn=self.event_handlers.handle_uploads,
417
+ inputs=[input_video, input_images, s_time_interval],
418
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
419
+ )
420
+ input_images.change(
421
+ fn=self.event_handlers.handle_uploads,
422
+ inputs=[input_video, input_images, s_time_interval],
423
+ outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
424
+ )
425
+
426
+ # Image gallery click handler (for selecting first frame)
427
+ def handle_image_selection(evt: gr.SelectData):
428
+ if evt is None or evt.index is None:
429
+ return "No image selected", 0
430
+ selected_index = evt.index
431
+ return f"Selected image {selected_index} as potential first frame", selected_index
432
+
433
+ image_gallery.select(
434
+ fn=handle_image_selection,
435
+ outputs=[log_output, selected_image_index_state],
436
+ )
437
+
438
+ # Select first frame handler
439
+ select_first_frame_btn.click(
440
+ fn=self.event_handlers.select_first_frame,
441
+ inputs=[image_gallery, selected_image_index_state],
442
+ outputs=[image_gallery, log_output, selected_first_frame_state],
443
+ )
444
+
445
+ # Navigation handlers
446
+ self._setup_navigation_handlers(
447
+ prev_measure_btn,
448
+ next_measure_btn,
449
+ measure_view_selector,
450
+ measure_image,
451
+ measure_depth_image,
452
+ measure_points_state,
453
+ processed_data_state,
454
+ )
455
+
456
+ # Measurement handler
457
+ measure_image.select(
458
+ fn=self.event_handlers.measure,
459
+ inputs=[processed_data_state, measure_points_state, measure_view_selector],
460
+ outputs=[measure_image, measure_depth_image, measure_points_state, measure_text],
461
+ )
462
+
463
+ # Example scene handlers
464
+ self._setup_example_scene_handlers(
465
+ scenes,
466
+ scene_components,
467
+ reconstruction_output,
468
+ target_dir_output,
469
+ image_gallery,
470
+ log_output,
471
+ is_example,
472
+ processed_data_state,
473
+ measure_view_selector,
474
+ measure_image,
475
+ measure_depth_image,
476
+ gs_video,
477
+ gs_info,
478
+ )
479
+
480
+ def _setup_visualization_handlers(
481
+ self,
482
+ show_cam: gr.Checkbox,
483
+ filter_black_bg: gr.Checkbox,
484
+ filter_white_bg: gr.Checkbox,
485
+ process_res_method_dropdown: gr.Dropdown,
486
+ target_dir_output: gr.Textbox,
487
+ is_example: gr.Textbox,
488
+ reconstruction_output: gr.Model3D,
489
+ log_output: gr.Markdown,
490
+ ) -> None:
491
+ """Set up visualization update handlers."""
492
+ # Common inputs for visualization updates
493
+ viz_inputs = [
494
+ target_dir_output,
495
+ show_cam,
496
+ is_example,
497
+ filter_black_bg,
498
+ filter_white_bg,
499
+ process_res_method_dropdown,
500
+ ]
501
+
502
+ # Set up change handlers for all visualization controls
503
+ for component in [show_cam, filter_black_bg, filter_white_bg]:
504
+ component.change(
505
+ fn=self.event_handlers.update_visualization,
506
+ inputs=viz_inputs,
507
+ outputs=[reconstruction_output, log_output],
508
+ )
509
+
510
+ def _setup_navigation_handlers(
511
+ self,
512
+ prev_measure_btn: gr.Button,
513
+ next_measure_btn: gr.Button,
514
+ measure_view_selector: gr.Dropdown,
515
+ measure_image: gr.Image,
516
+ measure_depth_image: gr.Image,
517
+ measure_points_state: gr.State,
518
+ processed_data_state: gr.State,
519
+ ) -> None:
520
+ """Set up navigation handlers for measure tab."""
521
+ # Measure tab navigation
522
+ prev_measure_btn.click(
523
+ fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
524
+ processed_data, current_selector, -1
525
+ ),
526
+ inputs=[processed_data_state, measure_view_selector],
527
+ outputs=[
528
+ measure_view_selector,
529
+ measure_image,
530
+ measure_depth_image,
531
+ measure_points_state,
532
+ ],
533
+ )
534
+
535
+ next_measure_btn.click(
536
+ fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
537
+ processed_data, current_selector, 1
538
+ ),
539
+ inputs=[processed_data_state, measure_view_selector],
540
+ outputs=[
541
+ measure_view_selector,
542
+ measure_image,
543
+ measure_depth_image,
544
+ measure_points_state,
545
+ ],
546
+ )
547
+
548
+ measure_view_selector.change(
549
+ fn=lambda processed_data, selector_value: (
550
+ self.event_handlers.update_measure_view(
551
+ processed_data, int(selector_value.split()[1]) - 1
552
+ )
553
+ if selector_value
554
+ else (None, None, [])
555
+ ),
556
+ inputs=[processed_data_state, measure_view_selector],
557
+ outputs=[measure_image, measure_depth_image, measure_points_state],
558
+ )
559
+
560
+ def _setup_example_scene_handlers(
561
+ self,
562
+ scenes: List[Dict[str, Any]],
563
+ scene_components: List[gr.Image],
564
+ reconstruction_output: gr.Model3D,
565
+ target_dir_output: gr.Textbox,
566
+ image_gallery: gr.Gallery,
567
+ log_output: gr.Markdown,
568
+ is_example: gr.Textbox,
569
+ processed_data_state: gr.State,
570
+ measure_view_selector: gr.Dropdown,
571
+ measure_image: gr.Image,
572
+ measure_depth_image: gr.Image,
573
+ gs_video: gr.Video,
574
+ gs_info: gr.Markdown,
575
+ ) -> None:
576
+ """Set up example scene handlers."""
577
+
578
+ def load_and_update_measure(name):
579
+ result = self.event_handlers.load_example_scene(name)
580
+ # result = (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
581
+
582
+ # Update measure view if processed_data is available
583
+ measure_img = None
584
+ measure_depth = None
585
+ if result[4] is not None: # processed_data exists
586
+ measure_img, measure_depth, _ = (
587
+ self.event_handlers.visualization_handler.update_measure_view(result[4], 0)
588
+ )
589
+
590
+ return result + ("True", measure_img, measure_depth)
591
+
592
+ for i, scene in enumerate(scenes):
593
+ if i < len(scene_components):
594
+ scene_components[i].select(
595
+ fn=lambda name=scene["name"]: load_and_update_measure(name),
596
+ outputs=[
597
+ reconstruction_output,
598
+ target_dir_output,
599
+ image_gallery,
600
+ log_output,
601
+ processed_data_state,
602
+ measure_view_selector,
603
+ gs_video,
604
+ gs_video, # gs_video_visibility
605
+ gs_info, # gs_info_visibility
606
+ is_example,
607
+ measure_image,
608
+ measure_depth_image,
609
+ ],
610
+ )
611
+
612
+ def launch(self, host: str = "127.0.0.1", port: int = 7860, **kwargs) -> None:
613
+ """
614
+ Launch the application.
615
+
616
+ Args:
617
+ host: Host address to bind to
618
+ port: Port number to bind to
619
+ **kwargs: Additional arguments for demo.launch()
620
+ """
621
+ demo = self.create_app()
622
+ demo.queue(max_size=20).launch(
623
+ show_error=True, ssr_mode=False, server_name=host, server_port=port, **kwargs
624
+ )
625
+
626
+
627
+ def main():
628
+ """Main function to run the application."""
629
+ parser = argparse.ArgumentParser(
630
+ description="Depth Anything 3 Gradio Application",
631
+ formatter_class=argparse.RawDescriptionHelpFormatter,
632
+ epilog="""
633
+ Examples:
634
+ # Basic usage
635
+ python gradio_app.py --help
636
+ python gradio_app.py --host 0.0.0.0 --port 8080
637
+ python gradio_app.py --model-dir /path/to/model --workspace-dir /path/to/workspace
638
+
639
+ # Cache examples at startup (all low-res)
640
+ python gradio_app.py --cache-examples
641
+
642
+ # Cache with selective high-res+3DGS for scenes matching tag
643
+ python gradio_app.py --cache-examples --cache-gs-tag dl3dv
644
+ # This will use high-res + 3DGS for scenes containing "dl3dv" in their name,
645
+ # and low-res only for other scenes
646
+ """,
647
+ )
648
+
649
+ # Server configuration
650
+ parser.add_argument(
651
+ "--host", default="127.0.0.1", help="Host address to bind to (default: 127.0.0.1)"
652
+ )
653
+ parser.add_argument(
654
+ "--port", type=int, default=7860, help="Port number to bind to (default: 7860)"
655
+ )
656
+
657
+ # Directory configuration
658
+ parser.add_argument(
659
+ "--model-dir",
660
+ default="depth-anything/DA3NESTED-GIANT-LARGE",
661
+ help="Path to the model directory (default: depth-anything/DA3NESTED-GIANT-LARGE)",
662
+ )
663
+ parser.add_argument(
664
+ "--workspace-dir",
665
+ default="workspace/gradio", # noqa: E501
666
+ help="Path to the workspace directory (default: workspace/gradio)", # noqa: E501
667
+ )
668
+ parser.add_argument(
669
+ "--gallery-dir",
670
+ default="workspace/gallery",
671
+ help="Path to the gallery directory (default: workspace/gallery)", # noqa: E501
672
+ )
673
+
674
+ # Additional Gradio options
675
+ parser.add_argument("--share", action="store_true", help="Create a public link for the app")
676
+ parser.add_argument("--debug", action="store_true", help="Enable debug mode")
677
+
678
+ # Example caching options
679
+ parser.add_argument(
680
+ "--cache-examples",
681
+ action="store_true",
682
+ help="Pre-cache all example scenes at startup for faster loading",
683
+ )
684
+ parser.add_argument(
685
+ "--cache-gs-tag",
686
+ type=str,
687
+ default="",
688
+ help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.", # noqa: E501
689
+ )
690
+
691
+ args = parser.parse_args()
692
+
693
+ # Create directories if they don't exist
694
+ os.makedirs(args.workspace_dir, exist_ok=True)
695
+ os.makedirs(args.gallery_dir, exist_ok=True)
696
+
697
+ # Initialize and launch the application
698
+ app = DepthAnything3App(
699
+ model_dir=args.model_dir, workspace_dir=args.workspace_dir, gallery_dir=args.gallery_dir
700
+ )
701
+
702
+ # Prepare launch arguments
703
+ launch_kwargs = {"share": args.share, "debug": args.debug}
704
+
705
+ print("Starting Depth Anything 3 Gradio App...")
706
+ print(f"Host: {args.host}")
707
+ print(f"Port: {args.port}")
708
+ print(f"Model Directory: {args.model_dir}")
709
+ print(f"Workspace Directory: {args.workspace_dir}")
710
+ print(f"Gallery Directory: {args.gallery_dir}")
711
+ print(f"Share: {args.share}")
712
+ print(f"Debug: {args.debug}")
713
+ print(f"Cache Examples: {args.cache_examples}")
714
+ if args.cache_examples:
715
+ if args.cache_gs_tag:
716
+ print(
717
+ f"Cache GS Tag: '{args.cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)" # noqa: E501
718
+ ) # noqa: E501
719
+ else:
720
+ print("Cache GS Tag: None (all scenes will use low-res only)")
721
+
722
+ # Pre-cache examples if requested
723
+ if args.cache_examples:
724
+ print("\n" + "=" * 60)
725
+ print("Pre-caching mode enabled")
726
+ if args.cache_gs_tag:
727
+ print(f"Scenes containing '{args.cache_gs_tag}' will use HIGH-RES + 3DGS")
728
+ print("Other scenes will use LOW-RES only")
729
+ else:
730
+ print("All scenes will use LOW-RES only")
731
+ print("=" * 60)
732
+ app.cache_examples(
733
+ show_cam=True,
734
+ filter_black_bg=False,
735
+ filter_white_bg=False,
736
+ save_percentage=5.0,
737
+ num_max_points=1000,
738
+ cache_gs_tag=args.cache_gs_tag,
739
+ gs_trj_mode="smooth",
740
+ gs_video_quality="low",
741
+ )
742
+
743
+ app.launch(host=args.host, port=args.port, **launch_kwargs)
744
+
745
+
746
+ if __name__ == "__main__":
747
+ main()
src/depth_anything_3/app/modules/__init__.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Modules package for Depth Anything 3 Gradio app.
17
+
18
+ This package contains all the modular components for the Gradio application.
19
+ """
20
+
21
+ from depth_anything_3.app.modules.event_handlers import EventHandlers
22
+ from depth_anything_3.app.modules.file_handlers import FileHandler
23
+ from depth_anything_3.app.modules.model_inference import ModelInference
24
+ from depth_anything_3.app.modules.ui_components import UIComponents
25
+ from depth_anything_3.app.modules.utils import (
26
+ cleanup_memory,
27
+ create_depth_visualization,
28
+ get_logo_base64,
29
+ get_scene_info,
30
+ save_to_gallery_func,
31
+ )
32
+ from depth_anything_3.app.modules.visualization import VisualizationHandler
33
+
34
+ __all__ = [
35
+ "ModelInference",
36
+ "FileHandler",
37
+ "VisualizationHandler",
38
+ "EventHandlers",
39
+ "UIComponents",
40
+ "create_depth_visualization",
41
+ "save_to_gallery_func",
42
+ "get_scene_info",
43
+ "cleanup_memory",
44
+ "get_logo_base64",
45
+ ]
src/depth_anything_3/app/modules/event_handlers.py ADDED
@@ -0,0 +1,629 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Event handling module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles all event callbacks and user interactions.
19
+ """
20
+
21
+ import os
22
+ import time
23
+ from glob import glob
24
+ from typing import Any, Dict, List, Optional, Tuple
25
+ import gradio as gr
26
+ import numpy as np
27
+ import torch
28
+
29
+ from depth_anything_3.app.modules.file_handlers import FileHandler
30
+ from depth_anything_3.app.modules.model_inference import ModelInference
31
+ from depth_anything_3.app.modules.utils import cleanup_memory
32
+ from depth_anything_3.app.modules.visualization import VisualizationHandler
33
+
34
+
35
+ class EventHandlers:
36
+ """
37
+ Handles all event callbacks and user interactions for the Gradio app.
38
+ """
39
+
40
+ def __init__(self):
41
+ """Initialize the event handlers."""
42
+ self.model_inference = ModelInference()
43
+ self.file_handler = FileHandler()
44
+ self.visualization_handler = VisualizationHandler()
45
+
46
+ def clear_fields(self) -> None:
47
+ """
48
+ Clears the 3D viewer, the stored target_dir, and empties the gallery.
49
+ """
50
+ return None
51
+
52
+ def update_log(self) -> str:
53
+ """
54
+ Display a quick log message while waiting.
55
+ """
56
+ return "Loading and Reconstructing..."
57
+
58
+ def save_current_visualization(
59
+ self,
60
+ target_dir: str,
61
+ save_percentage: float,
62
+ show_cam: bool,
63
+ filter_black_bg: bool,
64
+ filter_white_bg: bool,
65
+ processed_data: Optional[Dict],
66
+ scene_name: str = "",
67
+ ) -> str:
68
+ """
69
+ Save current visualization results to gallery with specified save percentage.
70
+
71
+ Args:
72
+ target_dir: Directory containing results
73
+ save_percentage: Percentage of points to save (0-100)
74
+ show_cam: Whether to show cameras
75
+ filter_black_bg: Whether to filter black background
76
+ filter_white_bg: Whether to filter white background
77
+ processed_data: Processed data from reconstruction
78
+
79
+ Returns:
80
+ Status message
81
+ """
82
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
83
+ return "No reconstruction available. Please run 'Reconstruct' first."
84
+
85
+ if processed_data is None:
86
+ return "No processed data available. Please run 'Reconstruct' first."
87
+
88
+ try:
89
+ # Add debug information
90
+ print("[DEBUG] save_current_visualization called with:")
91
+ print(f" target_dir: {target_dir}")
92
+ print(f" save_percentage: {save_percentage}")
93
+ print(f" show_cam: {show_cam}")
94
+ print(f" filter_black_bg: {filter_black_bg}")
95
+ print(f" filter_white_bg: {filter_white_bg}")
96
+ print(f" processed_data: {processed_data is not None}")
97
+
98
+ # Import the gallery save function
99
+ # Create gallery name with user input or auto-generated
100
+ import datetime
101
+
102
+ from .utils import save_to_gallery_func
103
+
104
+ timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
105
+ if scene_name and scene_name.strip():
106
+ gallery_name = f"{scene_name.strip()}_{timestamp}_pct{save_percentage:.0f}"
107
+ else:
108
+ gallery_name = f"save_{timestamp}_pct{save_percentage:.0f}"
109
+
110
+ print(f"[DEBUG] Saving to gallery with name: {gallery_name}")
111
+
112
+ # Save entire process folder to gallery
113
+ success, message = save_to_gallery_func(
114
+ target_dir=target_dir, processed_data=processed_data, gallery_name=gallery_name
115
+ )
116
+
117
+ if success:
118
+ print(f"[DEBUG] Gallery save completed successfully: {message}")
119
+ return (
120
+ "Successfully saved to gallery!\n"
121
+ f"Gallery name: {gallery_name}\n"
122
+ f"Save percentage: {save_percentage}%\n"
123
+ f"Show cameras: {show_cam}\n"
124
+ f"Filter black bg: {filter_black_bg}\n"
125
+ f"Filter white bg: {filter_white_bg}\n\n"
126
+ f"{message}"
127
+ )
128
+ else:
129
+ print(f"[DEBUG] Gallery save failed: {message}")
130
+ return f"Failed to save to gallery: {message}"
131
+
132
+ except Exception as e:
133
+ return f"Error saving visualization: {str(e)}"
134
+
135
+ def gradio_demo(
136
+ self,
137
+ target_dir: str,
138
+ show_cam: bool = True,
139
+ filter_black_bg: bool = False,
140
+ filter_white_bg: bool = False,
141
+ process_res_method: str = "upper_bound_resize",
142
+ selected_first_frame: str = "",
143
+ save_percentage: float = 30.0,
144
+ num_max_points: int = 1_000_000,
145
+ infer_gs: bool = False,
146
+ gs_trj_mode: str = "extend",
147
+ gs_video_quality: str = "high",
148
+ ) -> Tuple[
149
+ Optional[str],
150
+ str,
151
+ Optional[Dict],
152
+ Optional[np.ndarray],
153
+ Optional[np.ndarray],
154
+ str,
155
+ gr.Dropdown,
156
+ Optional[str], # gs video path
157
+ gr.update, # gs video visibility update
158
+ gr.update, # gs info visibility update
159
+ ]:
160
+ """
161
+ Perform reconstruction using the already-created target_dir/images.
162
+
163
+ Args:
164
+ target_dir: Directory containing images
165
+ show_cam: Whether to show camera
166
+ filter_black_bg: Whether to filter black background
167
+ filter_white_bg: Whether to filter white background
168
+ process_res_method: Method for resizing input images
169
+ selected_first_frame: Selected first frame filename
170
+ infer_gs: Whether to infer 3D Gaussian Splatting
171
+
172
+ Returns:
173
+ Tuple of reconstruction results
174
+ """
175
+ if not os.path.isdir(target_dir) or target_dir == "None":
176
+ return (
177
+ None,
178
+ "No valid target directory found. Please upload first.",
179
+ None,
180
+ None,
181
+ None,
182
+ "",
183
+ None,
184
+ None,
185
+ gr.update(visible=False), # gs_video
186
+ gr.update(visible=True), # gs_info
187
+ )
188
+
189
+ start_time = time.time()
190
+ cleanup_memory()
191
+
192
+ # Get image files for logging
193
+ target_dir_images = os.path.join(target_dir, "images")
194
+ all_files = (
195
+ sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
196
+ )
197
+
198
+ print("Running DepthAnything3 model...")
199
+ print(f"Selected first frame: {selected_first_frame}")
200
+
201
+ # Validate selected_first_frame against current image list
202
+ if selected_first_frame and target_dir_images:
203
+ current_files = (
204
+ sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
205
+ )
206
+ if selected_first_frame not in current_files:
207
+ print(
208
+ f"Selected first frame '{selected_first_frame}' not found in "
209
+ "current images. Using default order."
210
+ )
211
+ selected_first_frame = "" # Reset to use default order
212
+
213
+ with torch.no_grad():
214
+ prediction, processed_data = self.model_inference.run_inference(
215
+ target_dir,
216
+ process_res_method=process_res_method,
217
+ show_camera=show_cam,
218
+ selected_first_frame=selected_first_frame,
219
+ save_percentage=save_percentage,
220
+ num_max_points=int(num_max_points * 1000), # Convert K to actual count
221
+ infer_gs=infer_gs,
222
+ gs_trj_mode=gs_trj_mode,
223
+ gs_video_quality=gs_video_quality,
224
+ )
225
+
226
+ # The GLB file is already generated by the API
227
+ glbfile = os.path.join(target_dir, "scene.glb")
228
+
229
+ # Handle 3DGS video based on infer_gs flag
230
+ gsvideo_path = None
231
+ gs_video_visible = False
232
+ gs_info_visible = True
233
+
234
+ if infer_gs:
235
+ try:
236
+ gsvideo_path = sorted(glob(os.path.join(target_dir, "gs_video", "*.mp4")))[-1]
237
+ gs_video_visible = True
238
+ gs_info_visible = False
239
+ except IndexError:
240
+ gsvideo_path = None
241
+ print("3DGS video not found, but infer_gs was enabled")
242
+
243
+ # Cleanup
244
+ cleanup_memory()
245
+
246
+ end_time = time.time()
247
+ print(f"Total time: {end_time - start_time:.2f} seconds")
248
+ log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
249
+
250
+ # Populate visualization tabs with processed data
251
+ depth_vis, measure_img, measure_depth_vis, measure_pts = (
252
+ self.visualization_handler.populate_visualization_tabs(processed_data)
253
+ )
254
+
255
+ # Update view selectors based on available views
256
+ depth_selector, measure_selector = self.visualization_handler.update_view_selectors(
257
+ processed_data
258
+ )
259
+
260
+ return (
261
+ glbfile,
262
+ log_msg,
263
+ processed_data,
264
+ measure_img, # measure_image
265
+ measure_depth_vis, # measure_depth_image
266
+ "", # measure_text (empty initially)
267
+ measure_selector, # measure_view_selector
268
+ gsvideo_path,
269
+ gr.update(visible=gs_video_visible), # gs_video visibility
270
+ gr.update(visible=gs_info_visible), # gs_info visibility
271
+ )
272
+
273
+ def update_visualization(
274
+ self,
275
+ target_dir: str,
276
+ show_cam: bool,
277
+ is_example: str,
278
+ filter_black_bg: bool = False,
279
+ filter_white_bg: bool = False,
280
+ process_res_method: str = "upper_bound_resize",
281
+ ) -> Tuple[gr.update, str]:
282
+ """
283
+ Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
284
+ and return it for the 3D viewer.
285
+
286
+ Args:
287
+ target_dir: Directory containing results
288
+ show_cam: Whether to show camera
289
+ is_example: Whether this is an example scene
290
+ filter_black_bg: Whether to filter black background
291
+ filter_white_bg: Whether to filter white background
292
+ process_res_method: Method for resizing input images
293
+
294
+ Returns:
295
+ Tuple of (glb_file, log_message)
296
+ """
297
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
298
+ return (
299
+ gr.update(),
300
+ "No reconstruction available. Please click the Reconstruct button first.",
301
+ )
302
+
303
+ # Check if GLB exists (could be cached example or reconstructed scene)
304
+ glbfile = os.path.join(target_dir, "scene.glb")
305
+ if os.path.exists(glbfile):
306
+ return (
307
+ glbfile,
308
+ (
309
+ "Visualization loaded from cache."
310
+ if is_example == "True"
311
+ else "Visualization updated."
312
+ ),
313
+ )
314
+
315
+ # If no GLB but it's an example that hasn't been reconstructed yet
316
+ if is_example == "True":
317
+ return (
318
+ gr.update(),
319
+ "No reconstruction available. Please click the Reconstruct button first.",
320
+ )
321
+
322
+ # For non-examples, check predictions.npz
323
+ predictions_path = os.path.join(target_dir, "predictions.npz")
324
+ if not os.path.exists(predictions_path):
325
+ error_message = (
326
+ f"No reconstruction available at {predictions_path}. "
327
+ "Please run 'Reconstruct' first."
328
+ )
329
+ return gr.update(), error_message
330
+
331
+ loaded = np.load(predictions_path, allow_pickle=True)
332
+ predictions = {key: loaded[key] for key in loaded.keys()} # noqa: F841
333
+
334
+ return (
335
+ glbfile,
336
+ "Visualization updated.",
337
+ )
338
+
339
+ def handle_uploads(
340
+ self,
341
+ input_video: Optional[str],
342
+ input_images: Optional[List],
343
+ s_time_interval: float = 10.0,
344
+ ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
345
+ """
346
+ Handle file uploads and update gallery.
347
+
348
+ Args:
349
+ input_video: Path to input video file
350
+ input_images: List of input image files
351
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
352
+
353
+ Returns:
354
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message)
355
+ """
356
+ return self.file_handler.update_gallery_on_upload(
357
+ input_video, input_images, s_time_interval
358
+ )
359
+
360
+ def load_example_scene(self, scene_name: str, examples_dir: str = None) -> Tuple[
361
+ Optional[str],
362
+ Optional[str],
363
+ Optional[List],
364
+ str,
365
+ Optional[Dict],
366
+ gr.Dropdown,
367
+ Optional[str],
368
+ gr.update,
369
+ gr.update,
370
+ ]:
371
+ """
372
+ Load a scene from examples directory.
373
+
374
+ Args:
375
+ scene_name: Name of the scene to load
376
+ examples_dir: Path to examples directory (if None, uses workspace_dir/examples)
377
+
378
+ Returns:
379
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
380
+ """
381
+ if examples_dir is None:
382
+ # Get workspace directory from environment variable
383
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
384
+ examples_dir = os.path.join(workspace_dir, "examples")
385
+
386
+ reconstruction_output, target_dir, image_paths, log_message = (
387
+ self.file_handler.load_example_scene(scene_name, examples_dir)
388
+ )
389
+
390
+ # Try to load cached processed data if available
391
+ processed_data = None
392
+ measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1")
393
+ gs_video_path = None
394
+ gs_video_visible = False
395
+ gs_info_visible = True
396
+
397
+ if target_dir and target_dir != "None":
398
+ predictions_path = os.path.join(target_dir, "predictions.npz")
399
+ if os.path.exists(predictions_path):
400
+ try:
401
+ # Load predictions from cache
402
+ loaded = np.load(predictions_path, allow_pickle=True)
403
+ predictions = {key: loaded[key] for key in loaded.keys()}
404
+
405
+ # Reconstruct processed_data structure
406
+ num_images = len(predictions.get("images", []))
407
+ processed_data = {}
408
+
409
+ for i in range(num_images):
410
+ processed_data[i] = {
411
+ "image": predictions["images"][i] if "images" in predictions else None,
412
+ "depth": predictions["depths"][i] if "depths" in predictions else None,
413
+ "depth_image": os.path.join(
414
+ target_dir, "depth_vis", f"{i:04d}.jpg" # Fixed: use .jpg not .png
415
+ ),
416
+ "intrinsics": (
417
+ predictions["intrinsics"][i]
418
+ if "intrinsics" in predictions
419
+ and i < len(predictions["intrinsics"])
420
+ else None
421
+ ),
422
+ "mask": None,
423
+ }
424
+
425
+ # Update measure view selector
426
+ choices = [f"View {i + 1}" for i in range(num_images)]
427
+ measure_view_selector = gr.Dropdown(choices=choices, value=choices[0])
428
+
429
+ except Exception as e:
430
+ print(f"Error loading cached data: {e}")
431
+
432
+ # Check for cached 3DGS video
433
+ gs_video_dir = os.path.join(target_dir, "gs_video")
434
+ if os.path.exists(gs_video_dir):
435
+ try:
436
+ from glob import glob
437
+
438
+ gs_videos = sorted(glob(os.path.join(gs_video_dir, "*.mp4")))
439
+ if gs_videos:
440
+ gs_video_path = gs_videos[-1]
441
+ gs_video_visible = True
442
+ gs_info_visible = False
443
+ print(f"Loaded cached 3DGS video: {gs_video_path}")
444
+ except Exception as e:
445
+ print(f"Error loading cached 3DGS video: {e}")
446
+
447
+ return (
448
+ reconstruction_output,
449
+ target_dir,
450
+ image_paths,
451
+ log_message,
452
+ processed_data,
453
+ measure_view_selector,
454
+ gs_video_path,
455
+ gr.update(visible=gs_video_visible),
456
+ gr.update(visible=gs_info_visible),
457
+ )
458
+
459
+ def navigate_depth_view(
460
+ self,
461
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
462
+ current_selector: str,
463
+ direction: int,
464
+ ) -> Tuple[str, Optional[str]]:
465
+ """
466
+ Navigate depth view.
467
+
468
+ Args:
469
+ processed_data: Processed data dictionary
470
+ current_selector: Current selector value
471
+ direction: Direction to navigate
472
+
473
+ Returns:
474
+ Tuple of (new_selector_value, depth_vis)
475
+ """
476
+ return self.visualization_handler.navigate_depth_view(
477
+ processed_data, current_selector, direction
478
+ )
479
+
480
+ def update_depth_view(
481
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
482
+ ) -> Optional[str]:
483
+ """
484
+ Update depth view for a specific view index.
485
+
486
+ Args:
487
+ processed_data: Processed data dictionary
488
+ view_index: Index of the view to update
489
+
490
+ Returns:
491
+ Path to depth visualization image or None
492
+ """
493
+ return self.visualization_handler.update_depth_view(processed_data, view_index)
494
+
495
+ def navigate_measure_view(
496
+ self,
497
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
498
+ current_selector: str,
499
+ direction: int,
500
+ ) -> Tuple[str, Optional[np.ndarray], Optional[np.ndarray], List]:
501
+ """
502
+ Navigate measure view.
503
+
504
+ Args:
505
+ processed_data: Processed data dictionary
506
+ current_selector: Current selector value
507
+ direction: Direction to navigate
508
+
509
+ Returns:
510
+ Tuple of (new_selector_value, measure_image, depth_right_half, measure_points)
511
+ """
512
+ return self.visualization_handler.navigate_measure_view(
513
+ processed_data, current_selector, direction
514
+ )
515
+
516
+ def update_measure_view(
517
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
518
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
519
+ """
520
+ Update measure view for a specific view index.
521
+
522
+ Args:
523
+ processed_data: Processed data dictionary
524
+ view_index: Index of the view to update
525
+
526
+ Returns:
527
+ Tuple of (measure_image, depth_right_half, measure_points)
528
+ """
529
+ return self.visualization_handler.update_measure_view(processed_data, view_index)
530
+
531
+ def measure(
532
+ self,
533
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
534
+ measure_points: List,
535
+ current_view_selector: str,
536
+ event: gr.SelectData,
537
+ ) -> List:
538
+ """
539
+ Handle measurement on images.
540
+
541
+ Args:
542
+ processed_data: Processed data dictionary
543
+ measure_points: List of current measure points
544
+ current_view_selector: Current view selector value
545
+ event: Gradio select event
546
+
547
+ Returns:
548
+ List of [image, depth_right_half, measure_points, text]
549
+ """
550
+ return self.visualization_handler.measure(
551
+ processed_data, measure_points, current_view_selector, event
552
+ )
553
+
554
+ def select_first_frame(
555
+ self, image_gallery: List, selected_index: int = 0
556
+ ) -> Tuple[List, str, str]:
557
+ """
558
+ Select the first frame from the image gallery.
559
+
560
+ Args:
561
+ image_gallery: List of images in the gallery
562
+ selected_index: Index of the selected image (default: 0)
563
+
564
+ Returns:
565
+ Tuple of (updated_image_gallery, log_message, selected_frame_path)
566
+ """
567
+ try:
568
+ if not image_gallery or len(image_gallery) == 0:
569
+ return image_gallery, "No images available to select as first frame.", ""
570
+
571
+ # Handle None or invalid selected_index
572
+ if (
573
+ selected_index is None
574
+ or selected_index < 0
575
+ or selected_index >= len(image_gallery)
576
+ ):
577
+ selected_index = 0
578
+ print(f"Invalid selected_index: {selected_index}, using default: 0")
579
+
580
+ # Get the selected image based on index
581
+ selected_image = image_gallery[selected_index]
582
+ print(f"Selected image index: {selected_index}")
583
+ print(f"Total images: {len(image_gallery)}")
584
+
585
+ # Extract the file path from the selected image
586
+ selected_frame_path = ""
587
+ print(f"Selected image type: {type(selected_image)}")
588
+ print(f"Selected image: {selected_image}")
589
+
590
+ if isinstance(selected_image, tuple):
591
+ # Gradio Gallery returns tuple (path, None)
592
+ selected_frame_path = selected_image[0]
593
+ elif isinstance(selected_image, str):
594
+ selected_frame_path = selected_image
595
+ elif hasattr(selected_image, "name"):
596
+ selected_frame_path = selected_image.name
597
+ elif isinstance(selected_image, dict):
598
+ if "name" in selected_image:
599
+ selected_frame_path = selected_image["name"]
600
+ elif "path" in selected_image:
601
+ selected_frame_path = selected_image["path"]
602
+ elif "src" in selected_image:
603
+ selected_frame_path = selected_image["src"]
604
+ else:
605
+ # Try to convert to string
606
+ selected_frame_path = str(selected_image)
607
+
608
+ print(f"Extracted path: {selected_frame_path}")
609
+
610
+ # Extract filename from the path for matching
611
+ import os
612
+
613
+ selected_filename = os.path.basename(selected_frame_path)
614
+ print(f"Selected filename: {selected_filename}")
615
+
616
+ # Move the selected image to the front
617
+ updated_gallery = [selected_image] + [
618
+ img for img in image_gallery if img != selected_image
619
+ ]
620
+
621
+ log_message = (
622
+ f"Selected frame: {selected_filename}. "
623
+ f"Moved to first position. Total frames: {len(updated_gallery)}"
624
+ )
625
+ return updated_gallery, log_message, selected_filename
626
+
627
+ except Exception as e:
628
+ print(f"Error selecting first frame: {e}")
629
+ return image_gallery, f"Error selecting first frame: {e}", ""
src/depth_anything_3/app/modules/file_handlers.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ File handling module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles file uploads, video processing, and file operations.
19
+ """
20
+
21
+ import os
22
+ import shutil
23
+ import time
24
+ from datetime import datetime
25
+ from typing import List, Optional, Tuple
26
+ import cv2
27
+ from PIL import Image
28
+ from pillow_heif import register_heif_opener
29
+
30
+ register_heif_opener()
31
+
32
+
33
+ class FileHandler:
34
+ """
35
+ Handles file uploads and processing for the Gradio app.
36
+ """
37
+
38
+ def __init__(self):
39
+ """Initialize the file handler."""
40
+
41
+ def handle_uploads(
42
+ self,
43
+ input_video: Optional[str],
44
+ input_images: Optional[List],
45
+ s_time_interval: float = 10.0,
46
+ ) -> Tuple[str, List[str]]:
47
+ """
48
+ Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
49
+ images or extracted frames from video into it.
50
+
51
+ Args:
52
+ input_video: Path to input video file
53
+ input_images: List of input image files
54
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
55
+
56
+ Returns:
57
+ Tuple of (target_dir, image_paths)
58
+ """
59
+ start_time = time.time()
60
+
61
+ # Get workspace directory from environment variable or use default
62
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
63
+ if not os.path.exists(workspace_dir):
64
+ os.makedirs(workspace_dir)
65
+
66
+ # Create input_images subdirectory
67
+ input_images_dir = os.path.join(workspace_dir, "input_images")
68
+ if not os.path.exists(input_images_dir):
69
+ os.makedirs(input_images_dir)
70
+
71
+ # Create a unique folder name within input_images
72
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
73
+ target_dir = os.path.join(input_images_dir, f"session_{timestamp}")
74
+ target_dir_images = os.path.join(target_dir, "images")
75
+
76
+ # Clean up if somehow that folder already exists
77
+ if os.path.exists(target_dir):
78
+ shutil.rmtree(target_dir)
79
+ os.makedirs(target_dir)
80
+ os.makedirs(target_dir_images)
81
+
82
+ image_paths = []
83
+
84
+ # Handle images
85
+ if input_images is not None:
86
+ image_paths.extend(self._process_images(input_images, target_dir_images))
87
+
88
+ # Handle video
89
+ if input_video is not None:
90
+ image_paths.extend(
91
+ self._process_video(input_video, target_dir_images, s_time_interval)
92
+ )
93
+
94
+ # Sort final images for gallery
95
+ image_paths = sorted(image_paths)
96
+
97
+ end_time = time.time()
98
+ print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
99
+ return target_dir, image_paths
100
+
101
+ def _process_images(self, input_images: List, target_dir_images: str) -> List[str]:
102
+ """
103
+ Process uploaded images.
104
+
105
+ Args:
106
+ input_images: List of input image files
107
+ target_dir_images: Target directory for images
108
+
109
+ Returns:
110
+ List of processed image paths
111
+ """
112
+ image_paths = []
113
+
114
+ for file_data in input_images:
115
+ if isinstance(file_data, dict) and "name" in file_data:
116
+ file_path = file_data["name"]
117
+ else:
118
+ file_path = file_data
119
+
120
+ # Check if the file is a HEIC image
121
+ file_ext = os.path.splitext(file_path)[1].lower()
122
+ if file_ext in [".heic", ".heif"]:
123
+ # Convert HEIC to JPEG for better gallery compatibility
124
+ try:
125
+ with Image.open(file_path) as img:
126
+ # Convert to RGB if necessary (HEIC can have different color modes)
127
+ if img.mode not in ("RGB", "L"):
128
+ img = img.convert("RGB")
129
+
130
+ # Create JPEG filename
131
+ base_name = os.path.splitext(os.path.basename(file_path))[0]
132
+ dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
133
+
134
+ # Save as JPEG with high quality
135
+ img.save(dst_path, "JPEG", quality=95)
136
+ image_paths.append(dst_path)
137
+ print(
138
+ f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> "
139
+ f"{os.path.basename(dst_path)}"
140
+ )
141
+ except Exception as e:
142
+ print(f"Error converting HEIC file {file_path}: {e}")
143
+ # Fall back to copying as is
144
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
145
+ shutil.copy(file_path, dst_path)
146
+ image_paths.append(dst_path)
147
+ else:
148
+ # Regular image files - copy as is
149
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
150
+ shutil.copy(file_path, dst_path)
151
+ image_paths.append(dst_path)
152
+
153
+ return image_paths
154
+
155
+ def _process_video(
156
+ self, input_video: str, target_dir_images: str, s_time_interval: float
157
+ ) -> List[str]:
158
+ """
159
+ Process video file and extract frames.
160
+
161
+ Args:
162
+ input_video: Path to input video file
163
+ target_dir_images: Target directory for extracted frames
164
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
165
+
166
+ Returns:
167
+ List of extracted frame paths
168
+ """
169
+ image_paths = []
170
+
171
+ if isinstance(input_video, dict) and "name" in input_video:
172
+ video_path = input_video["name"]
173
+ else:
174
+ video_path = input_video
175
+
176
+ vs = cv2.VideoCapture(video_path)
177
+ fps = vs.get(cv2.CAP_PROP_FPS)
178
+ frame_interval = max(1, int(fps / s_time_interval)) # Convert FPS to frame interval
179
+
180
+ count = 0
181
+ video_frame_num = 0
182
+ while True:
183
+ gotit, frame = vs.read()
184
+ if not gotit:
185
+ break
186
+ count += 1
187
+ if count % frame_interval == 0:
188
+ image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
189
+ cv2.imwrite(image_path, frame)
190
+ image_paths.append(image_path)
191
+ video_frame_num += 1
192
+
193
+ return image_paths
194
+
195
+ def update_gallery_on_upload(
196
+ self,
197
+ input_video: Optional[str],
198
+ input_images: Optional[List],
199
+ s_time_interval: float = 10.0,
200
+ ) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
201
+ """
202
+ Handle file uploads and update gallery.
203
+
204
+ Args:
205
+ input_video: Path to input video file
206
+ input_images: List of input image files
207
+ s_time_interval: Sampling FPS (frames per second) for frame extraction
208
+
209
+ Returns:
210
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message)
211
+ """
212
+ if not input_video and not input_images:
213
+ return None, None, None, None
214
+
215
+ target_dir, image_paths = self.handle_uploads(input_video, input_images, s_time_interval)
216
+ return (
217
+ None,
218
+ target_dir,
219
+ image_paths,
220
+ "Upload complete. Click 'Reconstruct' to begin 3D processing.",
221
+ )
222
+
223
+ def load_example_scene(
224
+ self, scene_name: str, examples_dir: str = "examples"
225
+ ) -> Tuple[Optional[str], Optional[str], Optional[List], str]:
226
+ """
227
+ Load a scene from examples directory.
228
+
229
+ Args:
230
+ scene_name: Name of the scene to load
231
+ examples_dir: Path to examples directory
232
+
233
+ Returns:
234
+ Tuple of (reconstruction_output, target_dir, image_paths, log_message)
235
+ """
236
+ from depth_anything_3.app.modules.utils import get_scene_info
237
+
238
+ scenes = get_scene_info(examples_dir)
239
+
240
+ # Find the selected scene
241
+ selected_scene = None
242
+ for scene in scenes:
243
+ if scene["name"] == scene_name:
244
+ selected_scene = scene
245
+ break
246
+
247
+ if selected_scene is None:
248
+ return None, None, None, "Scene not found"
249
+
250
+ # Use fixed directory name for examples (not timestamp-based)
251
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
252
+ input_images_dir = os.path.join(workspace_dir, "input_images")
253
+ if not os.path.exists(input_images_dir):
254
+ os.makedirs(input_images_dir)
255
+
256
+ # Create a fixed folder name based on scene name
257
+ target_dir = os.path.join(input_images_dir, f"example_{scene_name}")
258
+ target_dir_images = os.path.join(target_dir, "images")
259
+
260
+ # Check if already cached (GLB file exists)
261
+ glb_path = os.path.join(target_dir, "scene.glb")
262
+ is_cached = os.path.exists(glb_path)
263
+
264
+ # Create directory if it doesn't exist
265
+ if not os.path.exists(target_dir):
266
+ os.makedirs(target_dir)
267
+ os.makedirs(target_dir_images)
268
+
269
+ # Copy images if directory is new or empty
270
+ if not os.path.exists(target_dir_images) or len(os.listdir(target_dir_images)) == 0:
271
+ os.makedirs(target_dir_images, exist_ok=True)
272
+ image_paths = []
273
+ for file_path in selected_scene["image_files"]:
274
+ dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
275
+ shutil.copy(file_path, dst_path)
276
+ image_paths.append(dst_path)
277
+ else:
278
+ # Use existing images
279
+ image_paths = sorted(
280
+ [
281
+ os.path.join(target_dir_images, f)
282
+ for f in os.listdir(target_dir_images)
283
+ if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"))
284
+ ]
285
+ )
286
+
287
+ # Return cached GLB if available
288
+ if is_cached:
289
+ return (
290
+ glb_path, # Return cached reconstruction
291
+ target_dir, # Set target directory
292
+ image_paths, # Set gallery
293
+ f"Loaded cached scene '{scene_name}' with {selected_scene['num_images']} images.",
294
+ )
295
+ else:
296
+ return (
297
+ None, # No cached reconstruction
298
+ target_dir, # Set target directory
299
+ image_paths, # Set gallery
300
+ (
301
+ f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. "
302
+ "Click 'Reconstruct' to begin 3D processing."
303
+ ),
304
+ )
src/depth_anything_3/app/modules/model_inference.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Model inference module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles all model-related operations including inference,
19
+ data processing, and result preparation.
20
+ """
21
+
22
+ import gc
23
+ import glob
24
+ import os
25
+ from typing import Any, Dict, Optional, Tuple
26
+ import numpy as np
27
+ import torch
28
+
29
+ from depth_anything_3.api import DepthAnything3
30
+ from depth_anything_3.utils.export.glb import export_to_glb
31
+ from depth_anything_3.utils.export.gs import export_to_gs_video
32
+
33
+
34
+ class ModelInference:
35
+ """
36
+ Handles model inference and data processing for Depth Anything 3.
37
+ """
38
+
39
+ def __init__(self):
40
+ """Initialize the model inference handler."""
41
+ self.model = None
42
+
43
+ def initialize_model(self, device: str = "cuda") -> None:
44
+ """
45
+ Initialize the DepthAnything3 model.
46
+
47
+ Args:
48
+ device: Device to load the model on
49
+ """
50
+ if self.model is None:
51
+ # Get model directory from environment variable or use default
52
+ model_dir = os.environ.get(
53
+ "DA3_MODEL_DIR", "/dev/shm/da3_models/DA3HF-VITG-METRIC_VITL"
54
+ )
55
+ self.model = DepthAnything3.from_pretrained(model_dir)
56
+ self.model = self.model.to(device)
57
+ else:
58
+ self.model = self.model.to(device)
59
+
60
+ self.model.eval()
61
+
62
+ def run_inference(
63
+ self,
64
+ target_dir: str,
65
+ filter_black_bg: bool = False,
66
+ filter_white_bg: bool = False,
67
+ process_res_method: str = "upper_bound_resize",
68
+ show_camera: bool = True,
69
+ selected_first_frame: Optional[str] = None,
70
+ save_percentage: float = 30.0,
71
+ num_max_points: int = 1_000_000,
72
+ infer_gs: bool = False,
73
+ gs_trj_mode: str = "extend",
74
+ gs_video_quality: str = "high",
75
+ ) -> Tuple[Any, Dict[int, Dict[str, Any]]]:
76
+ """
77
+ Run DepthAnything3 model inference on images.
78
+
79
+ Args:
80
+ target_dir: Directory containing images
81
+ apply_mask: Whether to apply mask for ambiguous depth classes
82
+ mask_edges: Whether to mask edges
83
+ filter_black_bg: Whether to filter black background
84
+ filter_white_bg: Whether to filter white background
85
+ process_res_method: Method for resizing input images
86
+ show_camera: Whether to show camera in 3D view
87
+ selected_first_frame: Selected first frame filename
88
+ save_percentage: Percentage of points to save (0-100)
89
+ infer_gs: Whether to infer 3D Gaussian Splatting
90
+
91
+ Returns:
92
+ Tuple of (prediction, processed_data)
93
+ """
94
+ print(f"Processing images from {target_dir}")
95
+
96
+ # Device check
97
+ device = "cuda" if torch.cuda.is_available() else "cpu"
98
+ device = torch.device(device)
99
+
100
+ # Initialize model if needed
101
+ self.initialize_model(device)
102
+
103
+ # Get image paths
104
+ print("Loading images...")
105
+ image_folder_path = os.path.join(target_dir, "images")
106
+ all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*")))
107
+
108
+ # Filter for image files
109
+ image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
110
+ all_image_paths = [
111
+ path
112
+ for path in all_image_paths
113
+ if any(path.lower().endswith(ext) for ext in image_extensions)
114
+ ]
115
+
116
+ print(f"Found {len(all_image_paths)} images")
117
+ print(f"All image paths: {all_image_paths}")
118
+
119
+ # Apply first frame selection logic
120
+ if selected_first_frame:
121
+ # Find the image with matching filename
122
+ selected_path = None
123
+ for path in all_image_paths:
124
+ if os.path.basename(path) == selected_first_frame:
125
+ selected_path = path
126
+ break
127
+
128
+ if selected_path:
129
+ # Move selected frame to the front
130
+ image_paths = [selected_path] + [
131
+ path for path in all_image_paths if path != selected_path
132
+ ]
133
+ print(f"User selected first frame: {selected_first_frame} -> {selected_path}")
134
+ print(f"Reordered image paths: {image_paths}")
135
+ else:
136
+ # Use default order if no match found
137
+ image_paths = all_image_paths
138
+ print(
139
+ f"Selected frame '{selected_first_frame}' not found in image paths. "
140
+ "Using default order."
141
+ )
142
+ first_frame_display = image_paths[0] if image_paths else "No images"
143
+ print(f"Using default order (first frame): {first_frame_display}")
144
+ else:
145
+ # Use default order (sorted)
146
+ image_paths = all_image_paths
147
+ first_frame_display = image_paths[0] if image_paths else "No images"
148
+ print(f"Using default order (first frame): {first_frame_display}")
149
+
150
+ if len(image_paths) == 0:
151
+ raise ValueError("No images found. Check your upload.")
152
+
153
+ # Map UI options to actual method names
154
+ method_mapping = {"high_res": "lower_bound_resize", "low_res": "upper_bound_resize"}
155
+ actual_method = method_mapping.get(process_res_method, "upper_bound_crop")
156
+
157
+ # Run model inference
158
+ print(f"Running inference with method: {actual_method}")
159
+ with torch.no_grad():
160
+ prediction = self.model.inference(
161
+ image_paths, export_dir=None, process_res_method=actual_method, infer_gs=infer_gs
162
+ )
163
+ # num_max_points: int = 1_000_000,
164
+ export_to_glb(
165
+ prediction,
166
+ filter_black_bg=filter_black_bg,
167
+ filter_white_bg=filter_white_bg,
168
+ export_dir=target_dir,
169
+ show_cameras=show_camera,
170
+ conf_thresh_percentile=save_percentage,
171
+ num_max_points=int(num_max_points),
172
+ )
173
+
174
+ # export to gs video if needed
175
+ if infer_gs:
176
+ mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"}
177
+ print(f"GS mode: {gs_trj_mode}; Backend mode: {mode_mapping[gs_trj_mode]}")
178
+ export_to_gs_video(
179
+ prediction,
180
+ export_dir=target_dir,
181
+ chunk_size=4,
182
+ trj_mode=mode_mapping.get(gs_trj_mode, "extend"),
183
+ enable_tqdm=True,
184
+ vis_depth="hcat",
185
+ video_quality=gs_video_quality,
186
+ )
187
+
188
+ # Save predictions.npz for caching metric depth data
189
+ self._save_predictions_cache(target_dir, prediction)
190
+
191
+ # Process results
192
+ processed_data = self._process_results(target_dir, prediction, image_paths)
193
+
194
+ # Clean up
195
+ torch.cuda.empty_cache()
196
+
197
+ return prediction, processed_data
198
+
199
+ def _save_predictions_cache(self, target_dir: str, prediction: Any) -> None:
200
+ """
201
+ Save predictions data to predictions.npz for caching.
202
+
203
+ Args:
204
+ target_dir: Directory to save the cache
205
+ prediction: Model prediction object
206
+ """
207
+ try:
208
+ output_file = os.path.join(target_dir, "predictions.npz")
209
+
210
+ # Build save dict with prediction data
211
+ save_dict = {}
212
+
213
+ # Save processed images if available
214
+ if prediction.processed_images is not None:
215
+ save_dict["images"] = prediction.processed_images
216
+
217
+ # Save depth data
218
+ if prediction.depth is not None:
219
+ save_dict["depths"] = np.round(prediction.depth, 6)
220
+
221
+ # Save confidence if available
222
+ if prediction.conf is not None:
223
+ save_dict["conf"] = np.round(prediction.conf, 2)
224
+
225
+ # Save camera parameters
226
+ if prediction.extrinsics is not None:
227
+ save_dict["extrinsics"] = prediction.extrinsics
228
+ if prediction.intrinsics is not None:
229
+ save_dict["intrinsics"] = prediction.intrinsics
230
+
231
+ # Save to file
232
+ np.savez_compressed(output_file, **save_dict)
233
+ print(f"Saved predictions cache to: {output_file}")
234
+
235
+ except Exception as e:
236
+ print(f"Warning: Failed to save predictions cache: {e}")
237
+
238
+ def _process_results(
239
+ self, target_dir: str, prediction: Any, image_paths: list
240
+ ) -> Dict[int, Dict[str, Any]]:
241
+ """
242
+ Process model results into structured data.
243
+
244
+ Args:
245
+ target_dir: Directory containing results
246
+ prediction: Model prediction object
247
+ image_paths: List of input image paths
248
+
249
+ Returns:
250
+ Dictionary containing processed data for each view
251
+ """
252
+ processed_data = {}
253
+
254
+ # Read generated depth visualization files
255
+ depth_vis_dir = os.path.join(target_dir, "depth_vis")
256
+
257
+ if os.path.exists(depth_vis_dir):
258
+ depth_files = sorted(glob.glob(os.path.join(depth_vis_dir, "*.jpg")))
259
+ for i, depth_file in enumerate(depth_files):
260
+ # Use processed images directly from API
261
+ processed_image = None
262
+ if prediction.processed_images is not None and i < len(
263
+ prediction.processed_images
264
+ ):
265
+ processed_image = prediction.processed_images[i]
266
+
267
+ processed_data[i] = {
268
+ "depth_image": depth_file,
269
+ "image": processed_image,
270
+ "original_image_path": image_paths[i] if i < len(image_paths) else None,
271
+ "depth": prediction.depth[i] if i < len(prediction.depth) else None,
272
+ "intrinsics": (
273
+ prediction.intrinsics[i]
274
+ if prediction.intrinsics is not None and i < len(prediction.intrinsics)
275
+ else None
276
+ ),
277
+ "mask": None, # No mask information available
278
+ }
279
+
280
+ return processed_data
281
+
282
+ def cleanup(self) -> None:
283
+ """Clean up GPU memory."""
284
+ if torch.cuda.is_available():
285
+ torch.cuda.empty_cache()
286
+ gc.collect()
src/depth_anything_3/app/modules/ui_components.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ UI components module for Depth Anything 3 Gradio app.
17
+
18
+ This module contains UI component definitions and layout functions.
19
+ """
20
+
21
+ import os
22
+ from typing import Any, Dict, List, Tuple
23
+ import gradio as gr
24
+
25
+ from depth_anything_3.app.modules.utils import get_logo_base64, get_scene_info
26
+
27
+
28
+ class UIComponents:
29
+ """
30
+ Handles UI component creation and layout for the Gradio app.
31
+ """
32
+
33
+ def __init__(self):
34
+ """Initialize the UI components handler."""
35
+
36
+ def create_upload_section(self) -> Tuple[gr.Video, gr.Slider, gr.File, gr.Gallery, gr.Button]:
37
+ """
38
+ Create the upload section with video, images, and gallery components.
39
+
40
+ Returns:
41
+ A tuple of Gradio components: (input_video, s_time_interval, input_images,
42
+ image_gallery, select_first_frame_btn).
43
+ """
44
+ input_video = gr.Video(label="Upload Video", interactive=True)
45
+ s_time_interval = gr.Slider(
46
+ minimum=0.1,
47
+ maximum=60,
48
+ value=10,
49
+ step=0.1,
50
+ label="Sampling FPS (Frames Per Second)",
51
+ interactive=True,
52
+ visible=True,
53
+ )
54
+ input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
55
+ image_gallery = gr.Gallery(
56
+ label="Preview",
57
+ columns=4,
58
+ height="300px",
59
+ show_download_button=True,
60
+ object_fit="contain",
61
+ preview=True,
62
+ interactive=False,
63
+ )
64
+
65
+ # Select first frame button (moved below image gallery)
66
+ select_first_frame_btn = gr.Button("Select First Frame", scale=1)
67
+
68
+ return input_video, s_time_interval, input_images, image_gallery, select_first_frame_btn
69
+
70
+ def create_3d_viewer_section(self) -> gr.Model3D:
71
+ """
72
+ Create the 3D viewer component.
73
+
74
+ Returns:
75
+ 3D model viewer component
76
+ """
77
+ return gr.Model3D(
78
+ height=520,
79
+ zoom_speed=0.5,
80
+ pan_speed=0.5,
81
+ clear_color=[0.0, 0.0, 0.0, 0.0],
82
+ key="persistent_3d_viewer",
83
+ elem_id="reconstruction_3d_viewer",
84
+ )
85
+
86
+ def create_nvs_video(self) -> Tuple[gr.Video, gr.Markdown]:
87
+ """
88
+ Create the 3DGS rendered video display component and info message.
89
+
90
+ Returns:
91
+ Tuple of (video component, info message component)
92
+ """
93
+ with gr.Column():
94
+ gs_info = gr.Markdown(
95
+ (
96
+ "‼️ **3D Gaussian Splatting rendering is currently DISABLED.** <br><br><br>"
97
+ "To render novel views from 3DGS, "
98
+ "enable **Infer 3D Gaussian Splatting** below. <br>"
99
+ "Next, in **Visualization Options**, "
100
+ "*optionally* configure the **rendering trajectory** (default: smooth) "
101
+ "and **video quality** (default: low), "
102
+ "then click **Reconstruct**."
103
+ ),
104
+ visible=True,
105
+ height=520,
106
+ )
107
+ gs_video = gr.Video(
108
+ height=520,
109
+ label="3DGS Rendered NVS Video (depth shown for reference only)",
110
+ interactive=False,
111
+ visible=False,
112
+ )
113
+ return gs_video, gs_info
114
+
115
+ def create_depth_section(self) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image]:
116
+ """
117
+ Create the depth visualization section.
118
+
119
+ Returns:
120
+ A tuple of (prev_depth_btn, depth_view_selector, next_depth_btn, depth_map)
121
+ """
122
+ with gr.Row(elem_classes=["navigation-row"]):
123
+ prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
124
+ depth_view_selector = gr.Dropdown(
125
+ choices=["View 1"],
126
+ value="View 1",
127
+ label="Select View",
128
+ scale=2,
129
+ interactive=True,
130
+ allow_custom_value=True,
131
+ )
132
+ next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
133
+ depth_map = gr.Image(
134
+ type="numpy",
135
+ label="Colorized Depth Map",
136
+ format="png",
137
+ interactive=False,
138
+ )
139
+
140
+ return prev_depth_btn, depth_view_selector, next_depth_btn, depth_map
141
+
142
+ def create_measure_section(
143
+ self,
144
+ ) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image, gr.Image, gr.Markdown]:
145
+ """
146
+ Create the measurement section.
147
+
148
+ Returns:
149
+ A tuple of (prev_measure_btn, measure_view_selector, next_measure_btn, measure_image,
150
+ measure_depth_image, measure_text)
151
+ """
152
+ from depth_anything_3.app.css_and_html import MEASURE_INSTRUCTIONS_HTML
153
+
154
+ gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
155
+ with gr.Row(elem_classes=["navigation-row"]):
156
+ prev_measure_btn = gr.Button("◀ Previous", size="sm", scale=1)
157
+ measure_view_selector = gr.Dropdown(
158
+ choices=["View 1"],
159
+ value="View 1",
160
+ label="Select View",
161
+ scale=2,
162
+ interactive=True,
163
+ allow_custom_value=True,
164
+ )
165
+ next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
166
+ with gr.Row():
167
+ measure_image = gr.Image(
168
+ type="numpy",
169
+ show_label=False,
170
+ format="webp",
171
+ interactive=False,
172
+ sources=[],
173
+ label="RGB Image",
174
+ scale=1,
175
+ height=275,
176
+ )
177
+ measure_depth_image = gr.Image(
178
+ type="numpy",
179
+ show_label=False,
180
+ format="webp",
181
+ interactive=False,
182
+ sources=[],
183
+ label="Depth Visualization (Right Half)",
184
+ scale=1,
185
+ height=275,
186
+ )
187
+ gr.Markdown(
188
+ "**Note:** Images have been adjusted to model processing size. "
189
+ "Click two points on the RGB image to measure distance."
190
+ )
191
+ measure_text = gr.Markdown("")
192
+
193
+ return (
194
+ prev_measure_btn,
195
+ measure_view_selector,
196
+ next_measure_btn,
197
+ measure_image,
198
+ measure_depth_image,
199
+ measure_text,
200
+ )
201
+
202
+ def create_inference_control_section(self) -> Tuple[gr.Dropdown, gr.Checkbox]:
203
+ """
204
+ Create the inference control section (before inference).
205
+
206
+ Returns:
207
+ Tuple of (process_res_method_dropdown, infer_gs)
208
+ """
209
+ with gr.Row():
210
+ process_res_method_dropdown = gr.Dropdown(
211
+ choices=["high_res", "low_res"],
212
+ value="low_res",
213
+ label="Image Processing Method",
214
+ info="low_res for much more images",
215
+ scale=1,
216
+ )
217
+ # Modify line 220, add color class
218
+ infer_gs = gr.Checkbox(
219
+ label="Infer 3D Gaussian Splatting",
220
+ value=False,
221
+ info=(
222
+ 'Enable novel view rendering from 3DGS (<i class="fas fa-triangle-exclamation '
223
+ 'fa-color-red"></i> requires extra processing time)'
224
+ ),
225
+ scale=1,
226
+ )
227
+
228
+ return (process_res_method_dropdown, infer_gs)
229
+
230
+ def create_display_control_section(
231
+ self,
232
+ ) -> Tuple[
233
+ gr.Checkbox,
234
+ gr.Checkbox,
235
+ gr.Checkbox,
236
+ gr.Slider,
237
+ gr.Slider,
238
+ gr.Dropdown,
239
+ gr.Dropdown,
240
+ gr.Button,
241
+ gr.ClearButton,
242
+ ]:
243
+ """
244
+ Create the display control section (options for visualization).
245
+
246
+ Returns:
247
+ Tuple of display control components including buttons
248
+ """
249
+ with gr.Column():
250
+ # 3DGS options at the top
251
+ with gr.Row():
252
+ gs_trj_mode = gr.Dropdown(
253
+ choices=["smooth", "extend"],
254
+ value="smooth",
255
+ label=("Rendering trajectory for 3DGS viewpoints (requires n_views ≥ 2)"),
256
+ info=("'smooth' for view interpolation; 'extend' for longer trajectory"),
257
+ visible=False, # initially hidden
258
+ )
259
+ gs_video_quality = gr.Dropdown(
260
+ choices=["low", "medium", "high"],
261
+ value="low",
262
+ label=("Video quality for 3DGS rendered outputs"),
263
+ info=("'low' for faster loading speed; 'high' for better visual quality"),
264
+ visible=False, # initially hidden
265
+ )
266
+
267
+ # Reconstruct and Clear buttons (before Visualization Options)
268
+ with gr.Row():
269
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
270
+ clear_btn = gr.ClearButton(scale=1)
271
+
272
+ gr.Markdown("### Visualization Options: (Click Reconstruct to update)")
273
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
274
+ filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
275
+ filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
276
+ save_percentage = gr.Slider(
277
+ minimum=0,
278
+ maximum=100,
279
+ value=10,
280
+ step=1,
281
+ label="Filter Percentage",
282
+ info="Confidence Threshold (%): Higher values filter more points.",
283
+ )
284
+ num_max_points = gr.Slider(
285
+ minimum=1000,
286
+ maximum=100000,
287
+ value=1000,
288
+ step=1000,
289
+ label="Max Points (K points)",
290
+ info="Maximum number of points to export to GLB (in thousands)",
291
+ )
292
+
293
+ return (
294
+ show_cam,
295
+ filter_black_bg,
296
+ filter_white_bg,
297
+ save_percentage,
298
+ num_max_points,
299
+ gs_trj_mode,
300
+ gs_video_quality,
301
+ submit_btn,
302
+ clear_btn,
303
+ )
304
+
305
+ def create_control_section(
306
+ self,
307
+ ) -> Tuple[
308
+ gr.Button,
309
+ gr.ClearButton,
310
+ gr.Dropdown,
311
+ gr.Checkbox,
312
+ gr.Checkbox,
313
+ gr.Checkbox,
314
+ gr.Checkbox,
315
+ gr.Checkbox,
316
+ gr.Dropdown,
317
+ gr.Checkbox,
318
+ gr.Textbox,
319
+ ]:
320
+ """
321
+ Create the control section with buttons and options.
322
+
323
+ Returns:
324
+ Tuple of control components
325
+ """
326
+ with gr.Row():
327
+ submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
328
+ clear_btn = gr.ClearButton(
329
+ scale=1,
330
+ )
331
+
332
+ with gr.Row():
333
+ frame_filter = gr.Dropdown(
334
+ choices=["All"], value="All", label="Show Points from Frame"
335
+ )
336
+ with gr.Column():
337
+ gr.Markdown("### Visualization Option: (Click Reconstruct to update)")
338
+ show_cam = gr.Checkbox(label="Show Camera", value=True)
339
+ show_mesh = gr.Checkbox(label="Show Mesh", value=True)
340
+ filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
341
+ filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
342
+ gr.Markdown("### Reconstruction Options: (updated on next run)")
343
+ apply_mask_checkbox = gr.Checkbox(
344
+ label="Apply mask for predicted ambiguous depth classes & edges",
345
+ value=True,
346
+ )
347
+ process_res_method_dropdown = gr.Dropdown(
348
+ choices=[
349
+ "upper_bound_resize",
350
+ "upper_bound_crop",
351
+ "lower_bound_resize",
352
+ "lower_bound_crop",
353
+ ],
354
+ value="upper_bound_resize",
355
+ label="Image Processing Method",
356
+ info="Method for resizing input images",
357
+ )
358
+ save_to_gallery_checkbox = gr.Checkbox(
359
+ label="Save to Gallery",
360
+ value=False,
361
+ info="Save current reconstruction results to gallery directory",
362
+ )
363
+ gallery_name_input = gr.Textbox(
364
+ label="Gallery Name",
365
+ placeholder="Enter a name for the gallery folder",
366
+ value="",
367
+ info="Leave empty for auto-generated name with timestamp",
368
+ )
369
+
370
+ return (
371
+ submit_btn,
372
+ clear_btn,
373
+ frame_filter,
374
+ show_cam,
375
+ show_mesh,
376
+ filter_black_bg,
377
+ filter_white_bg,
378
+ apply_mask_checkbox,
379
+ process_res_method_dropdown,
380
+ save_to_gallery_checkbox,
381
+ gallery_name_input,
382
+ )
383
+
384
+ def create_example_scenes_section(self) -> List[Dict[str, Any]]:
385
+ """
386
+ Create the example scenes section.
387
+
388
+ Returns:
389
+ List of scene information dictionaries
390
+ """
391
+ # Get workspace directory from environment variable
392
+ workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
393
+ examples_dir = os.path.join(workspace_dir, "examples")
394
+
395
+ # Get scene information
396
+ scenes = get_scene_info(examples_dir)
397
+
398
+ return scenes
399
+
400
+ def create_example_scene_grid(self, scenes: List[Dict[str, Any]]) -> List[gr.Image]:
401
+ """
402
+ Create the example scene grid.
403
+
404
+ Args:
405
+ scenes: List of scene information dictionaries
406
+
407
+ Returns:
408
+ List of scene image components
409
+ """
410
+ scene_components = []
411
+
412
+ if scenes:
413
+ for i in range(0, len(scenes), 4): # Process 4 scenes per row
414
+ with gr.Row():
415
+ for j in range(4):
416
+ scene_idx = i + j
417
+ if scene_idx < len(scenes):
418
+ scene = scenes[scene_idx]
419
+ with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
420
+ # Clickable thumbnail
421
+ scene_img = gr.Image(
422
+ value=scene["thumbnail"],
423
+ height=150,
424
+ interactive=False,
425
+ show_label=False,
426
+ elem_id=f"scene_thumb_{scene['name']}",
427
+ sources=[],
428
+ )
429
+ scene_components.append(scene_img)
430
+
431
+ # Scene name and image count as text below thumbnail
432
+ gr.Markdown(
433
+ f"**{scene['name']}** \n {scene['num_images']} images",
434
+ elem_classes=["scene-info"],
435
+ )
436
+ else:
437
+ # Empty column to maintain grid structure
438
+ with gr.Column(scale=1):
439
+ pass
440
+
441
+ return scene_components
442
+
443
+ def create_header_section(self) -> gr.HTML:
444
+ """
445
+ Create the header section with logo and title.
446
+
447
+ Returns:
448
+ Header HTML component
449
+ """
450
+ from depth_anything_3.app.css_and_html import get_header_html
451
+
452
+ return gr.HTML(get_header_html(get_logo_base64()))
453
+
454
+ def create_description_section(self) -> gr.HTML:
455
+ """
456
+ Create the description section.
457
+
458
+ Returns:
459
+ Description HTML component
460
+ """
461
+ from depth_anything_3.app.css_and_html import get_description_html
462
+
463
+ return gr.HTML(get_description_html())
464
+
465
+ def create_acknowledgements_section(self) -> gr.HTML:
466
+ """
467
+ Create the acknowledgements section.
468
+
469
+ Returns:
470
+ Acknowledgements HTML component
471
+ """
472
+ from depth_anything_3.app.css_and_html import get_acknowledgements_html
473
+
474
+ return gr.HTML(get_acknowledgements_html())
src/depth_anything_3/app/modules/utils.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Utility functions for Depth Anything 3 Gradio app.
17
+
18
+ This module contains helper functions for data processing, visualization,
19
+ and file operations.
20
+ """
21
+
22
+ import gc
23
+ import json
24
+ import os
25
+ import shutil
26
+ from datetime import datetime
27
+ from typing import Any, Dict, List, Optional, Tuple
28
+ import numpy as np
29
+ import torch
30
+
31
+
32
+ def create_depth_visualization(depth: np.ndarray) -> Optional[np.ndarray]:
33
+ """
34
+ Create a colored depth visualization.
35
+
36
+ Args:
37
+ depth: Depth array
38
+
39
+ Returns:
40
+ Colored depth visualization or None
41
+ """
42
+ if depth is None:
43
+ return None
44
+
45
+ # Normalize depth to 0-1 range
46
+ depth_min = depth[depth > 0].min() if (depth > 0).any() else 0
47
+ depth_max = depth.max()
48
+
49
+ if depth_max <= depth_min:
50
+ return None
51
+
52
+ # Normalize depth
53
+ depth_norm = (depth - depth_min) / (depth_max - depth_min)
54
+ depth_norm = np.clip(depth_norm, 0, 1)
55
+
56
+ # Apply colormap (using matplotlib's viridis colormap)
57
+ import matplotlib.cm as cm
58
+
59
+ # Convert to colored image
60
+ depth_colored = cm.viridis(depth_norm)[:, :, :3] # Remove alpha channel
61
+ depth_colored = (depth_colored * 255).astype(np.uint8)
62
+
63
+ return depth_colored
64
+
65
+
66
+ def save_to_gallery_func(
67
+ target_dir: str, processed_data: Dict[int, Dict[str, Any]], gallery_name: Optional[str] = None
68
+ ) -> Tuple[bool, str]:
69
+ """
70
+ Save the current reconstruction results to the gallery directory.
71
+
72
+ Args:
73
+ target_dir: Source directory containing reconstruction results
74
+ processed_data: Processed data dictionary
75
+ gallery_name: Name for the gallery folder
76
+
77
+ Returns:
78
+ Tuple of (success, message)
79
+ """
80
+ try:
81
+ # Get gallery directory from environment variable or use default
82
+ gallery_dir = os.environ.get(
83
+ "DA3_GALLERY_DIR",
84
+ "workspace/gallery",
85
+ )
86
+ if not os.path.exists(gallery_dir):
87
+ os.makedirs(gallery_dir)
88
+
89
+ # Use provided name or create a unique name
90
+ if gallery_name is None or gallery_name.strip() == "":
91
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
92
+ gallery_name = f"reconstruction_{timestamp}"
93
+
94
+ gallery_path = os.path.join(gallery_dir, gallery_name)
95
+
96
+ # Check if directory already exists
97
+ if os.path.exists(gallery_path):
98
+ return False, f"Save failed: folder '{gallery_name}' already exists"
99
+
100
+ # Create the gallery directory
101
+ os.makedirs(gallery_path, exist_ok=True)
102
+
103
+ # Copy GLB file
104
+ glb_source = os.path.join(target_dir, "scene.glb")
105
+ glb_dest = os.path.join(gallery_path, "scene.glb")
106
+ if os.path.exists(glb_source):
107
+ shutil.copy2(glb_source, glb_dest)
108
+
109
+ # Copy depth visualization images
110
+ depth_vis_dir = os.path.join(target_dir, "depth_vis")
111
+ if os.path.exists(depth_vis_dir):
112
+ gallery_depth_vis = os.path.join(gallery_path, "depth_vis")
113
+ shutil.copytree(depth_vis_dir, gallery_depth_vis)
114
+
115
+ # Copy original images
116
+ images_source = os.path.join(target_dir, "images")
117
+ if os.path.exists(images_source):
118
+ gallery_images = os.path.join(gallery_path, "images")
119
+ shutil.copytree(images_source, gallery_images)
120
+
121
+ scene_preview_source = os.path.join(target_dir, "scene.jpg")
122
+ scene_preview_dest = os.path.join(gallery_path, "scene.jpg")
123
+ shutil.copy2(scene_preview_source, scene_preview_dest)
124
+
125
+ # Save metadata
126
+ metadata = {
127
+ "timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
128
+ "num_images": len(processed_data) if processed_data else 0,
129
+ "gallery_name": gallery_name,
130
+ }
131
+
132
+ with open(os.path.join(gallery_path, "metadata.json"), "w") as f:
133
+ json.dump(metadata, f, indent=2)
134
+
135
+ print(f"Saved reconstruction to gallery: {gallery_path}")
136
+ return True, f"Save successful: saved to {gallery_path}"
137
+
138
+ except Exception as e:
139
+ print(f"Error saving to gallery: {e}")
140
+ return False, f"Save failed: {str(e)}"
141
+
142
+
143
+ def get_scene_info(examples_dir: str) -> List[Dict[str, Any]]:
144
+ """
145
+ Get information about scenes in the examples directory.
146
+
147
+ Args:
148
+ examples_dir: Path to examples directory
149
+
150
+ Returns:
151
+ List of scene information dictionaries
152
+ """
153
+ import glob
154
+
155
+ scenes = []
156
+ if not os.path.exists(examples_dir):
157
+ return scenes
158
+
159
+ for scene_folder in sorted(os.listdir(examples_dir)):
160
+ scene_path = os.path.join(examples_dir, scene_folder)
161
+ if os.path.isdir(scene_path):
162
+ # Find all image files in the scene folder
163
+ image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
164
+ image_files = []
165
+ for ext in image_extensions:
166
+ image_files.extend(glob.glob(os.path.join(scene_path, ext)))
167
+ image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
168
+
169
+ if image_files:
170
+ # Sort images and get the first one for thumbnail
171
+ image_files = sorted(image_files)
172
+ first_image = image_files[0]
173
+ num_images = len(image_files)
174
+
175
+ scenes.append(
176
+ {
177
+ "name": scene_folder,
178
+ "path": scene_path,
179
+ "thumbnail": first_image,
180
+ "num_images": num_images,
181
+ "image_files": image_files,
182
+ }
183
+ )
184
+
185
+ return scenes
186
+
187
+
188
+ def cleanup_memory() -> None:
189
+ """Clean up GPU memory and garbage collect."""
190
+ gc.collect()
191
+ if torch.cuda.is_available():
192
+ torch.cuda.empty_cache()
193
+
194
+
195
+ def get_logo_base64() -> Optional[str]:
196
+ """
197
+ Convert WAI logo to base64 for embedding in HTML.
198
+
199
+ Returns:
200
+ Base64 encoded logo string or None
201
+ """
202
+ import base64
203
+
204
+ logo_path = "examples/WAI-Logo/wai_logo.png"
205
+ try:
206
+ with open(logo_path, "rb") as img_file:
207
+ img_data = img_file.read()
208
+ base64_str = base64.b64encode(img_data).decode()
209
+ return f"data:image/png;base64,{base64_str}"
210
+ except FileNotFoundError:
211
+ return None
src/depth_anything_3/app/modules/visualization.py ADDED
@@ -0,0 +1,434 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Visualization module for Depth Anything 3 Gradio app.
17
+
18
+ This module handles visualization updates, navigation, and measurement functionality.
19
+ """
20
+
21
+ import os
22
+ from typing import Any, Dict, List, Optional, Tuple
23
+ import cv2
24
+ import gradio as gr
25
+ import numpy as np
26
+
27
+
28
+ class VisualizationHandler:
29
+ """
30
+ Handles visualization updates and navigation for the Gradio app.
31
+ """
32
+
33
+ def __init__(self):
34
+ """Initialize the visualization handler."""
35
+
36
+ def update_view_selectors(
37
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]]
38
+ ) -> Tuple[gr.Dropdown, gr.Dropdown]:
39
+ """
40
+ Update view selector dropdowns based on available views.
41
+
42
+ Args:
43
+ processed_data: Processed data dictionary
44
+
45
+ Returns:
46
+ Tuple of (depth_view_selector, measure_view_selector)
47
+ """
48
+ if processed_data is None or len(processed_data) == 0:
49
+ choices = ["View 1"]
50
+ else:
51
+ num_views = len(processed_data)
52
+ choices = [f"View {i + 1}" for i in range(num_views)]
53
+
54
+ return (
55
+ gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
56
+ gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
57
+ )
58
+
59
+ def get_view_data_by_index(
60
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
61
+ ) -> Optional[Dict[str, Any]]:
62
+ """
63
+ Get view data by index, handling bounds.
64
+
65
+ Args:
66
+ processed_data: Processed data dictionary
67
+ view_index: Index of the view to get
68
+
69
+ Returns:
70
+ View data dictionary or None
71
+ """
72
+ if processed_data is None or len(processed_data) == 0:
73
+ return None
74
+
75
+ view_keys = list(processed_data.keys())
76
+ if view_index < 0 or view_index >= len(view_keys):
77
+ view_index = 0
78
+
79
+ return processed_data[view_keys[view_index]]
80
+
81
+ def update_depth_view(
82
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
83
+ ) -> Optional[str]:
84
+ """
85
+ Update depth view for a specific view index.
86
+
87
+ Args:
88
+ processed_data: Processed data dictionary
89
+ view_index: Index of the view to update
90
+
91
+ Returns:
92
+ Path to depth visualization image or None
93
+ """
94
+ view_data = self.get_view_data_by_index(processed_data, view_index)
95
+ if view_data is None or view_data.get("depth_image") is None:
96
+ return None
97
+
98
+ # Return the depth visualization image directly
99
+ return view_data["depth_image"]
100
+
101
+ def navigate_depth_view(
102
+ self,
103
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
104
+ current_selector_value: str,
105
+ direction: int,
106
+ ) -> Tuple[str, Optional[str]]:
107
+ """
108
+ Navigate depth view (direction: -1 for previous, +1 for next).
109
+
110
+ Args:
111
+ processed_data: Processed data dictionary
112
+ current_selector_value: Current selector value
113
+ direction: Direction to navigate (-1 for previous, +1 for next)
114
+
115
+ Returns:
116
+ Tuple of (new_selector_value, depth_vis)
117
+ """
118
+ if processed_data is None or len(processed_data) == 0:
119
+ return "View 1", None
120
+
121
+ # Parse current view number
122
+ try:
123
+ current_view = int(current_selector_value.split()[1]) - 1
124
+ except: # noqa
125
+ current_view = 0
126
+
127
+ num_views = len(processed_data)
128
+ new_view = (current_view + direction) % num_views
129
+
130
+ new_selector_value = f"View {new_view + 1}"
131
+ depth_vis = self.update_depth_view(processed_data, new_view)
132
+
133
+ return new_selector_value, depth_vis
134
+
135
+ def update_measure_view(
136
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
137
+ ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
138
+ """
139
+ Update measure view for a specific view index.
140
+
141
+ Args:
142
+ processed_data: Processed data dictionary
143
+ view_index: Index of the view to update
144
+
145
+ Returns:
146
+ Tuple of (measure_image, depth_right_half, measure_points)
147
+ """
148
+ view_data = self.get_view_data_by_index(processed_data, view_index)
149
+ if view_data is None:
150
+ return None, None, [] # image, depth_right_half, measure_points
151
+
152
+ # Get the processed (resized) image
153
+ if "image" in view_data and view_data["image"] is not None:
154
+ image = view_data["image"].copy()
155
+ else:
156
+ return None, None, []
157
+
158
+ # Ensure image is in uint8 format
159
+ if image.dtype != np.uint8:
160
+ if image.max() <= 1.0:
161
+ image = (image * 255).astype(np.uint8)
162
+ else:
163
+ image = image.astype(np.uint8)
164
+
165
+ # Extract right half of the depth visualization (pure depth part)
166
+ depth_image_path = view_data.get("depth_image", None)
167
+ depth_right_half = None
168
+
169
+ if depth_image_path and os.path.exists(depth_image_path):
170
+ try:
171
+ # Load the combined depth visualization image
172
+ depth_combined = cv2.imread(depth_image_path)
173
+ depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB)
174
+ if depth_combined is not None:
175
+ height, width = depth_combined.shape[:2]
176
+ # Extract right half (depth visualization part)
177
+ depth_right_half = depth_combined[:, width // 2 :]
178
+ except Exception as e:
179
+ print(f"Error extracting depth right half: {e}")
180
+
181
+ return image, depth_right_half, []
182
+
183
+ def navigate_measure_view(
184
+ self,
185
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
186
+ current_selector_value: str,
187
+ direction: int,
188
+ ) -> Tuple[str, Optional[np.ndarray], Optional[str], List]:
189
+ """
190
+ Navigate measure view (direction: -1 for previous, +1 for next).
191
+
192
+ Args:
193
+ processed_data: Processed data dictionary
194
+ current_selector_value: Current selector value
195
+ direction: Direction to navigate (-1 for previous, +1 for next)
196
+
197
+ Returns:
198
+ Tuple of (new_selector_value, measure_image, depth_image_path, measure_points)
199
+ """
200
+ if processed_data is None or len(processed_data) == 0:
201
+ return "View 1", None, None, []
202
+
203
+ # Parse current view number
204
+ try:
205
+ current_view = int(current_selector_value.split()[1]) - 1
206
+ except: # noqa
207
+ current_view = 0
208
+
209
+ num_views = len(processed_data)
210
+ new_view = (current_view + direction) % num_views
211
+
212
+ new_selector_value = f"View {new_view + 1}"
213
+ measure_image, depth_right_half, measure_points = self.update_measure_view(
214
+ processed_data, new_view
215
+ )
216
+
217
+ return new_selector_value, measure_image, depth_right_half, measure_points
218
+
219
+ def populate_visualization_tabs(
220
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]]
221
+ ) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]:
222
+ """
223
+ Populate the depth and measure tabs with processed data.
224
+
225
+ Args:
226
+ processed_data: Processed data dictionary
227
+
228
+ Returns:
229
+ Tuple of (depth_vis, measure_img, depth_image_path, measure_points)
230
+ """
231
+ if processed_data is None or len(processed_data) == 0:
232
+ return None, None, None, []
233
+
234
+ # Use update function to get depth visualization
235
+ depth_vis = self.update_depth_view(processed_data, 0)
236
+ measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0)
237
+
238
+ return depth_vis, measure_img, depth_right_half, []
239
+
240
+ def reset_measure(
241
+ self, processed_data: Optional[Dict[int, Dict[str, Any]]]
242
+ ) -> Tuple[Optional[np.ndarray], List, str]:
243
+ """
244
+ Reset measure points.
245
+
246
+ Args:
247
+ processed_data: Processed data dictionary
248
+
249
+ Returns:
250
+ Tuple of (image, measure_points, text)
251
+ """
252
+ if processed_data is None or len(processed_data) == 0:
253
+ return None, [], ""
254
+
255
+ # Return the first view image
256
+ first_view = list(processed_data.values())[0]
257
+ return first_view["image"], [], ""
258
+
259
+ def measure(
260
+ self,
261
+ processed_data: Optional[Dict[int, Dict[str, Any]]],
262
+ measure_points: List,
263
+ current_view_selector: str,
264
+ event: gr.SelectData,
265
+ ) -> List:
266
+ """
267
+ Handle measurement on images.
268
+
269
+ Args:
270
+ processed_data: Processed data dictionary
271
+ measure_points: List of current measure points
272
+ current_view_selector: Current view selector value
273
+ event: Gradio select event
274
+
275
+ Returns:
276
+ List of [image, depth_right_half, measure_points, text]
277
+ """
278
+ try:
279
+ print(f"Measure function called with selector: {current_view_selector}")
280
+
281
+ if processed_data is None or len(processed_data) == 0:
282
+ return [None, [], "No data available"]
283
+
284
+ # Use the currently selected view instead of always using the first view
285
+ try:
286
+ current_view_index = int(current_view_selector.split()[1]) - 1
287
+ except: # noqa
288
+ current_view_index = 0
289
+
290
+ print(f"Using view index: {current_view_index}")
291
+
292
+ # Get view data safely
293
+ if current_view_index < 0 or current_view_index >= len(processed_data):
294
+ current_view_index = 0
295
+
296
+ view_keys = list(processed_data.keys())
297
+ current_view = processed_data[view_keys[current_view_index]]
298
+
299
+ if current_view is None:
300
+ return [None, [], "No view data available"]
301
+
302
+ point2d = event.index[0], event.index[1]
303
+ print(f"Clicked point: {point2d}")
304
+
305
+ measure_points.append(point2d)
306
+
307
+ # Get image and depth visualization
308
+ image, depth_right_half, _ = self.update_measure_view(
309
+ processed_data, current_view_index
310
+ )
311
+ if image is None:
312
+ return [None, [], "No image available"]
313
+
314
+ image = image.copy()
315
+
316
+ # Ensure image is in uint8 format for proper cv2 operations
317
+ try:
318
+ if image.dtype != np.uint8:
319
+ if image.max() <= 1.0:
320
+ # Image is in [0, 1] range, convert to [0, 255]
321
+ image = (image * 255).astype(np.uint8)
322
+ else:
323
+ # Image is already in [0, 255] range
324
+ image = image.astype(np.uint8)
325
+ except Exception as e:
326
+ print(f"Image conversion error: {e}")
327
+ return [None, [], f"Image conversion error: {e}"]
328
+
329
+ # Draw circles for points
330
+ try:
331
+ for p in measure_points:
332
+ if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
333
+ image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
334
+ except Exception as e:
335
+ print(f"Drawing error: {e}")
336
+ return [None, [], f"Drawing error: {e}"]
337
+
338
+ # Get depth information from processed_data
339
+ depth_text = ""
340
+ try:
341
+ for i, p in enumerate(measure_points):
342
+ if (
343
+ current_view["depth"] is not None
344
+ and 0 <= p[1] < current_view["depth"].shape[0]
345
+ and 0 <= p[0] < current_view["depth"].shape[1]
346
+ ):
347
+ d = current_view["depth"][p[1], p[0]]
348
+ depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n"
349
+ else:
350
+ depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n" # noqa: E501
351
+ except Exception as e:
352
+ print(f"Depth text error: {e}")
353
+ depth_text = f"Error computing depth: {e}\n"
354
+
355
+ if len(measure_points) == 2:
356
+ try:
357
+ point1, point2 = measure_points
358
+ # Draw line
359
+ if (
360
+ 0 <= point1[0] < image.shape[1]
361
+ and 0 <= point1[1] < image.shape[0]
362
+ and 0 <= point2[0] < image.shape[1]
363
+ and 0 <= point2[1] < image.shape[0]
364
+ ):
365
+ image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
366
+
367
+ # Compute 3D distance using depth information and camera intrinsics
368
+ distance_text = "- **Distance: Unable to calculate 3D distance**"
369
+ if (
370
+ current_view["depth"] is not None
371
+ and 0 <= point1[1] < current_view["depth"].shape[0]
372
+ and 0 <= point1[0] < current_view["depth"].shape[1]
373
+ and 0 <= point2[1] < current_view["depth"].shape[0]
374
+ and 0 <= point2[0] < current_view["depth"].shape[1]
375
+ ):
376
+ try:
377
+ # Get depth values at the two points
378
+ d1 = current_view["depth"][point1[1], point1[0]]
379
+ d2 = current_view["depth"][point2[1], point2[0]]
380
+
381
+ # Convert 2D pixel coordinates to 3D world coordinates
382
+ if current_view["intrinsics"] is not None:
383
+ # Get camera intrinsics
384
+ K = current_view["intrinsics"] # 3x3 intrinsic matrix
385
+ fx, fy = K[0, 0], K[1, 1] # focal lengths
386
+ cx, cy = K[0, 2], K[1, 2] # principal point
387
+
388
+ # Convert pixel coordinates to normalized camera coordinates
389
+ # Point 1: (u1, v1) -> (x1, y1, z1)
390
+ u1, v1 = point1[0], point1[1]
391
+ x1 = (u1 - cx) * d1 / fx
392
+ y1 = (v1 - cy) * d1 / fy
393
+ z1 = d1
394
+
395
+ # Point 2: (u2, v2) -> (x2, y2, z2)
396
+ u2, v2 = point2[0], point2[1]
397
+ x2 = (u2 - cx) * d2 / fx
398
+ y2 = (v2 - cy) * d2 / fy
399
+ z2 = d2
400
+
401
+ # Calculate 3D Euclidean distance
402
+ p1_3d = np.array([x1, y1, z1])
403
+ p2_3d = np.array([x2, y2, z2])
404
+ distance_3d = np.linalg.norm(p1_3d - p2_3d)
405
+
406
+ distance_text = f"- **Distance: {distance_3d:.2f}m**"
407
+ else:
408
+ # Fallback to simplified calculation if no intrinsics
409
+ pixel_distance = np.sqrt(
410
+ (point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
411
+ )
412
+ avg_depth = (d1 + d2) / 2
413
+ scale_factor = avg_depth / 1000 # Rough scaling factor
414
+ estimated_3d_distance = pixel_distance * scale_factor
415
+ distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**" # noqa: E501
416
+
417
+ except Exception as e:
418
+ print(f"Distance computation error: {e}")
419
+ distance_text = f"- **Distance computation error: {e}**"
420
+
421
+ measure_points = []
422
+ text = depth_text + distance_text
423
+ print(f"Measurement complete: {text}")
424
+ return [image, depth_right_half, measure_points, text]
425
+ except Exception as e:
426
+ print(f"Final measurement error: {e}")
427
+ return [None, [], f"Measurement error: {e}"]
428
+ else:
429
+ print(f"Single point measurement: {depth_text}")
430
+ return [image, depth_right_half, measure_points, depth_text]
431
+
432
+ except Exception as e:
433
+ print(f"Overall measure function error: {e}")
434
+ return [None, [], f"Measure function error: {e}"]
src/depth_anything_3/cfg.py ADDED
@@ -0,0 +1,144 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Configuration utility functions
17
+ """
18
+
19
+ import importlib
20
+ from pathlib import Path
21
+ from typing import Any, Callable, List, Union
22
+ from omegaconf import DictConfig, ListConfig, OmegaConf
23
+
24
+ try:
25
+ OmegaConf.register_new_resolver("eval", eval)
26
+ except Exception as e:
27
+ # if eval is not available, we can just pass
28
+ print(f"Error registering eval resolver: {e}")
29
+
30
+
31
+ def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
32
+ """
33
+ Load a configuration. Will resolve inheritance.
34
+ Supports both file paths and module paths (e.g., depth_anything_3.configs.giant).
35
+ """
36
+ # Check if path is a module path (contains dots but no slashes and doesn't end with .yaml)
37
+ if "." in path and "/" not in path and not path.endswith(".yaml"):
38
+ # It's a module path, load from package resources
39
+ path_parts = path.split(".")[1:]
40
+ config_path = Path(__file__).resolve().parent
41
+ for part in path_parts:
42
+ config_path = config_path.joinpath(part)
43
+ config_path = config_path.with_suffix(".yaml")
44
+ config = OmegaConf.load(str(config_path))
45
+ else:
46
+ # It's a file path (absolute, relative, or with .yaml extension)
47
+ config = OmegaConf.load(path)
48
+
49
+ if argv is not None:
50
+ config_argv = OmegaConf.from_dotlist(argv)
51
+ config = OmegaConf.merge(config, config_argv)
52
+ config = resolve_recursive(config, resolve_inheritance)
53
+ return config
54
+
55
+
56
+ def resolve_recursive(
57
+ config: Any,
58
+ resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
59
+ ) -> Any:
60
+ config = resolver(config)
61
+ if isinstance(config, DictConfig):
62
+ for k in config.keys():
63
+ v = config.get(k)
64
+ if isinstance(v, (DictConfig, ListConfig)):
65
+ config[k] = resolve_recursive(v, resolver)
66
+ if isinstance(config, ListConfig):
67
+ for i in range(len(config)):
68
+ v = config.get(i)
69
+ if isinstance(v, (DictConfig, ListConfig)):
70
+ config[i] = resolve_recursive(v, resolver)
71
+ return config
72
+
73
+
74
+ def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
75
+ """
76
+ Recursively resolve inheritance if the config contains:
77
+ __inherit__: path/to/parent.yaml or a ListConfig of such paths.
78
+ """
79
+ if isinstance(config, DictConfig):
80
+ inherit = config.pop("__inherit__", None)
81
+
82
+ if inherit:
83
+ inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
84
+
85
+ parent_config = None
86
+ for parent_path in inherit_list:
87
+ assert isinstance(parent_path, str)
88
+ parent_config = (
89
+ load_config(parent_path)
90
+ if parent_config is None
91
+ else OmegaConf.merge(parent_config, load_config(parent_path))
92
+ )
93
+
94
+ if len(config.keys()) > 0:
95
+ config = OmegaConf.merge(parent_config, config)
96
+ else:
97
+ config = parent_config
98
+ return config
99
+
100
+
101
+ def import_item(path: str, name: str) -> Any:
102
+ """
103
+ Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
104
+ """
105
+ return getattr(importlib.import_module(path), name)
106
+
107
+
108
+ def create_object(config: DictConfig) -> Any:
109
+ """
110
+ Create an object from config.
111
+ The config is expected to contains the following:
112
+ __object__:
113
+ path: path.to.module
114
+ name: MyClass
115
+ args: as_config | as_params (default to as_config)
116
+ """
117
+ config = DictConfig(config)
118
+ item = import_item(
119
+ path=config.__object__.path,
120
+ name=config.__object__.name,
121
+ )
122
+ args = config.__object__.get("args", "as_config")
123
+ if args == "as_config":
124
+ return item(config)
125
+ if args == "as_params":
126
+ config = OmegaConf.to_object(config)
127
+ config.pop("__object__")
128
+ return item(**config)
129
+ raise NotImplementedError(f"Unknown args type: {args}")
130
+
131
+
132
+ def create_dataset(path: str, *args, **kwargs) -> Any:
133
+ """
134
+ Create a dataset. Requires the file to contain a "create_dataset" function.
135
+ """
136
+ return import_item(path, "create_dataset")(*args, **kwargs)
137
+
138
+
139
+ def to_dict_recursive(config_obj):
140
+ if isinstance(config_obj, DictConfig):
141
+ return {k: to_dict_recursive(v) for k, v in config_obj.items()}
142
+ elif isinstance(config_obj, ListConfig):
143
+ return [to_dict_recursive(item) for item in config_obj]
144
+ return config_obj
src/depth_anything_3/cli.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: E402
2
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """
16
+ Refactored Depth Anything 3 CLI
17
+ Clean, modular command-line interface
18
+ """
19
+
20
+ from __future__ import annotations
21
+
22
+ import os
23
+ import typer
24
+
25
+ from depth_anything_3.services import start_server
26
+ from depth_anything_3.services.gallery import gallery as gallery_main
27
+ from depth_anything_3.services.inference_service import run_inference
28
+ from depth_anything_3.services.input_handlers import (
29
+ ColmapHandler,
30
+ ImageHandler,
31
+ ImagesHandler,
32
+ InputHandler,
33
+ VideoHandler,
34
+ parse_export_feat,
35
+ )
36
+ from depth_anything_3.utils.constants import DEFAULT_EXPORT_DIR, DEFAULT_GALLERY_DIR, DEFAULT_GRADIO_DIR, DEFAULT_MODEL
37
+
38
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
39
+
40
+ app = typer.Typer(help="Depth Anything 3 - Video depth estimation CLI", add_completion=False)
41
+
42
+
43
+ # ============================================================================
44
+ # Input type detection utilities
45
+ # ============================================================================
46
+
47
+ # Supported file extensions
48
+ IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".tif"}
49
+ VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
50
+
51
+
52
+ def detect_input_type(input_path: str) -> str:
53
+ """
54
+ Detect input type from path.
55
+
56
+ Returns:
57
+ - "image": Single image file
58
+ - "images": Directory containing images
59
+ - "video": Video file
60
+ - "colmap": COLMAP directory structure
61
+ - "unknown": Cannot determine type
62
+ """
63
+ if not os.path.exists(input_path):
64
+ return "unknown"
65
+
66
+ # Check if it's a file
67
+ if os.path.isfile(input_path):
68
+ ext = os.path.splitext(input_path)[1].lower()
69
+ if ext in IMAGE_EXTENSIONS:
70
+ return "image"
71
+ elif ext in VIDEO_EXTENSIONS:
72
+ return "video"
73
+ return "unknown"
74
+
75
+ # Check if it's a directory
76
+ if os.path.isdir(input_path):
77
+ # Check for COLMAP structure
78
+ images_dir = os.path.join(input_path, "images")
79
+ sparse_dir = os.path.join(input_path, "sparse")
80
+
81
+ if os.path.isdir(images_dir) and os.path.isdir(sparse_dir):
82
+ return "colmap"
83
+
84
+ # Check if directory contains image files
85
+ for item in os.listdir(input_path):
86
+ item_path = os.path.join(input_path, item)
87
+ if os.path.isfile(item_path):
88
+ ext = os.path.splitext(item)[1].lower()
89
+ if ext in IMAGE_EXTENSIONS:
90
+ return "images"
91
+
92
+ return "unknown"
93
+
94
+ return "unknown"
95
+
96
+
97
+ # ============================================================================
98
+ # Common parameters and configuration
99
+ # ============================================================================
100
+
101
+ # ============================================================================
102
+ # Inference commands
103
+ # ============================================================================
104
+
105
+
106
+ @app.command()
107
+ def auto(
108
+ input_path: str = typer.Argument(
109
+ ..., help="Path to input (image, directory, video, or COLMAP)"
110
+ ),
111
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
112
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
113
+ export_format: str = typer.Option("glb", help="Export format"),
114
+ device: str = typer.Option("cuda", help="Device to use"),
115
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
116
+ backend_url: str = typer.Option(
117
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
118
+ ),
119
+ process_res: int = typer.Option(504, help="Processing resolution"),
120
+ process_res_method: str = typer.Option(
121
+ "upper_bound_resize", help="Processing resolution method"
122
+ ),
123
+ export_feat: str = typer.Option(
124
+ "",
125
+ help="[FEAT_VIS]Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
126
+ ),
127
+ auto_cleanup: bool = typer.Option(
128
+ False, help="Automatically clean export directory if it exists (no prompt)"
129
+ ),
130
+ # Video-specific options
131
+ fps: float = typer.Option(1.0, help="[Video] Sampling FPS for frame extraction"),
132
+ # COLMAP-specific options
133
+ sparse_subdir: str = typer.Option(
134
+ "", help="[COLMAP] Sparse reconstruction subdirectory (e.g., '0' for sparse/0/)"
135
+ ),
136
+ align_to_input_ext_scale: bool = typer.Option(
137
+ True, help="[COLMAP] Align prediction to input extrinsics scale"
138
+ ),
139
+ # GLB export options
140
+ conf_thresh_percentile: float = typer.Option(
141
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
142
+ ),
143
+ num_max_points: int = typer.Option(
144
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
145
+ ),
146
+ show_cameras: bool = typer.Option(
147
+ True, help="[GLB] Show camera wireframes in the exported scene"
148
+ ),
149
+ # Feat_vis export options
150
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
151
+ ):
152
+ """
153
+ Automatically detect input type and run appropriate processing.
154
+
155
+ Supports:
156
+ - Single image file (.jpg, .png, etc.)
157
+ - Directory of images
158
+ - Video file (.mp4, .avi, etc.)
159
+ - COLMAP directory (with 'images' and 'sparse' subdirectories)
160
+ """
161
+ # Detect input type
162
+ input_type = detect_input_type(input_path)
163
+
164
+ if input_type == "unknown":
165
+ typer.echo(f"❌ Error: Cannot determine input type for: {input_path}", err=True)
166
+ typer.echo("Supported inputs:", err=True)
167
+ typer.echo(" - Single image file (.jpg, .png, etc.)", err=True)
168
+ typer.echo(" - Directory containing images", err=True)
169
+ typer.echo(" - Video file (.mp4, .avi, etc.)", err=True)
170
+ typer.echo(" - COLMAP directory (with 'images/' and 'sparse/' subdirectories)", err=True)
171
+ raise typer.Exit(1)
172
+
173
+ # Display detected type
174
+ typer.echo(f"🔍 Detected input type: {input_type.upper()}")
175
+ typer.echo(f"📁 Input path: {input_path}")
176
+ typer.echo()
177
+
178
+ # Determine backend URL based on use_backend flag
179
+ final_backend_url = backend_url if use_backend else None
180
+
181
+ # Parse export_feat parameter
182
+ export_feat_layers = parse_export_feat(export_feat)
183
+
184
+ # Route to appropriate handler
185
+ if input_type == "image":
186
+ typer.echo("Processing single image...")
187
+ # Process input
188
+ image_files = ImageHandler.process(input_path)
189
+
190
+ # Handle export directory
191
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
192
+
193
+ # Run inference
194
+ run_inference(
195
+ image_paths=image_files,
196
+ export_dir=export_dir,
197
+ model_dir=model_dir,
198
+ device=device,
199
+ backend_url=final_backend_url,
200
+ export_format=export_format,
201
+ process_res=process_res,
202
+ process_res_method=process_res_method,
203
+ export_feat_layers=export_feat_layers,
204
+ conf_thresh_percentile=conf_thresh_percentile,
205
+ num_max_points=num_max_points,
206
+ show_cameras=show_cameras,
207
+ feat_vis_fps=feat_vis_fps,
208
+ )
209
+
210
+ elif input_type == "images":
211
+ typer.echo("Processing directory of images...")
212
+ # Process input - use default extensions
213
+ image_files = ImagesHandler.process(input_path, "png,jpg,jpeg")
214
+
215
+ # Handle export directory
216
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
217
+
218
+ # Run inference
219
+ run_inference(
220
+ image_paths=image_files,
221
+ export_dir=export_dir,
222
+ model_dir=model_dir,
223
+ device=device,
224
+ backend_url=final_backend_url,
225
+ export_format=export_format,
226
+ process_res=process_res,
227
+ process_res_method=process_res_method,
228
+ export_feat_layers=export_feat_layers,
229
+ conf_thresh_percentile=conf_thresh_percentile,
230
+ num_max_points=num_max_points,
231
+ show_cameras=show_cameras,
232
+ feat_vis_fps=feat_vis_fps,
233
+ )
234
+
235
+ elif input_type == "video":
236
+ typer.echo(f"Processing video with FPS={fps}...")
237
+ # Handle export directory
238
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
239
+
240
+ # Process input
241
+ image_files = VideoHandler.process(input_path, export_dir, fps)
242
+
243
+ # Run inference
244
+ run_inference(
245
+ image_paths=image_files,
246
+ export_dir=export_dir,
247
+ model_dir=model_dir,
248
+ device=device,
249
+ backend_url=final_backend_url,
250
+ export_format=export_format,
251
+ process_res=process_res,
252
+ process_res_method=process_res_method,
253
+ export_feat_layers=export_feat_layers,
254
+ conf_thresh_percentile=conf_thresh_percentile,
255
+ num_max_points=num_max_points,
256
+ show_cameras=show_cameras,
257
+ feat_vis_fps=feat_vis_fps,
258
+ )
259
+
260
+ elif input_type == "colmap":
261
+ typer.echo(
262
+ f"Processing COLMAP directory (sparse subdirectory: '{sparse_subdir or 'default'}')..."
263
+ )
264
+ # Process input
265
+ image_files, extrinsics, intrinsics = ColmapHandler.process(input_path, sparse_subdir)
266
+
267
+ # Handle export directory
268
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
269
+
270
+ # Run inference
271
+ run_inference(
272
+ image_paths=image_files,
273
+ export_dir=export_dir,
274
+ model_dir=model_dir,
275
+ device=device,
276
+ backend_url=final_backend_url,
277
+ export_format=export_format,
278
+ process_res=process_res,
279
+ process_res_method=process_res_method,
280
+ export_feat_layers=export_feat_layers,
281
+ extrinsics=extrinsics,
282
+ intrinsics=intrinsics,
283
+ align_to_input_ext_scale=align_to_input_ext_scale,
284
+ conf_thresh_percentile=conf_thresh_percentile,
285
+ num_max_points=num_max_points,
286
+ show_cameras=show_cameras,
287
+ feat_vis_fps=feat_vis_fps,
288
+ )
289
+
290
+ typer.echo()
291
+ typer.echo("✅ Processing completed successfully!")
292
+
293
+
294
+ @app.command()
295
+ def image(
296
+ image_path: str = typer.Argument(..., help="Path to input image file"),
297
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
298
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
299
+ export_format: str = typer.Option("glb", help="Export format"),
300
+ device: str = typer.Option("cuda", help="Device to use"),
301
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
302
+ backend_url: str = typer.Option(
303
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
304
+ ),
305
+ process_res: int = typer.Option(504, help="Processing resolution"),
306
+ process_res_method: str = typer.Option(
307
+ "upper_bound_resize", help="Processing resolution method"
308
+ ),
309
+ export_feat: str = typer.Option(
310
+ "",
311
+ help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
312
+ ),
313
+ auto_cleanup: bool = typer.Option(
314
+ False, help="Automatically clean export directory if it exists (no prompt)"
315
+ ),
316
+ # GLB export options
317
+ conf_thresh_percentile: float = typer.Option(
318
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
319
+ ),
320
+ num_max_points: int = typer.Option(
321
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
322
+ ),
323
+ show_cameras: bool = typer.Option(
324
+ True, help="[GLB] Show camera wireframes in the exported scene"
325
+ ),
326
+ # Feat_vis export options
327
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
328
+ ):
329
+ """Run camera pose and depth estimation on a single image."""
330
+ # Process input
331
+ image_files = ImageHandler.process(image_path)
332
+
333
+ # Handle export directory
334
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
335
+
336
+ # Parse export_feat parameter
337
+ export_feat_layers = parse_export_feat(export_feat)
338
+
339
+ # Determine backend URL based on use_backend flag
340
+ final_backend_url = backend_url if use_backend else None
341
+
342
+ # Run inference
343
+ run_inference(
344
+ image_paths=image_files,
345
+ export_dir=export_dir,
346
+ model_dir=model_dir,
347
+ device=device,
348
+ backend_url=final_backend_url,
349
+ export_format=export_format,
350
+ process_res=process_res,
351
+ process_res_method=process_res_method,
352
+ export_feat_layers=export_feat_layers,
353
+ conf_thresh_percentile=conf_thresh_percentile,
354
+ num_max_points=num_max_points,
355
+ show_cameras=show_cameras,
356
+ feat_vis_fps=feat_vis_fps,
357
+ )
358
+
359
+
360
+ @app.command()
361
+ def images(
362
+ images_dir: str = typer.Argument(..., help="Path to directory containing input images"),
363
+ image_extensions: str = typer.Option(
364
+ "png,jpg,jpeg", help="Comma-separated image file extensions to process"
365
+ ),
366
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
367
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
368
+ export_format: str = typer.Option("glb", help="Export format"),
369
+ device: str = typer.Option("cuda", help="Device to use"),
370
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
371
+ backend_url: str = typer.Option(
372
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
373
+ ),
374
+ process_res: int = typer.Option(504, help="Processing resolution"),
375
+ process_res_method: str = typer.Option(
376
+ "upper_bound_resize", help="Processing resolution method"
377
+ ),
378
+ export_feat: str = typer.Option(
379
+ "",
380
+ help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
381
+ ),
382
+ auto_cleanup: bool = typer.Option(
383
+ False, help="Automatically clean export directory if it exists (no prompt)"
384
+ ),
385
+ # GLB export options
386
+ conf_thresh_percentile: float = typer.Option(
387
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
388
+ ),
389
+ num_max_points: int = typer.Option(
390
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
391
+ ),
392
+ show_cameras: bool = typer.Option(
393
+ True, help="[GLB] Show camera wireframes in the exported scene"
394
+ ),
395
+ # Feat_vis export options
396
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
397
+ ):
398
+ """Run camera pose and depth estimation on a directory of images."""
399
+ # Process input
400
+ image_files = ImagesHandler.process(images_dir, image_extensions)
401
+
402
+ # Handle export directory
403
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
404
+
405
+ # Parse export_feat parameter
406
+ export_feat_layers = parse_export_feat(export_feat)
407
+
408
+ # Determine backend URL based on use_backend flag
409
+ final_backend_url = backend_url if use_backend else None
410
+
411
+ # Run inference
412
+ run_inference(
413
+ image_paths=image_files,
414
+ export_dir=export_dir,
415
+ model_dir=model_dir,
416
+ device=device,
417
+ backend_url=final_backend_url,
418
+ export_format=export_format,
419
+ process_res=process_res,
420
+ process_res_method=process_res_method,
421
+ export_feat_layers=export_feat_layers,
422
+ conf_thresh_percentile=conf_thresh_percentile,
423
+ num_max_points=num_max_points,
424
+ show_cameras=show_cameras,
425
+ feat_vis_fps=feat_vis_fps,
426
+ )
427
+
428
+
429
+ @app.command()
430
+ def colmap(
431
+ colmap_dir: str = typer.Argument(
432
+ ..., help="Path to COLMAP directory containing 'images' and 'sparse' subdirectories"
433
+ ),
434
+ sparse_subdir: str = typer.Option(
435
+ "", help="Sparse reconstruction subdirectory (e.g., '0' for sparse/0/, empty for sparse/)"
436
+ ),
437
+ align_to_input_ext_scale: bool = typer.Option(
438
+ True, help="Align prediction to input extrinsics scale"
439
+ ),
440
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
441
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
442
+ export_format: str = typer.Option("glb", help="Export format"),
443
+ device: str = typer.Option("cuda", help="Device to use"),
444
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
445
+ backend_url: str = typer.Option(
446
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
447
+ ),
448
+ process_res: int = typer.Option(504, help="Processing resolution"),
449
+ process_res_method: str = typer.Option(
450
+ "upper_bound_resize", help="Processing resolution method"
451
+ ),
452
+ export_feat: str = typer.Option(
453
+ "",
454
+ help="Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
455
+ ),
456
+ auto_cleanup: bool = typer.Option(
457
+ False, help="Automatically clean export directory if it exists (no prompt)"
458
+ ),
459
+ # GLB export options
460
+ conf_thresh_percentile: float = typer.Option(
461
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
462
+ ),
463
+ num_max_points: int = typer.Option(
464
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
465
+ ),
466
+ show_cameras: bool = typer.Option(
467
+ True, help="[GLB] Show camera wireframes in the exported scene"
468
+ ),
469
+ # Feat_vis export options
470
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
471
+ ):
472
+ """Run pose conditioned depth estimation on COLMAP data."""
473
+ # Process input
474
+ image_files, extrinsics, intrinsics = ColmapHandler.process(colmap_dir, sparse_subdir)
475
+
476
+ # Handle export directory
477
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
478
+
479
+ # Parse export_feat parameter
480
+ export_feat_layers = parse_export_feat(export_feat)
481
+
482
+ # Determine backend URL based on use_backend flag
483
+ final_backend_url = backend_url if use_backend else None
484
+
485
+ # Run inference
486
+ run_inference(
487
+ image_paths=image_files,
488
+ export_dir=export_dir,
489
+ model_dir=model_dir,
490
+ device=device,
491
+ backend_url=final_backend_url,
492
+ export_format=export_format,
493
+ process_res=process_res,
494
+ process_res_method=process_res_method,
495
+ export_feat_layers=export_feat_layers,
496
+ extrinsics=extrinsics,
497
+ intrinsics=intrinsics,
498
+ align_to_input_ext_scale=align_to_input_ext_scale,
499
+ conf_thresh_percentile=conf_thresh_percentile,
500
+ num_max_points=num_max_points,
501
+ show_cameras=show_cameras,
502
+ feat_vis_fps=feat_vis_fps,
503
+ )
504
+
505
+
506
+ @app.command()
507
+ def video(
508
+ video_path: str = typer.Argument(..., help="Path to input video file"),
509
+ fps: float = typer.Option(1.0, help="Sampling FPS for frame extraction"),
510
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
511
+ export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
512
+ export_format: str = typer.Option("glb", help="Export format"),
513
+ device: str = typer.Option("cuda", help="Device to use"),
514
+ use_backend: bool = typer.Option(False, help="Use backend service for inference"),
515
+ backend_url: str = typer.Option(
516
+ "http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
517
+ ),
518
+ process_res: int = typer.Option(504, help="Processing resolution"),
519
+ process_res_method: str = typer.Option(
520
+ "upper_bound_resize", help="Processing resolution method"
521
+ ),
522
+ export_feat: str = typer.Option(
523
+ "",
524
+ help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
525
+ ),
526
+ auto_cleanup: bool = typer.Option(
527
+ False, help="Automatically clean export directory if it exists (no prompt)"
528
+ ),
529
+ # GLB export options
530
+ conf_thresh_percentile: float = typer.Option(
531
+ 40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
532
+ ),
533
+ num_max_points: int = typer.Option(
534
+ 1_000_000, help="[GLB] Maximum number of points in the point cloud"
535
+ ),
536
+ show_cameras: bool = typer.Option(
537
+ True, help="[GLB] Show camera wireframes in the exported scene"
538
+ ),
539
+ # Feat_vis export options
540
+ feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
541
+ ):
542
+ """Run depth estimation on video by extracting frames and processing them."""
543
+ # Handle export directory
544
+ export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
545
+
546
+ # Process input
547
+ image_files = VideoHandler.process(video_path, export_dir, fps)
548
+
549
+ # Parse export_feat parameter
550
+ export_feat_layers = parse_export_feat(export_feat)
551
+
552
+ # Determine backend URL based on use_backend flag
553
+ final_backend_url = backend_url if use_backend else None
554
+
555
+ # Run inference
556
+ run_inference(
557
+ image_paths=image_files,
558
+ export_dir=export_dir,
559
+ model_dir=model_dir,
560
+ device=device,
561
+ backend_url=final_backend_url,
562
+ export_format=export_format,
563
+ process_res=process_res,
564
+ process_res_method=process_res_method,
565
+ export_feat_layers=export_feat_layers,
566
+ conf_thresh_percentile=conf_thresh_percentile,
567
+ num_max_points=num_max_points,
568
+ show_cameras=show_cameras,
569
+ feat_vis_fps=feat_vis_fps,
570
+ )
571
+
572
+
573
+ # ============================================================================
574
+ # Service management commands
575
+ # ============================================================================
576
+
577
+
578
+ @app.command()
579
+ def backend(
580
+ model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
581
+ device: str = typer.Option("cuda", help="Device to use"),
582
+ host: str = typer.Option("127.0.0.1", help="Host to bind to"),
583
+ port: int = typer.Option(8008, help="Port to bind to"),
584
+ gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path (optional)"),
585
+ ):
586
+ """Start model backend service with integrated gallery."""
587
+ typer.echo("=" * 60)
588
+ typer.echo("🚀 Starting Depth Anything 3 Backend Server")
589
+ typer.echo("=" * 60)
590
+ typer.echo(f"Model directory: {model_dir}")
591
+ typer.echo(f"Device: {device}")
592
+
593
+ # Check if gallery directory exists
594
+ if gallery_dir and os.path.exists(gallery_dir):
595
+ typer.echo(f"Gallery directory: {gallery_dir}")
596
+ else:
597
+ gallery_dir = None # Disable gallery if directory doesn't exist
598
+
599
+ typer.echo()
600
+ typer.echo("📡 Server URLs (Ctrl/CMD+Click to open):")
601
+ typer.echo(f" 🏠 Home: http://{host}:{port}")
602
+ typer.echo(f" 📊 Dashboard: http://{host}:{port}/dashboard")
603
+ typer.echo(f" 📈 API Status: http://{host}:{port}/status")
604
+
605
+ if gallery_dir:
606
+ typer.echo(f" 🎨 Gallery: http://{host}:{port}/gallery/")
607
+
608
+ typer.echo("=" * 60)
609
+
610
+ try:
611
+ start_server(model_dir, device, host, port, gallery_dir)
612
+ except KeyboardInterrupt:
613
+ typer.echo("\n👋 Backend server stopped.")
614
+ except Exception as e:
615
+ typer.echo(f"❌ Failed to start backend: {e}")
616
+ raise typer.Exit(1)
617
+
618
+
619
+ # ============================================================================
620
+ # Application launch commands
621
+ # ============================================================================
622
+
623
+
624
+ @app.command()
625
+ def gradio(
626
+ model_dir: str = typer.Option(DEFAULT_MODEL,help="Model directory path"),
627
+ workspace_dir: str = typer.Option(DEFAULT_GRADIO_DIR,help="Workspace directory path"),
628
+ gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR,help="Gallery directory path"),
629
+ host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
630
+ port: int = typer.Option(7860, help="Port number to bind to"),
631
+ share: bool = typer.Option(False, help="Create a public link for the app"),
632
+ debug: bool = typer.Option(False, help="Enable debug mode"),
633
+ cache_examples: bool = typer.Option(
634
+ False, help="Pre-cache all example scenes at startup for faster loading"
635
+ ),
636
+ cache_gs_tag: str = typer.Option(
637
+ "",
638
+ help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.",
639
+ ),
640
+ ):
641
+ """Launch Depth Anything 3 Gradio interactive web application"""
642
+ from depth_anything_3.app.gradio_app import DepthAnything3App
643
+
644
+ # Create necessary directories
645
+ os.makedirs(workspace_dir, exist_ok=True)
646
+ os.makedirs(gallery_dir, exist_ok=True)
647
+
648
+ typer.echo("Launching Depth Anything 3 Gradio application...")
649
+ typer.echo(f"Model directory: {model_dir}")
650
+ typer.echo(f"Workspace directory: {workspace_dir}")
651
+ typer.echo(f"Gallery directory: {gallery_dir}")
652
+ typer.echo(f"Host: {host}")
653
+ typer.echo(f"Port: {port}")
654
+ typer.echo(f"Share: {share}")
655
+ typer.echo(f"Debug mode: {debug}")
656
+ typer.echo(f"Cache examples: {cache_examples}")
657
+ if cache_examples:
658
+ if cache_gs_tag:
659
+ typer.echo(
660
+ f"Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)"
661
+ )
662
+ else:
663
+ typer.echo(f"Cache GS Tag: None (all scenes will use low-res only)")
664
+
665
+ try:
666
+ # Initialize and launch application
667
+ app = DepthAnything3App(
668
+ model_dir=model_dir, workspace_dir=workspace_dir, gallery_dir=gallery_dir
669
+ )
670
+
671
+ # Pre-cache examples if requested
672
+ if cache_examples:
673
+ typer.echo("\n" + "=" * 60)
674
+ typer.echo("Pre-caching mode enabled")
675
+ if cache_gs_tag:
676
+ typer.echo(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS")
677
+ typer.echo(f"Other scenes will use LOW-RES only")
678
+ else:
679
+ typer.echo(f"All scenes will use LOW-RES only")
680
+ typer.echo("=" * 60)
681
+ app.cache_examples(
682
+ show_cam=True,
683
+ filter_black_bg=False,
684
+ filter_white_bg=False,
685
+ save_percentage=20.0,
686
+ num_max_points=1000,
687
+ cache_gs_tag=cache_gs_tag,
688
+ gs_trj_mode="smooth",
689
+ gs_video_quality="low",
690
+ )
691
+
692
+ # Prepare launch arguments
693
+ launch_kwargs = {"share": share, "debug": debug}
694
+
695
+ app.launch(host=host, port=port, **launch_kwargs)
696
+
697
+ except KeyboardInterrupt:
698
+ typer.echo("\nGradio application stopped.")
699
+ except Exception as e:
700
+ typer.echo(f"Failed to launch Gradio application: {e}")
701
+ raise typer.Exit(1)
702
+
703
+
704
+ @app.command()
705
+ def gallery(
706
+ gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery root directory"),
707
+ host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
708
+ port: int = typer.Option(8007, help="Port number to bind to"),
709
+ open_browser: bool = typer.Option(False, help="Open browser after launch"),
710
+ ):
711
+ """Launch Depth Anything 3 Gallery server"""
712
+
713
+ # Validate gallery directory
714
+ if not os.path.exists(gallery_dir):
715
+ raise typer.BadParameter(f"Gallery directory not found: {gallery_dir}")
716
+
717
+ typer.echo("Launching Depth Anything 3 Gallery server...")
718
+ typer.echo(f"Gallery directory: {gallery_dir}")
719
+ typer.echo(f"Host: {host}")
720
+ typer.echo(f"Port: {port}")
721
+ typer.echo(f"Auto-open browser: {open_browser}")
722
+
723
+ try:
724
+ # Set command line arguments
725
+ import sys
726
+
727
+ sys.argv = ["gallery", "--dir", gallery_dir, "--host", host, "--port", str(port)]
728
+ if open_browser:
729
+ sys.argv.append("--open")
730
+
731
+ # Launch gallery server
732
+ gallery_main()
733
+
734
+ except KeyboardInterrupt:
735
+ typer.echo("\nGallery server stopped.")
736
+ except Exception as e:
737
+ typer.echo(f"Failed to launch Gallery server: {e}")
738
+ raise typer.Exit(1)
739
+
740
+
741
+ if __name__ == "__main__":
742
+ app()
src/depth_anything_3/configs/da3-base.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitb
13
+ out_layers: [5, 7, 9, 11]
14
+ alt_start: 4
15
+ qknorm_start: 4
16
+ rope_start: 4
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 1536
26
+ output_dim: 2
27
+ features: &head_features 128
28
+ out_channels: &head_out_channels [96, 192, 384, 768]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 768
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 1536
src/depth_anything_3/configs/da3-giant.yaml ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitg
13
+ out_layers: [19, 27, 33, 39]
14
+ alt_start: 13
15
+ qknorm_start: 13
16
+ rope_start: 13
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 3072
26
+ output_dim: 2
27
+ features: &head_features 256
28
+ out_channels: &head_out_channels [256, 512, 1024, 1024]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 1536
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 3072
46
+
47
+
48
+ gs_head:
49
+ __object__:
50
+ path: depth_anything_3.model.gsdpt
51
+ name: GSDPT
52
+ args: as_params
53
+
54
+ dim_in: *head_dim_in
55
+ output_dim: 38 # should align with gs_adapter's setting, for gs params
56
+ features: *head_features
57
+ out_channels: *head_out_channels
58
+
59
+
60
+ gs_adapter:
61
+ __object__:
62
+ path: depth_anything_3.model.gs_adapter
63
+ name: GaussianAdapter
64
+ args: as_params
65
+
66
+ sh_degree: 2
67
+ pred_color: false # predict SH coefficient if false
68
+ pred_offset_depth: true
69
+ pred_offset_xy: true
70
+ gaussian_scale_min: 1e-5
71
+ gaussian_scale_max: 30.0
src/depth_anything_3/configs/da3-large.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitl
13
+ out_layers: [11, 15, 19, 23]
14
+ alt_start: 8
15
+ qknorm_start: 8
16
+ rope_start: 8
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 2048
26
+ output_dim: 2
27
+ features: &head_features 256
28
+ out_channels: &head_out_channels [256, 512, 1024, 1024]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 1024
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 2048
src/depth_anything_3/configs/da3-small.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vits
13
+ out_layers: [5, 7, 9, 11]
14
+ alt_start: 4
15
+ qknorm_start: 4
16
+ rope_start: 4
17
+ cat_token: True
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dualdpt
22
+ name: DualDPT
23
+ args: as_params
24
+
25
+ dim_in: &head_dim_in 768
26
+ output_dim: 2
27
+ features: &head_features 64
28
+ out_channels: &head_out_channels [48, 96, 192, 384]
29
+
30
+
31
+ cam_enc:
32
+ __object__:
33
+ path: depth_anything_3.model.cam_enc
34
+ name: CameraEnc
35
+ args: as_params
36
+
37
+ dim_out: 384
38
+
39
+ cam_dec:
40
+ __object__:
41
+ path: depth_anything_3.model.cam_dec
42
+ name: CameraDec
43
+ args: as_params
44
+
45
+ dim_in: 768
src/depth_anything_3/configs/da3metric-large.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitl
13
+ out_layers: [4, 11, 17, 23]
14
+ alt_start: -1 # -1 means disable
15
+ qknorm_start: -1
16
+ rope_start: -1
17
+ cat_token: False
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dpt
22
+ name: DPT
23
+ args: as_params
24
+
25
+ dim_in: 1024
26
+ output_dim: 1
27
+ features: 256
28
+ out_channels: [256, 512, 1024, 1024]
src/depth_anything_3/configs/da3mono-large.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: DepthAnything3Net
4
+ args: as_params
5
+
6
+ net:
7
+ __object__:
8
+ path: depth_anything_3.model.dinov2.dinov2
9
+ name: DinoV2
10
+ args: as_params
11
+
12
+ name: vitl
13
+ out_layers: [4, 11, 17, 23]
14
+ alt_start: -1 # -1 means disable
15
+ qknorm_start: -1
16
+ rope_start: -1
17
+ cat_token: False
18
+
19
+ head:
20
+ __object__:
21
+ path: depth_anything_3.model.dpt
22
+ name: DPT
23
+ args: as_params
24
+
25
+ dim_in: 1024
26
+ output_dim: 1
27
+ features: 256
28
+ out_channels: [256, 512, 1024, 1024]
src/depth_anything_3/configs/da3nested-giant-large.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __object__:
2
+ path: depth_anything_3.model.da3
3
+ name: NestedDepthAnything3Net
4
+ args: as_params
5
+
6
+ anyview:
7
+ __inherit__: depth_anything_3.configs.da3-giant
8
+
9
+ metric:
10
+ __inherit__: depth_anything_3.configs.da3metric-large
src/depth_anything_3/model/__init__.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from depth_anything_3.model.da3 import DepthAnything3Net, NestedDepthAnything3Net
16
+
17
+ __export__ = [
18
+ NestedDepthAnything3Net,
19
+ DepthAnything3Net,
20
+ ]
src/depth_anything_3/model/cam_dec.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+
18
+
19
+ class CameraDec(nn.Module):
20
+ def __init__(self, dim_in=1536):
21
+ super().__init__()
22
+ output_dim = dim_in
23
+ self.backbone = nn.Sequential(
24
+ nn.Linear(output_dim, output_dim),
25
+ nn.ReLU(),
26
+ nn.Linear(output_dim, output_dim),
27
+ nn.ReLU(),
28
+ )
29
+ self.fc_t = nn.Linear(output_dim, 3)
30
+ self.fc_qvec = nn.Linear(output_dim, 4)
31
+ self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU())
32
+
33
+ def forward(self, feat, camera_encoding=None, *args, **kwargs):
34
+ B, N = feat.shape[:2]
35
+ feat = feat.reshape(B * N, -1)
36
+ feat = self.backbone(feat)
37
+ out_t = self.fc_t(feat.float()).reshape(B, N, 3)
38
+ if camera_encoding is None:
39
+ out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
40
+ out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
41
+ else:
42
+ out_qvec = camera_encoding[..., 3:7]
43
+ out_fov = camera_encoding[..., -2:]
44
+ pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1)
45
+ return pose_enc
src/depth_anything_3/model/cam_enc.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch.nn as nn
16
+
17
+ from depth_anything_3.model.utils.attention import Mlp
18
+ from depth_anything_3.model.utils.block import Block
19
+ from depth_anything_3.model.utils.transform import extri_intri_to_pose_encoding
20
+ from depth_anything_3.utils.geometry import affine_inverse
21
+
22
+
23
+ class CameraEnc(nn.Module):
24
+ """
25
+ CameraHead predicts camera parameters from token representations using iterative refinement.
26
+
27
+ It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ dim_out: int = 1024,
33
+ dim_in: int = 9,
34
+ trunk_depth: int = 4,
35
+ target_dim: int = 9,
36
+ num_heads: int = 16,
37
+ mlp_ratio: int = 4,
38
+ init_values: float = 0.01,
39
+ **kwargs,
40
+ ):
41
+ super().__init__()
42
+ self.target_dim = target_dim
43
+ self.trunk_depth = trunk_depth
44
+ self.trunk = nn.Sequential(
45
+ *[
46
+ Block(
47
+ dim=dim_out,
48
+ num_heads=num_heads,
49
+ mlp_ratio=mlp_ratio,
50
+ init_values=init_values,
51
+ )
52
+ for _ in range(trunk_depth)
53
+ ]
54
+ )
55
+ self.token_norm = nn.LayerNorm(dim_out)
56
+ self.trunk_norm = nn.LayerNorm(dim_out)
57
+ self.pose_branch = Mlp(
58
+ in_features=dim_in,
59
+ hidden_features=dim_out // 2,
60
+ out_features=dim_out,
61
+ drop=0,
62
+ )
63
+
64
+ def forward(
65
+ self,
66
+ ext,
67
+ ixt,
68
+ image_size,
69
+ ) -> tuple:
70
+ c2ws = affine_inverse(ext)
71
+ pose_encoding = extri_intri_to_pose_encoding(
72
+ c2ws,
73
+ ixt,
74
+ image_size,
75
+ )
76
+ pose_tokens = self.pose_branch(pose_encoding)
77
+ pose_tokens = self.token_norm(pose_tokens)
78
+ pose_tokens = self.trunk(pose_tokens)
79
+ pose_tokens = self.trunk_norm(pose_tokens)
80
+ return pose_tokens
src/depth_anything_3/model/da3.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+ from addict import Dict
20
+ from omegaconf import DictConfig, OmegaConf
21
+
22
+ from depth_anything_3.cfg import create_object
23
+ from depth_anything_3.model.utils.transform import pose_encoding_to_extri_intri
24
+ from depth_anything_3.utils.alignment import (
25
+ apply_metric_scaling,
26
+ compute_alignment_mask,
27
+ compute_sky_mask,
28
+ least_squares_scale_scalar,
29
+ sample_tensor_for_quantile,
30
+ set_sky_regions_to_max_depth,
31
+ )
32
+ from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, map_pdf_to_opacity
33
+
34
+
35
+ def _wrap_cfg(cfg_obj):
36
+ return OmegaConf.create(cfg_obj)
37
+
38
+
39
+ class DepthAnything3Net(nn.Module):
40
+ """
41
+ Depth Anything 3 network for depth estimation and camera pose estimation.
42
+
43
+ This network consists of:
44
+ - Backbone: DinoV2 feature extractor
45
+ - Head: DPT or DualDPT for depth prediction
46
+ - Optional camera decoders for pose estimation
47
+ - Optional GSDPT for 3DGS prediction
48
+
49
+ Args:
50
+ preset: Configuration preset containing network dimensions and settings
51
+
52
+ Returns:
53
+ Dictionary containing:
54
+ - depth: Predicted depth map (B, H, W)
55
+ - depth_conf: Depth confidence map (B, H, W)
56
+ - extrinsics: Camera extrinsics (B, N, 4, 4)
57
+ - intrinsics: Camera intrinsics (B, N, 3, 3)
58
+ - gaussians: 3D Gaussian Splats (world space), type: model.gs_adapter.Gaussians
59
+ - aux: Auxiliary features for specified layers
60
+ """
61
+
62
+ # Patch size for feature extraction
63
+ PATCH_SIZE = 14
64
+
65
+ def __init__(self, net, head, cam_dec=None, cam_enc=None, gs_head=None, gs_adapter=None):
66
+ """
67
+ Initialize DepthAnything3Net with given yaml-initialized configuration.
68
+ """
69
+ super().__init__()
70
+ self.backbone = net if isinstance(net, nn.Module) else create_object(_wrap_cfg(net))
71
+ self.head = head if isinstance(head, nn.Module) else create_object(_wrap_cfg(head))
72
+ self.cam_dec, self.cam_enc = None, None
73
+ if cam_dec is not None:
74
+ self.cam_dec = (
75
+ cam_dec if isinstance(cam_dec, nn.Module) else create_object(_wrap_cfg(cam_dec))
76
+ )
77
+ self.cam_enc = (
78
+ cam_dec if isinstance(cam_enc, nn.Module) else create_object(_wrap_cfg(cam_enc))
79
+ )
80
+ self.gs_adapter, self.gs_head = None, None
81
+ if gs_head is not None and gs_adapter is not None:
82
+ self.gs_adapter = (
83
+ gs_adapter
84
+ if isinstance(gs_adapter, nn.Module)
85
+ else create_object(_wrap_cfg(gs_adapter))
86
+ )
87
+ gs_out_dim = self.gs_adapter.d_in + 1
88
+ if isinstance(gs_head, nn.Module):
89
+ assert (
90
+ gs_head.out_dim == gs_out_dim
91
+ ), f"gs_head.out_dim should be {gs_out_dim}, got {gs_head.out_dim}"
92
+ self.gs_head = gs_head
93
+ else:
94
+ assert (
95
+ gs_head["output_dim"] == gs_out_dim
96
+ ), f"gs_head output_dim should set to {gs_out_dim}, got {gs_head['output_dim']}"
97
+ self.gs_head = create_object(_wrap_cfg(gs_head))
98
+
99
+ def forward(
100
+ self,
101
+ x: torch.Tensor,
102
+ extrinsics: torch.Tensor | None = None,
103
+ intrinsics: torch.Tensor | None = None,
104
+ export_feat_layers: list[int] | None = [],
105
+ infer_gs: bool = False,
106
+ ) -> Dict[str, torch.Tensor]:
107
+ """
108
+ Forward pass through the network.
109
+
110
+ Args:
111
+ x: Input images (B, N, 3, H, W)
112
+ extrinsics: Camera extrinsics (B, N, 4, 4) - unused
113
+ intrinsics: Camera intrinsics (B, N, 3, 3) - unused
114
+ feat_layers: List of layer indices to extract features from
115
+
116
+ Returns:
117
+ Dictionary containing predictions and auxiliary features
118
+ """
119
+ # Extract features using backbone
120
+ if extrinsics is not None:
121
+ with torch.autocast(device_type=x.device.type, enabled=False):
122
+ cam_token = self.cam_enc(extrinsics, intrinsics, x.shape[-2:])
123
+ else:
124
+ cam_token = None
125
+
126
+ feats, aux_feats = self.backbone(
127
+ x, cam_token=cam_token, export_feat_layers=export_feat_layers
128
+ )
129
+ # feats = [[item for item in feat] for feat in feats]
130
+ H, W = x.shape[-2], x.shape[-1]
131
+
132
+ # Process features through depth head
133
+ with torch.autocast(device_type=x.device.type, enabled=False):
134
+ output = self._process_depth_head(feats, H, W)
135
+ output = self._process_camera_estimation(feats, H, W, output)
136
+ if infer_gs:
137
+ output = self._process_gs_head(feats, H, W, output, x, extrinsics, intrinsics)
138
+
139
+ # Extract auxiliary features if requested
140
+ output.aux = self._extract_auxiliary_features(aux_feats, export_feat_layers, H, W)
141
+
142
+ return output
143
+
144
+ def _process_depth_head(
145
+ self, feats: list[torch.Tensor], H: int, W: int
146
+ ) -> Dict[str, torch.Tensor]:
147
+ """Process features through the depth prediction head."""
148
+ return self.head(feats, H, W, patch_start_idx=0)
149
+
150
+ def _process_camera_estimation(
151
+ self, feats: list[torch.Tensor], H: int, W: int, output: Dict[str, torch.Tensor]
152
+ ) -> Dict[str, torch.Tensor]:
153
+ """Process camera pose estimation if camera decoder is available."""
154
+ if self.cam_dec is not None:
155
+ pose_enc = self.cam_dec(feats[-1][1])
156
+ # Remove ray information as it's not needed for pose estimation
157
+ if "ray" in output:
158
+ del output.ray
159
+ if "ray_conf" in output:
160
+ del output.ray_conf
161
+
162
+ # Convert pose encoding to extrinsics and intrinsics
163
+ c2w, ixt = pose_encoding_to_extri_intri(pose_enc, (H, W))
164
+ output.extrinsics = affine_inverse(c2w)
165
+ output.intrinsics = ixt
166
+
167
+ return output
168
+
169
+ def _process_gs_head(
170
+ self,
171
+ feats: list[torch.Tensor],
172
+ H: int,
173
+ W: int,
174
+ output: Dict[str, torch.Tensor],
175
+ in_images: torch.Tensor,
176
+ extrinsics: torch.Tensor | None = None,
177
+ intrinsics: torch.Tensor | None = None,
178
+ ) -> Dict[str, torch.Tensor]:
179
+ """Process 3DGS parameters estimation if 3DGS head is available."""
180
+ if self.gs_head is None or self.gs_adapter is None:
181
+ return output
182
+ assert output.get("depth", None) is not None, "must provide MV depth for the GS head."
183
+
184
+ # if GT camera poses are provided, use them
185
+ if extrinsics is not None and intrinsics is not None:
186
+ ctx_extr = extrinsics
187
+ ctx_intr = intrinsics
188
+ else:
189
+ ctx_extr = output.get("extrinsics", None)
190
+ ctx_intr = output.get("intrinsics", None)
191
+ assert (
192
+ ctx_extr is not None and ctx_intr is not None
193
+ ), "must process camera info first if GT is not available"
194
+ gt_extr = extrinsics
195
+ # homo the extr if needed
196
+ ctx_extr = as_homogeneous(ctx_extr)
197
+ if gt_extr is not None:
198
+ gt_extr = as_homogeneous(gt_extr)
199
+
200
+ # forward through the gs_dpt head to get 'camera space' parameters
201
+ gs_outs = self.gs_head(
202
+ feats=feats,
203
+ H=H,
204
+ W=W,
205
+ patch_start_idx=0,
206
+ images=in_images,
207
+ )
208
+ raw_gaussians = gs_outs.raw_gs
209
+ densities = gs_outs.raw_gs_conf
210
+
211
+ # convert to 'world space' 3DGS parameters; ready to export and render
212
+ # gt_extr could be None, and will be used to align the pose scale if available
213
+ gs_world = self.gs_adapter(
214
+ extrinsics=ctx_extr,
215
+ intrinsics=ctx_intr,
216
+ depths=output.depth,
217
+ opacities=map_pdf_to_opacity(densities),
218
+ raw_gaussians=raw_gaussians,
219
+ image_shape=(H, W),
220
+ gt_extrinsics=gt_extr,
221
+ )
222
+ output.gaussians = gs_world
223
+
224
+ return output
225
+
226
+ def _extract_auxiliary_features(
227
+ self, feats: list[torch.Tensor], feat_layers: list[int], H: int, W: int
228
+ ) -> Dict[str, torch.Tensor]:
229
+ """Extract auxiliary features from specified layers."""
230
+ aux_features = Dict()
231
+ assert len(feats) == len(feat_layers)
232
+ for feat, feat_layer in zip(feats, feat_layers):
233
+ # Reshape features to spatial dimensions
234
+ feat_reshaped = feat.reshape(
235
+ [
236
+ feat.shape[0],
237
+ feat.shape[1],
238
+ H // self.PATCH_SIZE,
239
+ W // self.PATCH_SIZE,
240
+ feat.shape[-1],
241
+ ]
242
+ )
243
+ aux_features[f"feat_layer_{feat_layer}"] = feat_reshaped
244
+
245
+ return aux_features
246
+
247
+
248
+ class NestedDepthAnything3Net(nn.Module):
249
+ """
250
+ Nested Depth Anything 3 network with metric scaling capabilities.
251
+
252
+ This network combines two DepthAnything3Net branches:
253
+ - Main branch: Standard depth estimation
254
+ - Metric branch: Metric depth estimation for scaling alignment
255
+
256
+ The network performs depth alignment using least squares scaling
257
+ and handles sky region masking for improved depth estimation.
258
+
259
+ Args:
260
+ preset: Configuration for the main depth estimation branch
261
+ second_preset: Configuration for the metric depth branch
262
+ """
263
+
264
+ def __init__(self, anyview: DictConfig, metric: DictConfig):
265
+ """
266
+ Initialize NestedDepthAnything3Net with two branches.
267
+
268
+ Args:
269
+ preset: Configuration for main depth estimation branch
270
+ second_preset: Configuration for metric depth branch
271
+ """
272
+ super().__init__()
273
+ self.da3 = create_object(anyview)
274
+ self.da3_metric = create_object(metric)
275
+
276
+ def forward(
277
+ self,
278
+ x: torch.Tensor,
279
+ extrinsics: torch.Tensor | None = None,
280
+ intrinsics: torch.Tensor | None = None,
281
+ export_feat_layers: list[int] | None = [],
282
+ infer_gs: bool = False,
283
+ ) -> Dict[str, torch.Tensor]:
284
+ """
285
+ Forward pass through both branches with metric scaling alignment.
286
+
287
+ Args:
288
+ x: Input images (B, N, 3, H, W)
289
+ extrinsics: Camera extrinsics (B, N, 4, 4) - unused
290
+ intrinsics: Camera intrinsics (B, N, 3, 3) - unused
291
+ feat_layers: List of layer indices to extract features from
292
+ metric_feat: Whether to use metric features (unused)
293
+
294
+ Returns:
295
+ Dictionary containing aligned depth predictions and camera parameters
296
+ """
297
+ # Get predictions from both branches
298
+ output = self.da3(
299
+ x, extrinsics, intrinsics, export_feat_layers=export_feat_layers, infer_gs=infer_gs
300
+ )
301
+ metric_output = self.da3_metric(x, infer_gs=infer_gs)
302
+
303
+ # Apply metric scaling and alignment
304
+ output = self._apply_metric_scaling(output, metric_output)
305
+ output = self._apply_depth_alignment(output, metric_output)
306
+ output = self._handle_sky_regions(output, metric_output)
307
+
308
+ return output
309
+
310
+ def _apply_metric_scaling(
311
+ self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor]
312
+ ) -> Dict[str, torch.Tensor]:
313
+ """Apply metric scaling to the metric depth output."""
314
+ # Scale metric depth based on camera intrinsics
315
+ metric_output.depth = apply_metric_scaling(
316
+ metric_output.depth,
317
+ output.intrinsics,
318
+ )
319
+ return output
320
+
321
+ def _apply_depth_alignment(
322
+ self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor]
323
+ ) -> Dict[str, torch.Tensor]:
324
+ """Apply depth alignment using least squares scaling."""
325
+ # Compute non-sky mask
326
+ non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3)
327
+
328
+ # Ensure we have enough non-sky pixels
329
+ assert non_sky_mask.sum() > 10, "Insufficient non-sky pixels for alignment"
330
+
331
+ # Sample depth confidence for quantile computation
332
+ depth_conf_ns = output.depth_conf[non_sky_mask]
333
+ depth_conf_sampled = sample_tensor_for_quantile(depth_conf_ns, max_samples=100000)
334
+ median_conf = torch.quantile(depth_conf_sampled, 0.5)
335
+
336
+ # Compute alignment mask
337
+ align_mask = compute_alignment_mask(
338
+ output.depth_conf, non_sky_mask, output.depth, metric_output.depth, median_conf
339
+ )
340
+
341
+ # Compute scale factor using least squares
342
+ valid_depth = output.depth[align_mask]
343
+ valid_metric_depth = metric_output.depth[align_mask]
344
+ scale_factor = least_squares_scale_scalar(valid_metric_depth, valid_depth)
345
+
346
+ # Apply scaling to depth and extrinsics
347
+ output.depth *= scale_factor
348
+ output.extrinsics[:, :, :3, 3] *= scale_factor
349
+ output.is_metric = 1
350
+ output.scale_factor = scale_factor.item()
351
+
352
+ return output
353
+
354
+ def _handle_sky_regions(
355
+ self,
356
+ output: Dict[str, torch.Tensor],
357
+ metric_output: Dict[str, torch.Tensor],
358
+ sky_depth_def: float = 200.0,
359
+ ) -> Dict[str, torch.Tensor]:
360
+ """Handle sky regions by setting them to maximum depth."""
361
+ non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3)
362
+
363
+ # Compute maximum depth for non-sky regions
364
+ # Use sampling to safely compute quantile on large tensors
365
+ non_sky_depth = output.depth[non_sky_mask]
366
+ if non_sky_depth.numel() > 100000:
367
+ idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device)
368
+ sampled_depth = non_sky_depth[idx]
369
+ else:
370
+ sampled_depth = non_sky_depth
371
+ non_sky_max = min(torch.quantile(sampled_depth, 0.99), sky_depth_def)
372
+
373
+ # Set sky regions to maximum depth and high confidence
374
+ output.depth, output.depth_conf = set_sky_regions_to_max_depth(
375
+ output.depth, output.depth_conf, non_sky_mask, max_depth=non_sky_max
376
+ )
377
+
378
+ return output
src/depth_anything_3/model/dinov2/dinov2.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+
11
+ from typing import List
12
+ import torch.nn as nn
13
+
14
+ from depth_anything_3.model.dinov2.vision_transformer import (
15
+ vit_base,
16
+ vit_giant2,
17
+ vit_large,
18
+ vit_small,
19
+ )
20
+
21
+
22
+ class DinoV2(nn.Module):
23
+ def __init__(
24
+ self,
25
+ name: str,
26
+ out_layers: List[int],
27
+ alt_start: int = -1,
28
+ qknorm_start: int = -1,
29
+ rope_start: int = -1,
30
+ cat_token: bool = True,
31
+ **kwargs,
32
+ ):
33
+ super().__init__()
34
+ assert name in {"vits", "vitb", "vitl", "vitg"}
35
+ self.name = name
36
+ self.out_layers = out_layers
37
+ self.alt_start = alt_start
38
+ self.qknorm_start = qknorm_start
39
+ self.rope_start = rope_start
40
+ self.cat_token = cat_token
41
+ encoder_map = {
42
+ "vits": vit_small,
43
+ "vitb": vit_base,
44
+ "vitl": vit_large,
45
+ "vitg": vit_giant2,
46
+ }
47
+ encoder_fn = encoder_map[self.name]
48
+ ffn_layer = "swiglufused" if self.name == "vitg" else "mlp"
49
+ self.pretrained = encoder_fn(
50
+ img_size=518,
51
+ patch_size=14,
52
+ ffn_layer=ffn_layer,
53
+ alt_start=alt_start,
54
+ qknorm_start=qknorm_start,
55
+ rope_start=rope_start,
56
+ cat_token=cat_token,
57
+ )
58
+
59
+ def forward(self, x, **kwargs):
60
+ return self.pretrained.get_intermediate_layers(
61
+ x,
62
+ self.out_layers,
63
+ **kwargs,
64
+ )
src/depth_anything_3/model/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # from .attention import MemEffAttention
8
+ from .block import Block
9
+ from .layer_scale import LayerScale
10
+ from .mlp import Mlp
11
+ from .patch_embed import PatchEmbed
12
+ from .rope import PositionGetter, RotaryPositionEmbedding2D
13
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
14
+
15
+ __all__ = [
16
+ Mlp,
17
+ PatchEmbed,
18
+ SwiGLUFFN,
19
+ SwiGLUFFNFused,
20
+ Block,
21
+ # MemEffAttention,
22
+ LayerScale,
23
+ PositionGetter,
24
+ RotaryPositionEmbedding2D,
25
+ ]
src/depth_anything_3/model/dinov2/layers/attention.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
10
+
11
+ import logging
12
+ import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
+
15
+ logger = logging.getLogger("dinov2")
16
+
17
+
18
+ class Attention(nn.Module):
19
+ def __init__(
20
+ self,
21
+ dim: int,
22
+ num_heads: int = 8,
23
+ qkv_bias: bool = False,
24
+ proj_bias: bool = True,
25
+ attn_drop: float = 0.0,
26
+ proj_drop: float = 0.0,
27
+ norm_layer: nn.Module = nn.LayerNorm,
28
+ qk_norm: bool = False,
29
+ fused_attn: bool = True, # use F.scaled_dot_product_attention or not
30
+ rope=None,
31
+ ) -> None:
32
+ super().__init__()
33
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
34
+ self.num_heads = num_heads
35
+ head_dim = dim // num_heads
36
+ self.scale = head_dim**-0.5
37
+ self.fused_attn = fused_attn
38
+
39
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
40
+ self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
41
+ self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
42
+ self.attn_drop = nn.Dropout(attn_drop)
43
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
44
+ self.proj_drop = nn.Dropout(proj_drop)
45
+ self.rope = rope
46
+
47
+ def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
48
+ B, N, C = x.shape
49
+ qkv = (
50
+ self.qkv(x)
51
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
52
+ .permute(2, 0, 3, 1, 4)
53
+ )
54
+ q, k, v = qkv[0], qkv[1], qkv[2]
55
+ q, k = self.q_norm(q), self.k_norm(k)
56
+ if self.rope is not None and pos is not None:
57
+ q = self.rope(q, pos)
58
+ k = self.rope(k, pos)
59
+ if self.fused_attn:
60
+ x = F.scaled_dot_product_attention(
61
+ q,
62
+ k,
63
+ v,
64
+ dropout_p=self.attn_drop.p if self.training else 0.0,
65
+ attn_mask=(
66
+ (attn_mask)[:, None].repeat(1, self.num_heads, 1, 1)
67
+ if attn_mask is not None
68
+ else None
69
+ ),
70
+ )
71
+ else:
72
+ q = q * self.scale
73
+ attn = q @ k.transpose(-2, -1)
74
+ attn = attn.softmax(dim=-1)
75
+ attn = self.attn_drop(attn)
76
+ x = attn @ v
77
+
78
+ x = x.transpose(1, 2).reshape(B, N, C)
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+ def _forward(self, x: Tensor) -> Tensor:
84
+ B, N, C = x.shape
85
+ qkv = (
86
+ self.qkv(x)
87
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
88
+ .permute(2, 0, 3, 1, 4)
89
+ )
90
+
91
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
92
+ attn = q @ k.transpose(-2, -1)
93
+
94
+ attn = attn.softmax(dim=-1)
95
+ attn = self.attn_drop(attn)
96
+
97
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
98
+ x = self.proj(x)
99
+ x = self.proj_drop(x)
100
+ return x
src/depth_anything_3/model/dinov2/layers/block.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa: F821
2
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
3
+ # All rights reserved.
4
+ #
5
+ # This source code is licensed under the license found in the
6
+ # LICENSE file in the root directory of this source tree.
7
+
8
+ # References:
9
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
10
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
11
+
12
+ import logging
13
+ from typing import Callable, Optional
14
+ import torch
15
+ from torch import Tensor, nn
16
+
17
+ from .attention import Attention
18
+ from .drop_path import DropPath
19
+ from .layer_scale import LayerScale
20
+ from .mlp import Mlp
21
+
22
+ logger = logging.getLogger("dinov2")
23
+ XFORMERS_AVAILABLE = True
24
+
25
+
26
+ class Block(nn.Module):
27
+ def __init__(
28
+ self,
29
+ dim: int,
30
+ num_heads: int,
31
+ mlp_ratio: float = 4.0,
32
+ qkv_bias: bool = False,
33
+ proj_bias: bool = True,
34
+ ffn_bias: bool = True,
35
+ drop: float = 0.0,
36
+ attn_drop: float = 0.0,
37
+ init_values=None,
38
+ drop_path: float = 0.0,
39
+ act_layer: Callable[..., nn.Module] = nn.GELU,
40
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
41
+ attn_class: Callable[..., nn.Module] = Attention,
42
+ ffn_layer: Callable[..., nn.Module] = Mlp,
43
+ qk_norm: bool = False,
44
+ rope=None,
45
+ ln_eps: float = 1e-6,
46
+ ) -> None:
47
+ super().__init__()
48
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
49
+ self.norm1 = norm_layer(dim, eps=ln_eps)
50
+ self.attn = attn_class(
51
+ dim,
52
+ num_heads=num_heads,
53
+ qkv_bias=qkv_bias,
54
+ proj_bias=proj_bias,
55
+ attn_drop=attn_drop,
56
+ proj_drop=drop,
57
+ qk_norm=qk_norm,
58
+ rope=rope,
59
+ )
60
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
61
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
62
+
63
+ self.norm2 = norm_layer(dim, eps=ln_eps)
64
+ mlp_hidden_dim = int(dim * mlp_ratio)
65
+ self.mlp = ffn_layer(
66
+ in_features=dim,
67
+ hidden_features=mlp_hidden_dim,
68
+ act_layer=act_layer,
69
+ drop=drop,
70
+ bias=ffn_bias,
71
+ )
72
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
73
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
74
+
75
+ self.sample_drop_ratio = drop_path
76
+
77
+ def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
78
+ def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor:
79
+ return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
80
+
81
+ def ffn_residual_func(x: Tensor) -> Tensor:
82
+ return self.ls2(self.mlp(self.norm2(x)))
83
+
84
+ if self.training and self.sample_drop_ratio > 0.1:
85
+ # the overhead is compensated only for a drop path rate larger than 0.1
86
+ x = drop_add_residual_stochastic_depth(
87
+ x,
88
+ residual_func=attn_residual_func,
89
+ sample_drop_ratio=self.sample_drop_ratio,
90
+ pos=pos,
91
+ )
92
+ x = drop_add_residual_stochastic_depth(
93
+ x,
94
+ residual_func=ffn_residual_func,
95
+ sample_drop_ratio=self.sample_drop_ratio,
96
+ )
97
+ elif self.training and self.sample_drop_ratio > 0.0:
98
+ x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask))
99
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
100
+ else:
101
+ x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
102
+ x = x + ffn_residual_func(x)
103
+ return x
104
+
105
+
106
+ def drop_add_residual_stochastic_depth(
107
+ x: Tensor,
108
+ residual_func: Callable[[Tensor], Tensor],
109
+ sample_drop_ratio: float = 0.0,
110
+ pos: Optional[Tensor] = None,
111
+ ) -> Tensor:
112
+ # 1) extract subset using permutation
113
+ b, n, d = x.shape
114
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
115
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
116
+ x_subset = x[brange]
117
+
118
+ # 2) apply residual_func to get residual
119
+ if pos is not None:
120
+ # if necessary, apply rope to the subset
121
+ pos = pos[brange]
122
+ residual = residual_func(x_subset, pos=pos)
123
+ else:
124
+ residual = residual_func(x_subset)
125
+
126
+ x_flat = x.flatten(1)
127
+ residual = residual.flatten(1)
128
+
129
+ residual_scale_factor = b / sample_subset_size
130
+
131
+ # 3) add the residual
132
+ x_plus_residual = torch.index_add(
133
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
134
+ )
135
+ return x_plus_residual.view_as(x)
136
+
137
+
138
+ def get_branges_scales(x, sample_drop_ratio=0.0):
139
+ b, n, d = x.shape
140
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
141
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
142
+ residual_scale_factor = b / sample_subset_size
143
+ return brange, residual_scale_factor
src/depth_anything_3/model/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
10
+
11
+
12
+ from torch import nn
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ if drop_prob == 0.0 or not training:
17
+ return x
18
+ keep_prob = 1 - drop_prob
19
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
20
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
21
+ if keep_prob > 0.0:
22
+ random_tensor.div_(keep_prob)
23
+ output = x * random_tensor
24
+ return output
25
+
26
+
27
+ class DropPath(nn.Module):
28
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
29
+
30
+ def __init__(self, drop_prob=None):
31
+ super().__init__()
32
+ self.drop_prob = drop_prob
33
+
34
+ def forward(self, x):
35
+ return drop_path(x, self.drop_prob, self.training)
src/depth_anything_3/model/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa: E501
8
+
9
+ from typing import Union
10
+ import torch
11
+ from torch import Tensor, nn
12
+
13
+
14
+ class LayerScale(nn.Module):
15
+ def __init__(
16
+ self,
17
+ dim: int,
18
+ init_values: Union[float, Tensor] = 1e-5,
19
+ inplace: bool = False,
20
+ ) -> None:
21
+ super().__init__()
22
+ self.dim = dim
23
+ self.inplace = inplace
24
+ self.init_values = init_values
25
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
26
+
27
+ def forward(self, x: Tensor) -> Tensor:
28
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
29
+
30
+ def extra_repr(self) -> str:
31
+ return f"{self.dim}, init_values={self.init_values}, inplace={self.inplace}"
src/depth_anything_3/model/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
10
+
11
+
12
+ from typing import Callable, Optional
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
src/depth_anything_3/model/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ # References:
8
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
9
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
10
+
11
+ from typing import Callable, Optional, Tuple, Union
12
+ import torch.nn as nn
13
+ from torch import Tensor
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert (
73
+ H % patch_H == 0
74
+ ), f"Input image height {H} is not a multiple of patch height {patch_H}"
75
+ assert (
76
+ W % patch_W == 0
77
+ ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
78
+
79
+ x = self.proj(x) # B C H W
80
+ H, W = x.size(2), x.size(3)
81
+ x = x.flatten(2).transpose(1, 2) # B HW C
82
+ x = self.norm(x)
83
+ if not self.flatten_embedding:
84
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
85
+ return x
86
+
87
+ def flops(self) -> float:
88
+ Ho, Wo = self.patches_resolution
89
+ flops = (
90
+ Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
91
+ )
92
+ if self.norm is not None:
93
+ flops += Ho * Wo * self.embed_dim
94
+ return flops
src/depth_anything_3/model/dinov2/layers/rope.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ # Implementation of 2D Rotary Position Embeddings (RoPE).
8
+
9
+ # This module provides a clean implementation of 2D Rotary Position Embeddings,
10
+ # which extends the original RoPE concept to handle 2D spatial positions.
11
+
12
+ # Inspired by:
13
+ # https://github.com/meta-llama/codellama/blob/main/llama/model.py
14
+ # https://github.com/naver-ai/rope-vit
15
+
16
+
17
+ from typing import Dict, Tuple
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+
23
+ class PositionGetter:
24
+ """Generates and caches 2D spatial positions for patches in a grid.
25
+
26
+ This class efficiently manages the generation of spatial coordinates for patches
27
+ in a 2D grid, caching results to avoid redundant computations.
28
+
29
+ Attributes:
30
+ position_cache: Dictionary storing precomputed position tensors for different
31
+ grid dimensions.
32
+ """
33
+
34
+ def __init__(self):
35
+ """Initializes the position generator with an empty cache."""
36
+ self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
37
+
38
+ def __call__(
39
+ self, batch_size: int, height: int, width: int, device: torch.device
40
+ ) -> torch.Tensor:
41
+ """Generates spatial positions for a batch of patches.
42
+
43
+ Args:
44
+ batch_size: Number of samples in the batch.
45
+ height: Height of the grid in patches.
46
+ width: Width of the grid in patches.
47
+ device: Target device for the position tensor.
48
+
49
+ Returns:
50
+ Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
51
+ for each position in the grid, repeated for each batch item.
52
+ """
53
+ if (height, width) not in self.position_cache:
54
+ y_coords = torch.arange(height, device=device)
55
+ x_coords = torch.arange(width, device=device)
56
+ positions = torch.cartesian_prod(y_coords, x_coords)
57
+ self.position_cache[height, width] = positions
58
+
59
+ cached_positions = self.position_cache[height, width]
60
+ return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
61
+
62
+
63
+ class RotaryPositionEmbedding2D(nn.Module):
64
+ """2D Rotary Position Embedding implementation.
65
+
66
+ This module applies rotary position embeddings to input tokens based on their
67
+ 2D spatial positions. It handles the position-dependent rotation of features
68
+ separately for vertical and horizontal dimensions.
69
+
70
+ Args:
71
+ frequency: Base frequency for the position embeddings. Default: 100.0
72
+ scaling_factor: Scaling factor for frequency computation. Default: 1.0
73
+
74
+ Attributes:
75
+ base_frequency: Base frequency for computing position embeddings.
76
+ scaling_factor: Factor to scale the computed frequencies.
77
+ frequency_cache: Cache for storing precomputed frequency components.
78
+ """
79
+
80
+ def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
81
+ """Initializes the 2D RoPE module."""
82
+ super().__init__()
83
+ self.base_frequency = frequency
84
+ self.scaling_factor = scaling_factor
85
+ self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
86
+
87
+ def _compute_frequency_components(
88
+ self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
89
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
90
+ """Computes frequency components for rotary embeddings.
91
+
92
+ Args:
93
+ dim: Feature dimension (must be even).
94
+ seq_len: Maximum sequence length.
95
+ device: Target device for computations.
96
+ dtype: Data type for the computed tensors.
97
+
98
+ Returns:
99
+ Tuple of (cosine, sine) tensors for frequency components.
100
+ """
101
+ cache_key = (dim, seq_len, device, dtype)
102
+ if cache_key not in self.frequency_cache:
103
+ # Compute frequency bands
104
+ exponents = torch.arange(0, dim, 2, device=device).float() / dim
105
+ inv_freq = 1.0 / (self.base_frequency**exponents)
106
+
107
+ # Generate position-dependent frequencies
108
+ positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
109
+ angles = torch.einsum("i,j->ij", positions, inv_freq)
110
+
111
+ # Compute and cache frequency components
112
+ angles = angles.to(dtype)
113
+ angles = torch.cat((angles, angles), dim=-1)
114
+ cos_components = angles.cos().to(dtype)
115
+ sin_components = angles.sin().to(dtype)
116
+ self.frequency_cache[cache_key] = (cos_components, sin_components)
117
+
118
+ return self.frequency_cache[cache_key]
119
+
120
+ @staticmethod
121
+ def _rotate_features(x: torch.Tensor) -> torch.Tensor:
122
+ """Performs feature rotation by splitting and recombining feature dimensions.
123
+
124
+ Args:
125
+ x: Input tensor to rotate.
126
+
127
+ Returns:
128
+ Rotated feature tensor.
129
+ """
130
+ feature_dim = x.shape[-1]
131
+ x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
132
+ return torch.cat((-x2, x1), dim=-1)
133
+
134
+ def _apply_1d_rope(
135
+ self,
136
+ tokens: torch.Tensor,
137
+ positions: torch.Tensor,
138
+ cos_comp: torch.Tensor,
139
+ sin_comp: torch.Tensor,
140
+ ) -> torch.Tensor:
141
+ """Applies 1D rotary position embeddings along one dimension.
142
+
143
+ Args:
144
+ tokens: Input token features.
145
+ positions: Position indices.
146
+ cos_comp: Cosine components for rotation.
147
+ sin_comp: Sine components for rotation.
148
+
149
+ Returns:
150
+ Tokens with applied rotary position embeddings.
151
+ """
152
+ # Embed positions with frequency components
153
+ cos = F.embedding(positions, cos_comp)[:, None, :, :]
154
+ sin = F.embedding(positions, sin_comp)[:, None, :, :]
155
+ # Apply rotation
156
+ return (tokens * cos) + (self._rotate_features(tokens) * sin)
157
+
158
+ def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
159
+ """Applies 2D rotary position embeddings to input tokens.
160
+
161
+ Args:
162
+ tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
163
+ The feature dimension (dim) must be divisible by 4.
164
+ positions: Position tensor of shape (batch_size, n_tokens, 2) containing
165
+ the y and x coordinates for each token.
166
+
167
+ Returns:
168
+ Tensor of same shape as input with applied 2D rotary position embeddings.
169
+
170
+ Raises:
171
+ AssertionError: If input dimensions are invalid or positions are malformed.
172
+ """
173
+ # Validate inputs
174
+ assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
175
+ assert (
176
+ positions.ndim == 3 and positions.shape[-1] == 2
177
+ ), "Positions must have shape (batch_size, n_tokens, 2)"
178
+
179
+ # Compute feature dimension for each spatial direction
180
+ feature_dim = tokens.size(-1) // 2
181
+
182
+ # Get frequency components
183
+ max_position = int(positions.max()) + 1
184
+ cos_comp, sin_comp = self._compute_frequency_components(
185
+ feature_dim, max_position, tokens.device, tokens.dtype
186
+ )
187
+
188
+ # Split features for vertical and horizontal processing
189
+ vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
190
+
191
+ # Apply RoPE separately for each dimension
192
+ vertical_features = self._apply_1d_rope(
193
+ vertical_features, positions[..., 0], cos_comp, sin_comp
194
+ )
195
+ horizontal_features = self._apply_1d_rope(
196
+ horizontal_features, positions[..., 1], cos_comp, sin_comp
197
+ )
198
+
199
+ # Combine processed features
200
+ return torch.cat((vertical_features, horizontal_features), dim=-1)
src/depth_anything_3/model/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Callable, Optional
8
+ import torch.nn.functional as F
9
+ from torch import Tensor, nn
10
+
11
+
12
+ class SwiGLUFFN(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_features: int,
16
+ hidden_features: Optional[int] = None,
17
+ out_features: Optional[int] = None,
18
+ act_layer: Callable[..., nn.Module] = None,
19
+ drop: float = 0.0,
20
+ bias: bool = True,
21
+ ) -> None:
22
+ super().__init__()
23
+ out_features = out_features or in_features
24
+ hidden_features = hidden_features or in_features
25
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
26
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
27
+
28
+ def forward(self, x: Tensor) -> Tensor:
29
+ x12 = self.w12(x)
30
+ x1, x2 = x12.chunk(2, dim=-1)
31
+ hidden = F.silu(x1) * x2
32
+ return self.w3(hidden)
33
+
34
+
35
+ try:
36
+ from xformers.ops import SwiGLU
37
+
38
+ XFORMERS_AVAILABLE = True
39
+ except ImportError:
40
+ SwiGLU = SwiGLUFFN
41
+ XFORMERS_AVAILABLE = False
42
+
43
+
44
+ class SwiGLUFFNFused(SwiGLU):
45
+ def __init__(
46
+ self,
47
+ in_features: int,
48
+ hidden_features: Optional[int] = None,
49
+ out_features: Optional[int] = None,
50
+ act_layer: Callable[..., nn.Module] = None,
51
+ drop: float = 0.0,
52
+ bias: bool = True,
53
+ ) -> None:
54
+ out_features = out_features or in_features
55
+ hidden_features = hidden_features or in_features
56
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
57
+ super().__init__(
58
+ in_features=in_features,
59
+ hidden_features=hidden_features,
60
+ out_features=out_features,
61
+ bias=bias,
62
+ )
src/depth_anything_3/model/dinov2/vision_transformer.py ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import math
11
+ from typing import Callable, List, Sequence, Tuple, Union
12
+ import numpy as np
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.utils.checkpoint
16
+ from einops import rearrange
17
+
18
+ from depth_anything_3.utils.logger import logger
19
+
20
+ from .layers import LayerScale # noqa: F401
21
+ from .layers import Mlp # noqa: F401
22
+ from .layers import ( # noqa: F401
23
+ Block,
24
+ PatchEmbed,
25
+ PositionGetter,
26
+ RotaryPositionEmbedding2D,
27
+ SwiGLUFFNFused,
28
+ )
29
+
30
+ # logger = logging.getLogger("dinov2")
31
+
32
+
33
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
34
+ """
35
+ embed_dim: output dimension for each position
36
+ pos: a list of positions to be encoded: size (M,)
37
+ out: (M, D)
38
+ """
39
+ assert embed_dim % 2 == 0
40
+ omega = np.arange(embed_dim // 2, dtype=float)
41
+ omega /= embed_dim / 2.0
42
+ omega = 1.0 / 10000**omega # (D/2,)
43
+
44
+ pos = pos.reshape(-1) # (M,)
45
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
46
+
47
+ emb_sin = np.sin(out) # (M, D/2)
48
+ emb_cos = np.cos(out) # (M, D/2)
49
+
50
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
51
+ return emb
52
+
53
+
54
+ def named_apply(
55
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
56
+ ) -> nn.Module:
57
+ if not depth_first and include_root:
58
+ fn(module=module, name=name)
59
+ for child_name, child_module in module.named_children():
60
+ child_name = ".".join((name, child_name)) if name else child_name
61
+ named_apply(
62
+ fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True
63
+ )
64
+ if depth_first and include_root:
65
+ fn(module=module, name=name)
66
+ return module
67
+
68
+
69
+ class BlockChunk(nn.ModuleList):
70
+ def forward(self, x):
71
+ for b in self:
72
+ x = b(x)
73
+ return x
74
+
75
+
76
+ class DinoVisionTransformer(nn.Module):
77
+ def __init__(
78
+ self,
79
+ img_size=224,
80
+ patch_size=16,
81
+ in_chans=3,
82
+ embed_dim=768,
83
+ depth=12,
84
+ num_heads=12,
85
+ mlp_ratio=4.0,
86
+ qkv_bias=True,
87
+ ffn_bias=True,
88
+ proj_bias=True,
89
+ drop_path_rate=0.0,
90
+ drop_path_uniform=False,
91
+ init_values=1.0, # for layerscale: None or 0 => no layerscale
92
+ embed_layer=PatchEmbed,
93
+ act_layer=nn.GELU,
94
+ block_fn=Block,
95
+ ffn_layer="mlp",
96
+ block_chunks=1,
97
+ num_register_tokens=0,
98
+ interpolate_antialias=False,
99
+ interpolate_offset=0.1,
100
+ alt_start=-1,
101
+ qknorm_start=-1,
102
+ rope_start=-1,
103
+ rope_freq=100,
104
+ plus_cam_token=False,
105
+ cat_token=True,
106
+ ):
107
+ """
108
+ Args:
109
+ img_size (int, tuple): input image size
110
+ patch_size (int, tuple): patch size
111
+ in_chans (int): number of input channels
112
+ embed_dim (int): embedding dimension
113
+ depth (int): depth of transformer
114
+ num_heads (int): number of attention heads
115
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
116
+ qkv_bias (bool): enable bias for qkv if True
117
+ proj_bias (bool): enable bias for proj in attn if True
118
+ ffn_bias (bool): enable bias for ffn if True
119
+ weight_init (str): weight init scheme
120
+ init_values (float): layer-scale init values
121
+ embed_layer (nn.Module): patch embedding layer
122
+ act_layer (nn.Module): MLP activation layer
123
+ block_fn (nn.Module): transformer block class
124
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
125
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
126
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
127
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating
128
+ positional embeddings
129
+ interpolate_offset: (float) work-around offset to apply when interpolating
130
+ positional embeddings
131
+ block_prompt: (bool) whether to add ray embeddings to the block input
132
+ """
133
+ super().__init__()
134
+ self.patch_start_idx = 1
135
+ norm_layer = nn.LayerNorm
136
+ self.num_features = self.embed_dim = (
137
+ embed_dim # num_features for consistency with other models
138
+ )
139
+ self.alt_start = alt_start
140
+ self.qknorm_start = qknorm_start
141
+ self.rope_start = rope_start
142
+ self.cat_token = cat_token
143
+ self.num_tokens = 1
144
+ self.n_blocks = depth
145
+ self.num_heads = num_heads
146
+ self.patch_size = patch_size
147
+ self.num_register_tokens = num_register_tokens
148
+ self.interpolate_antialias = interpolate_antialias
149
+ self.interpolate_offset = interpolate_offset
150
+
151
+ self.patch_embed = embed_layer(
152
+ img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
153
+ )
154
+ num_patches = self.patch_embed.num_patches
155
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
156
+ if self.alt_start != -1:
157
+ self.camera_token = nn.Parameter(torch.randn(1, 2, embed_dim))
158
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
159
+ assert num_register_tokens >= 0
160
+ self.register_tokens = (
161
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
162
+ if num_register_tokens
163
+ else None
164
+ )
165
+
166
+ if drop_path_uniform is True:
167
+ dpr = [drop_path_rate] * depth
168
+ else:
169
+ dpr = [
170
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
171
+ ] # stochastic depth decay rule
172
+ if ffn_layer == "mlp":
173
+ logger.info("using MLP layer as FFN")
174
+ ffn_layer = Mlp
175
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
176
+ logger.info("using SwiGLU layer as FFN")
177
+ ffn_layer = SwiGLUFFNFused
178
+ elif ffn_layer == "identity":
179
+ logger.info("using Identity layer as FFN")
180
+
181
+ def f(*args, **kwargs):
182
+ return nn.Identity()
183
+
184
+ ffn_layer = f
185
+ else:
186
+ raise NotImplementedError
187
+
188
+ if self.rope_start != -1:
189
+ self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
190
+ self.position_getter = PositionGetter() if self.rope is not None else None
191
+ else:
192
+ self.rope = None
193
+ blocks_list = [
194
+ block_fn(
195
+ dim=embed_dim,
196
+ num_heads=num_heads,
197
+ mlp_ratio=mlp_ratio,
198
+ qkv_bias=qkv_bias,
199
+ proj_bias=proj_bias,
200
+ ffn_bias=ffn_bias,
201
+ drop_path=dpr[i],
202
+ norm_layer=norm_layer,
203
+ act_layer=act_layer,
204
+ ffn_layer=ffn_layer,
205
+ init_values=init_values,
206
+ qk_norm=i >= qknorm_start if qknorm_start != -1 else False,
207
+ rope=self.rope if i >= rope_start and rope_start != -1 else None,
208
+ )
209
+ for i in range(depth)
210
+ ]
211
+ self.blocks = nn.ModuleList(blocks_list)
212
+ self.norm = norm_layer(embed_dim)
213
+
214
+ def interpolate_pos_encoding(self, x, w, h):
215
+ previous_dtype = x.dtype
216
+ npatch = x.shape[1] - 1
217
+ N = self.pos_embed.shape[1] - 1
218
+ if npatch == N and w == h:
219
+ return self.pos_embed
220
+ pos_embed = self.pos_embed.float()
221
+ class_pos_embed = pos_embed[:, 0]
222
+ patch_pos_embed = pos_embed[:, 1:]
223
+ dim = x.shape[-1]
224
+ w0 = w // self.patch_size
225
+ h0 = h // self.patch_size
226
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
227
+ assert N == M * M
228
+ kwargs = {}
229
+ if self.interpolate_offset:
230
+ # Historical kludge: add a small number to avoid floating point error in the
231
+ # interpolation, see https://github.com/facebookresearch/dino/issues/8
232
+ # Note: still needed for backward-compatibility, the underlying operators are using
233
+ # both output size and scale factors
234
+ sx = float(w0 + self.interpolate_offset) / M
235
+ sy = float(h0 + self.interpolate_offset) / M
236
+ kwargs["scale_factor"] = (sx, sy)
237
+ else:
238
+ # Simply specify an output size instead of a scale factor
239
+ kwargs["size"] = (w0, h0)
240
+ patch_pos_embed = nn.functional.interpolate(
241
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
242
+ mode="bicubic",
243
+ antialias=self.interpolate_antialias,
244
+ **kwargs,
245
+ )
246
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
247
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
248
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
249
+
250
+ def prepare_cls_token(self, B, S):
251
+ cls_token = self.cls_token.expand(B, S, -1)
252
+ cls_token = cls_token.reshape(B * S, -1, self.embed_dim)
253
+ return cls_token
254
+
255
+ def prepare_tokens_with_masks(self, x, masks=None, cls_token=None, **kwargs):
256
+ B, S, nc, w, h = x.shape
257
+ x = rearrange(x, "b s c h w -> (b s) c h w")
258
+ x = self.patch_embed(x)
259
+ if masks is not None:
260
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
261
+ cls_token = self.prepare_cls_token(B, S)
262
+ x = torch.cat((cls_token, x), dim=1)
263
+ x = x + self.interpolate_pos_encoding(x, w, h)
264
+ if self.register_tokens is not None:
265
+ x = torch.cat(
266
+ (
267
+ x[:, :1],
268
+ self.register_tokens.expand(x.shape[0], -1, -1),
269
+ x[:, 1:],
270
+ ),
271
+ dim=1,
272
+ )
273
+ x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
274
+ return x
275
+
276
+ def _prepare_rope(self, B, S, H, W, device):
277
+ pos = None
278
+ pos_nodiff = None
279
+ if self.rope is not None:
280
+ pos = self.position_getter(
281
+ B * S, H // self.patch_size, W // self.patch_size, device=device
282
+ )
283
+ pos = rearrange(pos, "(b s) n c -> b s n c", b=B)
284
+ pos_nodiff = torch.zeros_like(pos).to(pos.dtype)
285
+ if self.patch_start_idx > 0:
286
+ pos = pos + 1
287
+ pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(device).to(pos.dtype)
288
+ pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B)
289
+ pos = torch.cat([pos_special, pos], dim=2)
290
+ pos_nodiff = pos_nodiff + 1
291
+ pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2)
292
+ return pos, pos_nodiff
293
+
294
+ def _get_intermediate_layers_not_chunked(self, x, n=1, export_feat_layers=[], **kwargs):
295
+ B, S, _, H, W = x.shape
296
+ x = self.prepare_tokens_with_masks(x)
297
+ output, total_block_len, aux_output = [], len(self.blocks), []
298
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
299
+ pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
300
+
301
+ for i, blk in enumerate(self.blocks):
302
+ if i < self.rope_start or self.rope is None:
303
+ g_pos, l_pos = None, None
304
+ else:
305
+ g_pos = pos_nodiff
306
+ l_pos = pos
307
+ if self.alt_start != -1 and i == self.alt_start:
308
+ if kwargs.get("cam_token", None) is not None:
309
+ logger.info("Using camera conditions provided by the user")
310
+ cam_token = kwargs.get("cam_token")
311
+ else:
312
+ ref_token = self.camera_token[:, :1].expand(B, -1, -1)
313
+ src_token = self.camera_token[:, 1:].expand(B, S - 1, -1)
314
+ cam_token = torch.cat([ref_token, src_token], dim=1)
315
+ x[:, :, 0] = cam_token
316
+
317
+ if self.alt_start != -1 and i >= self.alt_start and i % 2 == 1:
318
+ x = self.process_attention(
319
+ x, blk, "global", pos=g_pos, attn_mask=kwargs.get("attn_mask", None)
320
+ )
321
+ else:
322
+ x = self.process_attention(x, blk, "local", pos=l_pos)
323
+ local_x = x
324
+
325
+ if i in blocks_to_take:
326
+ out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
327
+ output.append((out_x[:, :, 0], out_x))
328
+ if i in export_feat_layers:
329
+ aux_output.append(x)
330
+ return output, aux_output
331
+
332
+ def process_attention(self, x, block, attn_type="global", pos=None, attn_mask=None):
333
+ b, s, n = x.shape[:3]
334
+ if attn_type == "local":
335
+ x = rearrange(x, "b s n c -> (b s) n c")
336
+ if pos is not None:
337
+ pos = rearrange(pos, "b s n c -> (b s) n c")
338
+ elif attn_type == "global":
339
+ x = rearrange(x, "b s n c -> b (s n) c")
340
+ if pos is not None:
341
+ pos = rearrange(pos, "b s n c -> b (s n) c")
342
+ else:
343
+ raise ValueError(f"Invalid attention type: {attn_type}")
344
+
345
+ x = block(x, pos=pos, attn_mask=attn_mask)
346
+
347
+ if attn_type == "local":
348
+ x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
349
+ elif attn_type == "global":
350
+ x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s)
351
+ return x
352
+
353
+ def get_intermediate_layers(
354
+ self,
355
+ x: torch.Tensor,
356
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
357
+ export_feat_layers: List[int] = [],
358
+ **kwargs,
359
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
360
+ outputs, aux_outputs = self._get_intermediate_layers_not_chunked(
361
+ x, n, export_feat_layers=export_feat_layers, **kwargs
362
+ )
363
+ camera_tokens = [out[0] for out in outputs]
364
+ if outputs[0][1].shape[-1] == self.embed_dim:
365
+ outputs = [self.norm(out[1]) for out in outputs]
366
+ elif outputs[0][1].shape[-1] == (self.embed_dim * 2):
367
+ outputs = [
368
+ torch.cat(
369
+ [out[1][..., : self.embed_dim], self.norm(out[1][..., self.embed_dim :])],
370
+ dim=-1,
371
+ )
372
+ for out in outputs
373
+ ]
374
+ else:
375
+ raise ValueError(f"Invalid output shape: {outputs[0][1].shape}")
376
+ aux_outputs = [self.norm(out) for out in aux_outputs]
377
+ outputs = [out[..., 1 + self.num_register_tokens :, :] for out in outputs]
378
+ aux_outputs = [out[..., 1 + self.num_register_tokens :, :] for out in aux_outputs]
379
+ return tuple(zip(outputs, camera_tokens)), aux_outputs
380
+
381
+
382
+ def vit_small(patch_size=16, num_register_tokens=0, depth=12, **kwargs):
383
+ model = DinoVisionTransformer(
384
+ patch_size=patch_size,
385
+ embed_dim=384,
386
+ depth=depth,
387
+ num_heads=6,
388
+ mlp_ratio=4,
389
+ # block_fn=partial(Block, attn_class=MemEffAttention),
390
+ num_register_tokens=num_register_tokens,
391
+ **kwargs,
392
+ )
393
+ return model
394
+
395
+
396
+ def vit_base(patch_size=16, num_register_tokens=0, depth=12, **kwargs):
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=768,
400
+ depth=depth,
401
+ num_heads=12,
402
+ mlp_ratio=4,
403
+ # block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ **kwargs,
406
+ )
407
+ return model
408
+
409
+
410
+ def vit_large(patch_size=16, num_register_tokens=0, depth=24, **kwargs):
411
+ model = DinoVisionTransformer(
412
+ patch_size=patch_size,
413
+ embed_dim=1024,
414
+ depth=depth,
415
+ num_heads=16,
416
+ mlp_ratio=4,
417
+ # block_fn=partial(Block, attn_class=MemEffAttention),
418
+ num_register_tokens=num_register_tokens,
419
+ **kwargs,
420
+ )
421
+ return model
422
+
423
+
424
+ def vit_giant2(patch_size=16, num_register_tokens=0, depth=40, **kwargs):
425
+ """
426
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
427
+ """
428
+ model = DinoVisionTransformer(
429
+ patch_size=patch_size,
430
+ embed_dim=1536,
431
+ depth=depth,
432
+ num_heads=24,
433
+ mlp_ratio=4,
434
+ num_register_tokens=num_register_tokens,
435
+ **kwargs,
436
+ )
437
+ return model
src/depth_anything_3/model/dpt.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa E501
2
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Dict as TyDict
17
+ from typing import List, Sequence, Tuple
18
+ import torch
19
+ import torch.nn as nn
20
+ from addict import Dict
21
+ from einops import rearrange
22
+
23
+ from depth_anything_3.model.utils.head_utils import (
24
+ Permute,
25
+ create_uv_grid,
26
+ custom_interpolate,
27
+ position_grid_to_embed,
28
+ )
29
+
30
+
31
+ class DPT(nn.Module):
32
+ """
33
+ DPT for dense prediction (main head + optional sky head, sky always 1 channel).
34
+
35
+ Returns:
36
+ - Main head:
37
+ * If output_dim>1: { head_name, f"{head_name}_conf" }
38
+ * If output_dim==1: { head_name }
39
+ - Sky head (if use_sky_head=True): { sky_name } # [B, S, 1, H/down_ratio, W/down_ratio]
40
+ """
41
+
42
+ def __init__(
43
+ self,
44
+ dim_in: int,
45
+ *,
46
+ patch_size: int = 14,
47
+ output_dim: int = 1,
48
+ activation: str = "exp",
49
+ conf_activation: str = "expp1",
50
+ features: int = 256,
51
+ out_channels: Sequence[int] = (256, 512, 1024, 1024),
52
+ pos_embed: bool = False,
53
+ down_ratio: int = 1,
54
+ head_name: str = "depth",
55
+ # ---- sky head (fixed 1 channel) ----
56
+ use_sky_head: bool = True,
57
+ sky_name: str = "sky",
58
+ sky_activation: str = "relu", # 'sigmoid' / 'relu' / 'linear'
59
+ use_ln_for_heads: bool = False, # If needed, apply LayerNorm on intermediate features of both heads
60
+ norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer"
61
+ fusion_block_inplace: bool = False,
62
+ ) -> None:
63
+ super().__init__()
64
+
65
+ # -------------------- configuration --------------------
66
+ self.patch_size = patch_size
67
+ self.activation = activation
68
+ self.conf_activation = conf_activation
69
+ self.pos_embed = pos_embed
70
+ self.down_ratio = down_ratio
71
+
72
+ # Names
73
+ self.head_main = head_name
74
+ self.sky_name = sky_name
75
+
76
+ # Main head: output dimension and confidence switch
77
+ self.out_dim = output_dim
78
+ self.has_conf = output_dim > 1
79
+
80
+ # Sky head parameters (always 1 channel)
81
+ self.use_sky_head = use_sky_head
82
+ self.sky_activation = sky_activation
83
+
84
+ # Fixed 4 intermediate outputs
85
+ self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
86
+
87
+ # -------------------- token pre-norm + per-stage projection --------------------
88
+ if norm_type == "layer":
89
+ self.norm = nn.LayerNorm(dim_in)
90
+ elif norm_type == "idt":
91
+ self.norm = nn.Identity()
92
+ else:
93
+ raise Exception(f"Unknown norm_type {norm_type}, should be 'layer' or 'idt'.")
94
+ self.projects = nn.ModuleList(
95
+ [nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
96
+ )
97
+
98
+ # -------------------- Spatial re-size (align to common scale before fusion) --------------------
99
+ # Design consistent with original: relative to patch grid (x4, x2, x1, /2)
100
+ self.resize_layers = nn.ModuleList(
101
+ [
102
+ nn.ConvTranspose2d(
103
+ out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0
104
+ ),
105
+ nn.ConvTranspose2d(
106
+ out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0
107
+ ),
108
+ nn.Identity(),
109
+ nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1),
110
+ ]
111
+ )
112
+
113
+ # -------------------- scratch: stage adapters + main fusion chain --------------------
114
+ self.scratch = _make_scratch(list(out_channels), features, expand=False)
115
+
116
+ # Main fusion chain
117
+ self.scratch.refinenet1 = _make_fusion_block(features, inplace=fusion_block_inplace)
118
+ self.scratch.refinenet2 = _make_fusion_block(features, inplace=fusion_block_inplace)
119
+ self.scratch.refinenet3 = _make_fusion_block(features, inplace=fusion_block_inplace)
120
+ self.scratch.refinenet4 = _make_fusion_block(
121
+ features, has_residual=False, inplace=fusion_block_inplace
122
+ )
123
+
124
+ # Heads (shared neck1; then split into two heads)
125
+ head_features_1 = features
126
+ head_features_2 = 32
127
+ self.scratch.output_conv1 = nn.Conv2d(
128
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
129
+ )
130
+
131
+ ln_seq = (
132
+ [Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))]
133
+ if use_ln_for_heads
134
+ else []
135
+ )
136
+
137
+ # Main head
138
+ self.scratch.output_conv2 = nn.Sequential(
139
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
140
+ *ln_seq,
141
+ nn.ReLU(inplace=True),
142
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
143
+ )
144
+
145
+ # Sky head (fixed 1 channel)
146
+ if self.use_sky_head:
147
+ self.scratch.sky_output_conv2 = nn.Sequential(
148
+ nn.Conv2d(
149
+ head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1
150
+ ),
151
+ *ln_seq,
152
+ nn.ReLU(inplace=True),
153
+ nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
154
+ )
155
+
156
+ # -------------------------------------------------------------------------
157
+ # Public forward (supports frame chunking to save memory)
158
+ # -------------------------------------------------------------------------
159
+ def forward(
160
+ self,
161
+ feats: List[torch.Tensor],
162
+ H: int,
163
+ W: int,
164
+ patch_start_idx: int,
165
+ chunk_size: int = 8,
166
+ **kwargs,
167
+ ) -> Dict:
168
+ """
169
+ Args:
170
+ feats: List of 4 entries, each entry is a tensor like [B, S, T, C] (or the 0th element of tuple/list is that tensor).
171
+ H, W: Original image dimensions
172
+ patch_start_idx: Starting index of patch tokens in sequence (for cropping non-patch tokens)
173
+ chunk_size: Chunk size along time dimension S
174
+
175
+ Returns:
176
+ Dict[str, Tensor]
177
+ """
178
+ B, S, N, C = feats[0][0].shape
179
+ feats = [feat[0].reshape(B * S, N, C) for feat in feats]
180
+
181
+ # update image info, used by the GS-DPT head
182
+ extra_kwargs = {}
183
+ if "images" in kwargs:
184
+ extra_kwargs.update({"images": rearrange(kwargs["images"], "B S ... -> (B S) ...")})
185
+
186
+ if chunk_size is None or chunk_size >= S:
187
+ out_dict = self._forward_impl(feats, H, W, patch_start_idx, **extra_kwargs)
188
+ out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
189
+ return Dict(out_dict)
190
+
191
+ out_dicts: List[TyDict[str, torch.Tensor]] = []
192
+ for s0 in range(0, S, chunk_size):
193
+ s1 = min(s0 + chunk_size, S)
194
+ kw = {}
195
+ if "images" in extra_kwargs:
196
+ kw.update({"images": extra_kwargs["images"][s0:s1]})
197
+ out_dicts.append(
198
+ self._forward_impl([f[s0:s1] for f in feats], H, W, patch_start_idx, **kw)
199
+ )
200
+ out_dict = {k: torch.cat([od[k] for od in out_dicts], dim=0) for k in out_dicts[0].keys()}
201
+ out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
202
+ return Dict(out_dict)
203
+
204
+ # -------------------------------------------------------------------------
205
+ # Internal forward (single chunk)
206
+ # -------------------------------------------------------------------------
207
+ def _forward_impl(
208
+ self,
209
+ feats: List[torch.Tensor],
210
+ H: int,
211
+ W: int,
212
+ patch_start_idx: int,
213
+ ) -> TyDict[str, torch.Tensor]:
214
+ B, _, C = feats[0].shape
215
+ ph, pw = H // self.patch_size, W // self.patch_size
216
+ resized_feats = []
217
+ for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
218
+ x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C]
219
+ x = self.norm(x)
220
+ x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
221
+
222
+ x = self.projects[stage_idx](x)
223
+ if self.pos_embed:
224
+ x = self._add_pos_embed(x, W, H)
225
+ x = self.resize_layers[stage_idx](x) # Align scale
226
+ resized_feats.append(x)
227
+
228
+ # 2) Fusion pyramid (main branch only)
229
+ fused = self._fuse(resized_feats)
230
+
231
+ # 3) Upsample to target resolution, optionally add position encoding again
232
+ h_out = int(ph * self.patch_size / self.down_ratio)
233
+ w_out = int(pw * self.patch_size / self.down_ratio)
234
+
235
+ fused = self.scratch.output_conv1(fused)
236
+ fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
237
+ if self.pos_embed:
238
+ fused = self._add_pos_embed(fused, W, H)
239
+
240
+ # 4) Shared neck1
241
+ feat = fused
242
+
243
+ # 5) Main head: logits -> activation
244
+ main_logits = self.scratch.output_conv2(feat)
245
+ outs: TyDict[str, torch.Tensor] = {}
246
+ if self.has_conf:
247
+ fmap = main_logits.permute(0, 2, 3, 1)
248
+ pred = self._apply_activation_single(fmap[..., :-1], self.activation)
249
+ conf = self._apply_activation_single(fmap[..., -1], self.conf_activation)
250
+ outs[self.head_main] = pred.squeeze(1)
251
+ outs[f"{self.head_main}_conf"] = conf.squeeze(1)
252
+ else:
253
+ outs[self.head_main] = self._apply_activation_single(
254
+ main_logits, self.activation
255
+ ).squeeze(1)
256
+
257
+ # 6) Sky head (fixed 1 channel)
258
+ if self.use_sky_head:
259
+ sky_logits = self.scratch.sky_output_conv2(feat)
260
+ outs[self.sky_name] = self._apply_sky_activation(sky_logits).squeeze(1)
261
+
262
+ return outs
263
+
264
+ # -------------------------------------------------------------------------
265
+ # Subroutines
266
+ # -------------------------------------------------------------------------
267
+ def _fuse(self, feats: List[torch.Tensor]) -> torch.Tensor:
268
+ """
269
+ 4-layer top-down fusion, returns finest scale features (after fusion, before neck1).
270
+ """
271
+ l1, l2, l3, l4 = feats
272
+
273
+ l1_rn = self.scratch.layer1_rn(l1)
274
+ l2_rn = self.scratch.layer2_rn(l2)
275
+ l3_rn = self.scratch.layer3_rn(l3)
276
+ l4_rn = self.scratch.layer4_rn(l4)
277
+
278
+ # 4 -> 3 -> 2 -> 1
279
+ out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
280
+ out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
281
+ out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
282
+ out = self.scratch.refinenet1(out, l1_rn)
283
+ return out
284
+
285
+ def _apply_activation_single(
286
+ self, x: torch.Tensor, activation: str = "linear"
287
+ ) -> torch.Tensor:
288
+ """
289
+ Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case.
290
+ Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1
291
+ """
292
+ act = activation.lower() if isinstance(activation, str) else activation
293
+ if act == "exp":
294
+ return torch.exp(x)
295
+ if act == "expp1":
296
+ return torch.exp(x) + 1
297
+ if act == "expm1":
298
+ return torch.expm1(x)
299
+ if act == "relu":
300
+ return torch.relu(x)
301
+ if act == "sigmoid":
302
+ return torch.sigmoid(x)
303
+ if act == "softplus":
304
+ return torch.nn.functional.softplus(x)
305
+ if act == "tanh":
306
+ return torch.tanh(x)
307
+ # Default linear
308
+ return x
309
+
310
+ def _apply_sky_activation(self, x: torch.Tensor) -> torch.Tensor:
311
+ """
312
+ Sky head activation (fixed 1 channel):
313
+ * 'sigmoid' -> Sigmoid probability map
314
+ * 'relu' -> ReLU positive domain output
315
+ * 'linear' -> Original value (logits)
316
+ """
317
+ act = (
318
+ self.sky_activation.lower()
319
+ if isinstance(self.sky_activation, str)
320
+ else self.sky_activation
321
+ )
322
+ if act == "sigmoid":
323
+ return torch.sigmoid(x)
324
+ if act == "relu":
325
+ return torch.relu(x)
326
+ # 'linear'
327
+ return x
328
+
329
+ def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
330
+ """Simple UV position encoding directly added to feature map."""
331
+ pw, ph = x.shape[-1], x.shape[-2]
332
+ pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
333
+ pe = position_grid_to_embed(pe, x.shape[1]) * ratio
334
+ pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
335
+ return x + pe
336
+
337
+
338
+ # -----------------------------------------------------------------------------
339
+ # Building blocks (preserved, consistent with original)
340
+ # -----------------------------------------------------------------------------
341
+ def _make_fusion_block(
342
+ features: int,
343
+ size: Tuple[int, int] = None,
344
+ has_residual: bool = True,
345
+ groups: int = 1,
346
+ inplace: bool = False,
347
+ ) -> nn.Module:
348
+ return FeatureFusionBlock(
349
+ features=features,
350
+ activation=nn.ReLU(inplace=inplace),
351
+ deconv=False,
352
+ bn=False,
353
+ expand=False,
354
+ align_corners=True,
355
+ size=size,
356
+ has_residual=has_residual,
357
+ groups=groups,
358
+ )
359
+
360
+
361
+ def _make_scratch(
362
+ in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
363
+ ) -> nn.Module:
364
+ scratch = nn.Module()
365
+ # Optional expansion by stage
366
+ c1 = out_shape
367
+ c2 = out_shape * (2 if expand else 1)
368
+ c3 = out_shape * (4 if expand else 1)
369
+ c4 = out_shape * (8 if expand else 1)
370
+
371
+ scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups)
372
+ scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups)
373
+ scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups)
374
+ scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups)
375
+ return scratch
376
+
377
+
378
+ class ResidualConvUnit(nn.Module):
379
+ """Lightweight residual convolution block for fusion"""
380
+
381
+ def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None:
382
+ super().__init__()
383
+ self.bn = bn
384
+ self.groups = groups
385
+ self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
386
+ self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
387
+ self.norm1 = None
388
+ self.norm2 = None
389
+ self.activation = activation
390
+ self.skip_add = nn.quantized.FloatFunctional()
391
+
392
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
393
+ out = self.activation(x)
394
+ out = self.conv1(out)
395
+ if self.norm1 is not None:
396
+ out = self.norm1(out)
397
+
398
+ out = self.activation(out)
399
+ out = self.conv2(out)
400
+ if self.norm2 is not None:
401
+ out = self.norm2(out)
402
+
403
+ return self.skip_add.add(out, x)
404
+
405
+
406
+ class FeatureFusionBlock(nn.Module):
407
+ """Top-down fusion block: (optional) residual merge + upsampling + 1x1 contraction"""
408
+
409
+ def __init__(
410
+ self,
411
+ features: int,
412
+ activation: nn.Module,
413
+ deconv: bool = False,
414
+ bn: bool = False,
415
+ expand: bool = False,
416
+ align_corners: bool = True,
417
+ size: Tuple[int, int] = None,
418
+ has_residual: bool = True,
419
+ groups: int = 1,
420
+ ) -> None:
421
+ super().__init__()
422
+ self.align_corners = align_corners
423
+ self.size = size
424
+ self.has_residual = has_residual
425
+
426
+ self.resConfUnit1 = (
427
+ ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None
428
+ )
429
+ self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups)
430
+
431
+ out_features = (features // 2) if expand else features
432
+ self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups)
433
+ self.skip_add = nn.quantized.FloatFunctional()
434
+
435
+ def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override]
436
+ """
437
+ xs:
438
+ - xs[0]: Top branch input
439
+ - xs[1]: Lateral input (can do residual addition with top branch)
440
+ """
441
+ y = xs[0]
442
+ if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
443
+ y = self.skip_add.add(y, self.resConfUnit1(xs[1]))
444
+
445
+ y = self.resConfUnit2(y)
446
+
447
+ # Upsampling
448
+ if (size is None) and (self.size is None):
449
+ up_kwargs = {"scale_factor": 2}
450
+ elif size is None:
451
+ up_kwargs = {"size": self.size}
452
+ else:
453
+ up_kwargs = {"size": size}
454
+
455
+ y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
456
+ y = self.out_conv(y)
457
+ return y
src/depth_anything_3/model/dualdpt.py ADDED
@@ -0,0 +1,488 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # flake8: noqa E501
2
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Sequence, Tuple
17
+ import torch
18
+ import torch.nn as nn
19
+ from addict import Dict
20
+
21
+ from depth_anything_3.model.dpt import _make_fusion_block, _make_scratch
22
+ from depth_anything_3.model.utils.head_utils import (
23
+ Permute,
24
+ create_uv_grid,
25
+ custom_interpolate,
26
+ position_grid_to_embed,
27
+ )
28
+
29
+
30
+ class DualDPT(nn.Module):
31
+ """
32
+ Dual-head DPT for dense prediction with an always-on auxiliary head.
33
+
34
+ Architectural notes:
35
+ - Sky/object branches are removed.
36
+ - `intermediate_layer_idx` is fixed to (0, 1, 2, 3).
37
+ - Auxiliary head has its **own** fusion blocks (no fusion_inplace / no sharing).
38
+ - Auxiliary head is internally multi-level; **only the final level** is returned.
39
+ - Returns a **dict** with keys from `head_names`, e.g.:
40
+ { main_name, f"{main_name}_conf", aux_name, f"{aux_name}_conf" }
41
+ - `feature_only` is fixed to False.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ dim_in: int,
47
+ *,
48
+ patch_size: int = 14,
49
+ output_dim: int = 2,
50
+ activation: str = "exp",
51
+ conf_activation: str = "expp1",
52
+ features: int = 256,
53
+ out_channels: Sequence[int] = (256, 512, 1024, 1024),
54
+ pos_embed: bool = True,
55
+ down_ratio: int = 1,
56
+ aux_pyramid_levels: int = 4,
57
+ aux_out1_conv_num: int = 5,
58
+ head_names: Tuple[str, str] = ("depth", "ray"),
59
+ ) -> None:
60
+ super().__init__()
61
+
62
+ # -------------------- configuration --------------------
63
+ self.patch_size = patch_size
64
+ self.activation = activation
65
+ self.conf_activation = conf_activation
66
+ self.pos_embed = pos_embed
67
+ self.down_ratio = down_ratio
68
+
69
+ self.aux_levels = aux_pyramid_levels
70
+ self.aux_out1_conv_num = aux_out1_conv_num
71
+
72
+ # names ONLY come from config (no hard-coded strings elsewhere)
73
+ self.head_main, self.head_aux = head_names
74
+
75
+ # Always expect 4 scales; enforce intermediate idx = (0, 1, 2, 3)
76
+ self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
77
+
78
+ # -------------------- token pre-norm + per-stage projection --------------------
79
+ self.norm = nn.LayerNorm(dim_in)
80
+ self.projects = nn.ModuleList(
81
+ [nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
82
+ )
83
+
84
+ # -------------------- spatial re-sizers (align to common scale before fusion) --------------------
85
+ # design: stage strides (x4, x2, x1, /2) relative to patch grid to align to a common pivot scale
86
+ self.resize_layers = nn.ModuleList(
87
+ [
88
+ nn.ConvTranspose2d(
89
+ out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0
90
+ ),
91
+ nn.ConvTranspose2d(
92
+ out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0
93
+ ),
94
+ nn.Identity(),
95
+ nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1),
96
+ ]
97
+ )
98
+
99
+ # -------------------- scratch: stage adapters + fusion (main & aux are separate) --------------------
100
+ self.scratch = _make_scratch(list(out_channels), features, expand=False)
101
+
102
+ # Main fusion chain (independent)
103
+ self.scratch.refinenet1 = _make_fusion_block(features)
104
+ self.scratch.refinenet2 = _make_fusion_block(features)
105
+ self.scratch.refinenet3 = _make_fusion_block(features)
106
+ self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
107
+
108
+ # Primary head neck + head (independent)
109
+ head_features_1 = features
110
+ head_features_2 = 32
111
+ self.scratch.output_conv1 = nn.Conv2d(
112
+ head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
113
+ )
114
+ self.scratch.output_conv2 = nn.Sequential(
115
+ nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
116
+ nn.ReLU(inplace=True),
117
+ nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
118
+ )
119
+
120
+ # Auxiliary fusion chain (completely separate; no sharing, i.e., "fusion_inplace=False")
121
+ self.scratch.refinenet1_aux = _make_fusion_block(features)
122
+ self.scratch.refinenet2_aux = _make_fusion_block(features)
123
+ self.scratch.refinenet3_aux = _make_fusion_block(features)
124
+ self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False)
125
+
126
+ # Aux pre-head per level (we will only *return final level*)
127
+ self.scratch.output_conv1_aux = nn.ModuleList(
128
+ [self._make_aux_out1_block(head_features_1) for _ in range(self.aux_levels)]
129
+ )
130
+
131
+ # Aux final projection per level
132
+ use_ln = True
133
+ ln_seq = (
134
+ [Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))]
135
+ if use_ln
136
+ else []
137
+ )
138
+ self.scratch.output_conv2_aux = nn.ModuleList(
139
+ [
140
+ nn.Sequential(
141
+ nn.Conv2d(
142
+ head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1
143
+ ),
144
+ *ln_seq,
145
+ nn.ReLU(inplace=True),
146
+ nn.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0),
147
+ )
148
+ for _ in range(self.aux_levels)
149
+ ]
150
+ )
151
+
152
+ # -------------------------------------------------------------------------
153
+ # Public forward (supports frame chunking for memory)
154
+ # -------------------------------------------------------------------------
155
+
156
+ def forward(
157
+ self,
158
+ feats: List[torch.Tensor],
159
+ H: int,
160
+ W: int,
161
+ patch_start_idx: int,
162
+ chunk_size: int = 8,
163
+ ) -> Dict[str, torch.Tensor]:
164
+ """
165
+ Args:
166
+ aggregated_tokens_list: List of 4 tensors [B, S, T, C] from transformer.
167
+ images: [B, S, 3, H, W], in [0, 1].
168
+ patch_start_idx: Patch-token start in the token sequence (to drop non-patch tokens).
169
+ frames_chunk_size: Optional chunking along S for memory.
170
+
171
+ Returns:
172
+ Dict[str, Tensor] with keys based on `head_names`, e.g.:
173
+ self.head_main, f"{self.head_main}_conf",
174
+ self.head_aux, f"{self.head_aux}_conf"
175
+ Shapes:
176
+ main: [B, S, out_dim, H/down_ratio, W/down_ratio]
177
+ main_cf: [B, S, 1, H/down_ratio, W/down_ratio]
178
+ aux: [B, S, 7, H/down_ratio, W/down_ratio]
179
+ aux_cf: [B, S, 1, H/down_ratio, W/down_ratio]
180
+ """
181
+ B, S, N, C = feats[0][0].shape
182
+ feats = [feat[0].reshape(B * S, N, C) for feat in feats]
183
+ if chunk_size is None or chunk_size >= S:
184
+ out_dict = self._forward_impl(feats, H, W, patch_start_idx)
185
+ out_dict = {k: v.reshape(B, S, *v.shape[1:]) for k, v in out_dict.items()}
186
+ return Dict(out_dict)
187
+ out_dicts = []
188
+ for s0 in range(0, S, chunk_size):
189
+ s1 = min(s0 + chunk_size, S)
190
+ out_dict = self._forward_impl(
191
+ [feat[s0:s1] for feat in feats],
192
+ H,
193
+ W,
194
+ patch_start_idx,
195
+ )
196
+ out_dicts.append(out_dict)
197
+ out_dict = {
198
+ k: torch.cat([out_dict[k] for out_dict in out_dicts], dim=0)
199
+ for k in out_dicts[0].keys()
200
+ }
201
+ out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
202
+ return Dict(out_dict)
203
+
204
+ # -------------------------------------------------------------------------
205
+ # Internal forward (single chunk)
206
+ # -------------------------------------------------------------------------
207
+
208
+ def _forward_impl(
209
+ self,
210
+ feats: List[torch.Tensor],
211
+ H: int,
212
+ W: int,
213
+ patch_start_idx: int,
214
+ ) -> Dict[str, torch.Tensor]:
215
+ B, _, C = feats[0].shape
216
+ ph, pw = H // self.patch_size, W // self.patch_size
217
+ resized_feats = []
218
+ for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
219
+ x = feats[take_idx][:, patch_start_idx:]
220
+ x = self.norm(x)
221
+ x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
222
+
223
+ x = self.projects[stage_idx](x)
224
+ if self.pos_embed:
225
+ x = self._add_pos_embed(x, W, H)
226
+ x = self.resize_layers[stage_idx](x) # align scales
227
+ resized_feats.append(x)
228
+
229
+ # 2) Fuse pyramid (main & aux are completely independent)
230
+ fused_main, fused_aux_pyr = self._fuse(resized_feats)
231
+
232
+ # 3) Upsample to target resolution and (optional) add pos-embed again
233
+ h_out = int(ph * self.patch_size / self.down_ratio)
234
+ w_out = int(pw * self.patch_size / self.down_ratio)
235
+
236
+ fused_main = custom_interpolate(
237
+ fused_main, (h_out, w_out), mode="bilinear", align_corners=True
238
+ )
239
+ if self.pos_embed:
240
+ fused_main = self._add_pos_embed(fused_main, W, H)
241
+
242
+ # Primary head: conv1 -> conv2 -> activate
243
+ # fused_main = self.scratch.output_conv1(fused_main)
244
+ main_logits = self.scratch.output_conv2(fused_main)
245
+ fmap = main_logits.permute(0, 2, 3, 1)
246
+ main_pred = self._apply_activation_single(fmap[..., :-1], self.activation)
247
+ main_conf = self._apply_activation_single(fmap[..., -1], self.conf_activation)
248
+
249
+ # Auxiliary head (multi-level inside) -> only last level returned (after activation)
250
+ last_aux = fused_aux_pyr[-1]
251
+ if self.pos_embed:
252
+ last_aux = self._add_pos_embed(last_aux, W, H)
253
+ # neck (per-level pre-conv) then final projection (only for last level)
254
+ # last_aux = self.scratch.output_conv1_aux[-1](last_aux)
255
+ last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
256
+ fmap_last = last_aux_logits.permute(0, 2, 3, 1)
257
+ aux_pred = self._apply_activation_single(fmap_last[..., :-1], "linear")
258
+ aux_conf = self._apply_activation_single(fmap_last[..., -1], self.conf_activation)
259
+ return {
260
+ self.head_main: main_pred.squeeze(-1),
261
+ f"{self.head_main}_conf": main_conf,
262
+ self.head_aux: aux_pred,
263
+ f"{self.head_aux}_conf": aux_conf,
264
+ }
265
+
266
+ # -------------------------------------------------------------------------
267
+ # Subroutines
268
+ # -------------------------------------------------------------------------
269
+
270
+ def _fuse(self, feats: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
271
+ """
272
+ Feature pyramid fusion.
273
+ Returns:
274
+ fused_main: Tensor at finest scale (after refinenet1)
275
+ aux_pyr: List of aux tensors at each level (pre out_conv1_aux)
276
+ """
277
+ l1, l2, l3, l4 = feats
278
+
279
+ l1_rn = self.scratch.layer1_rn(l1)
280
+ l2_rn = self.scratch.layer2_rn(l2)
281
+ l3_rn = self.scratch.layer3_rn(l3)
282
+ l4_rn = self.scratch.layer4_rn(l4)
283
+
284
+ # level 4 -> 3
285
+ out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
286
+ aux_out = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
287
+ aux_list: List[torch.Tensor] = []
288
+ if self.aux_levels >= 4:
289
+ aux_list.append(aux_out)
290
+
291
+ # level 3 -> 2
292
+ out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
293
+ aux_out = self.scratch.refinenet3_aux(aux_out, l3_rn, size=l2_rn.shape[2:])
294
+ if self.aux_levels >= 3:
295
+ aux_list.append(aux_out)
296
+
297
+ # level 2 -> 1
298
+ out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
299
+ aux_out = self.scratch.refinenet2_aux(aux_out, l2_rn, size=l1_rn.shape[2:])
300
+ if self.aux_levels >= 2:
301
+ aux_list.append(aux_out)
302
+
303
+ # level 1 (final)
304
+ out = self.scratch.refinenet1(out, l1_rn)
305
+ aux_out = self.scratch.refinenet1_aux(aux_out, l1_rn)
306
+ aux_list.append(aux_out)
307
+
308
+ out = self.scratch.output_conv1(out)
309
+ aux_list = [self.scratch.output_conv1_aux[i](aux) for i, aux in enumerate(aux_list)]
310
+
311
+ return out, aux_list
312
+
313
+ def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
314
+ """Simple UV positional embedding added to feature maps."""
315
+ pw, ph = x.shape[-1], x.shape[-2]
316
+ pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
317
+ pe = position_grid_to_embed(pe, x.shape[1]) * ratio
318
+ pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
319
+ return x + pe
320
+
321
+ def _make_aux_out1_block(self, in_ch: int) -> nn.Sequential:
322
+ """Factory for the aux pre-head stack before the final 1x1 projection."""
323
+ if self.aux_out1_conv_num == 5:
324
+ return nn.Sequential(
325
+ nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
326
+ nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
327
+ nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
328
+ nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
329
+ nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
330
+ )
331
+ if self.aux_out1_conv_num == 3:
332
+ return nn.Sequential(
333
+ nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
334
+ nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
335
+ nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
336
+ )
337
+ if self.aux_out1_conv_num == 1:
338
+ return nn.Sequential(nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1))
339
+ raise ValueError(f"aux_out1_conv_num {self.aux_out1_conv_num} not supported")
340
+
341
+ def _apply_activation_single(
342
+ self, x: torch.Tensor, activation: str = "linear"
343
+ ) -> torch.Tensor:
344
+ """
345
+ Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case.
346
+ Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1
347
+ """
348
+ act = activation.lower() if isinstance(activation, str) else activation
349
+ if act == "exp":
350
+ return torch.exp(x)
351
+ if act == "expm1":
352
+ return torch.expm1(x)
353
+ if act == "expp1":
354
+ return torch.exp(x) + 1
355
+ if act == "relu":
356
+ return torch.relu(x)
357
+ if act == "sigmoid":
358
+ return torch.sigmoid(x)
359
+ if act == "softplus":
360
+ return torch.nn.functional.softplus(x)
361
+ if act == "tanh":
362
+ return torch.tanh(x)
363
+ # Default linear
364
+ return x
365
+
366
+
367
+ # # -----------------------------------------------------------------------------
368
+ # # Building blocks (tidy)
369
+ # # -----------------------------------------------------------------------------
370
+
371
+
372
+ # def _make_fusion_block(
373
+ # features: int,
374
+ # size: Tuple[int, int] = None,
375
+ # has_residual: bool = True,
376
+ # groups: int = 1,
377
+ # inplace: bool = False, # <- activation uses inplace=True by default; not related to "fusion_inplace"
378
+ # ) -> nn.Module:
379
+ # return FeatureFusionBlock(
380
+ # features=features,
381
+ # activation=nn.ReLU(inplace=inplace),
382
+ # deconv=False,
383
+ # bn=False,
384
+ # expand=False,
385
+ # align_corners=True,
386
+ # size=size,
387
+ # has_residual=has_residual,
388
+ # groups=groups,
389
+ # )
390
+
391
+
392
+ # def _make_scratch(
393
+ # in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
394
+ # ) -> nn.Module:
395
+ # scratch = nn.Module()
396
+ # # optionally expand widths by stage
397
+ # c1 = out_shape
398
+ # c2 = out_shape * (2 if expand else 1)
399
+ # c3 = out_shape * (4 if expand else 1)
400
+ # c4 = out_shape * (8 if expand else 1)
401
+
402
+ # scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups)
403
+ # scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups)
404
+ # scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups)
405
+ # scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups)
406
+ # return scratch
407
+
408
+
409
+ # class ResidualConvUnit(nn.Module):
410
+ # """Lightweight residual conv block used within fusion."""
411
+
412
+ # def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None:
413
+ # super().__init__()
414
+ # self.bn = bn
415
+ # self.groups = groups
416
+ # self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
417
+ # self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
418
+ # self.norm1 = None
419
+ # self.norm2 = None
420
+ # self.activation = activation
421
+ # self.skip_add = nn.quantized.FloatFunctional()
422
+
423
+ # def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
424
+ # out = self.activation(x)
425
+ # out = self.conv1(out)
426
+ # if self.norm1 is not None:
427
+ # out = self.norm1(out)
428
+
429
+ # out = self.activation(out)
430
+ # out = self.conv2(out)
431
+ # if self.norm2 is not None:
432
+ # out = self.norm2(out)
433
+
434
+ # return self.skip_add.add(out, x)
435
+
436
+
437
+ # class FeatureFusionBlock(nn.Module):
438
+ # """Top-down fusion block: (optional) residual merge + upsample + 1x1 shrink."""
439
+
440
+ # def __init__(
441
+ # self,
442
+ # features: int,
443
+ # activation: nn.Module,
444
+ # deconv: bool = False,
445
+ # bn: bool = False,
446
+ # expand: bool = False,
447
+ # align_corners: bool = True,
448
+ # size: Tuple[int, int] = None,
449
+ # has_residual: bool = True,
450
+ # groups: int = 1,
451
+ # ) -> None:
452
+ # super().__init__()
453
+ # self.align_corners = align_corners
454
+ # self.size = size
455
+ # self.has_residual = has_residual
456
+
457
+ # self.resConfUnit1 = (
458
+ # ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None
459
+ # )
460
+ # self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups)
461
+
462
+ # out_features = (features // 2) if expand else features
463
+ # self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups)
464
+ # self.skip_add = nn.quantized.FloatFunctional()
465
+
466
+ # def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override]
467
+ # """
468
+ # xs:
469
+ # - xs[0]: top input
470
+ # - xs[1]: (optional) lateral (to be added with residual)
471
+ # """
472
+ # y = xs[0]
473
+ # if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
474
+ # y = self.skip_add.add(y, self.resConfUnit1(xs[1]))
475
+
476
+ # y = self.resConfUnit2(y)
477
+
478
+ # # upsample
479
+ # if (size is None) and (self.size is None):
480
+ # up_kwargs = {"scale_factor": 2}
481
+ # elif size is None:
482
+ # up_kwargs = {"size": self.size}
483
+ # else:
484
+ # up_kwargs = {"size": size}
485
+
486
+ # y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
487
+ # y = self.out_conv(y)
488
+ # return y
src/depth_anything_3/model/gs_adapter.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional
16
+ import torch
17
+ from einops import einsum, rearrange, repeat
18
+ from torch import nn
19
+
20
+ from depth_anything_3.model.utils.transform import cam_quat_xyzw_to_world_quat_wxyz
21
+ from depth_anything_3.specs import Gaussians
22
+ from depth_anything_3.utils.geometry import affine_inverse, get_world_rays, sample_image_grid
23
+ from depth_anything_3.utils.pose_align import batch_align_poses_umeyama
24
+ from depth_anything_3.utils.sh_helpers import rotate_sh
25
+
26
+
27
+ class GaussianAdapter(nn.Module):
28
+
29
+ def __init__(
30
+ self,
31
+ sh_degree: int = 0,
32
+ pred_color: bool = False,
33
+ pred_offset_depth: bool = False,
34
+ pred_offset_xy: bool = True,
35
+ gaussian_scale_min: float = 1e-5,
36
+ gaussian_scale_max: float = 30.0,
37
+ ):
38
+ super().__init__()
39
+ self.sh_degree = sh_degree
40
+ self.pred_color = pred_color
41
+ self.pred_offset_depth = pred_offset_depth
42
+ self.pred_offset_xy = pred_offset_xy
43
+ self.gaussian_scale_min = gaussian_scale_min
44
+ self.gaussian_scale_max = gaussian_scale_max
45
+
46
+ # Create a mask for the spherical harmonics coefficients. This ensures that at
47
+ # initialization, the coefficients are biased towards having a large DC
48
+ # component and small view-dependent components.
49
+ if not pred_color:
50
+ self.register_buffer(
51
+ "sh_mask",
52
+ torch.ones((self.d_sh,), dtype=torch.float32),
53
+ persistent=False,
54
+ )
55
+ for degree in range(1, sh_degree + 1):
56
+ self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree
57
+
58
+ def forward(
59
+ self,
60
+ extrinsics: torch.Tensor, # "*#batch 4 4"
61
+ intrinsics: torch.Tensor, # "*#batch 3 3"
62
+ depths: torch.Tensor, # "*#batch"
63
+ opacities: torch.Tensor, # "*#batch" | "*#batch _"
64
+ raw_gaussians: torch.Tensor, # "*#batch _"
65
+ image_shape: tuple[int, int],
66
+ eps: float = 1e-8,
67
+ gt_extrinsics: Optional[torch.Tensor] = None, # "*#batch 4 4"
68
+ **kwargs,
69
+ ) -> Gaussians:
70
+ device = extrinsics.device
71
+ dtype = raw_gaussians.dtype
72
+ H, W = image_shape
73
+ b, v = raw_gaussians.shape[:2]
74
+
75
+ # get cam2worlds and intr_normed to adapt to 3DGS codebase
76
+ cam2worlds = affine_inverse(extrinsics)
77
+ intr_normed = intrinsics.clone().detach()
78
+ intr_normed[..., 0, :] /= W
79
+ intr_normed[..., 1, :] /= H
80
+
81
+ # 1. compute 3DGS means
82
+ # 1.1) offset the predicted depth if needed
83
+ if self.pred_offset_depth:
84
+ gs_depths = depths + raw_gaussians[..., -1]
85
+ raw_gaussians = raw_gaussians[..., :-1]
86
+ else:
87
+ gs_depths = depths
88
+ # 1.2) align predicted poses with GT if needed
89
+ if gt_extrinsics is not None and extrinsics != gt_extrinsics:
90
+ try:
91
+ _, _, pose_scales = batch_align_poses_umeyama(
92
+ gt_extrinsics.detach().float(),
93
+ extrinsics.detach().float(),
94
+ )
95
+ except Exception:
96
+ pose_scales = torch.ones_like(extrinsics[:, 0, 0, 0])
97
+ pose_scales = torch.clamp(pose_scales, min=1 / 3.0, max=3.0)
98
+ cam2worlds[:, :, :3, 3] = cam2worlds[:, :, :3, 3] * rearrange(
99
+ pose_scales, "b -> b () ()"
100
+ )
101
+ gs_depths = gs_depths * rearrange(pose_scales, "b -> b () () () ()")
102
+ # 1.3) casting xy in image space
103
+ xy_ray, _ = sample_image_grid((H, W), device)
104
+ xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # b v h w xy
105
+ # offset xy if needed
106
+ if self.pred_offset_xy:
107
+ pixel_size = 1 / torch.tensor((W, H), dtype=xy_ray.dtype, device=device)
108
+ offset_xy = raw_gaussians[..., :2]
109
+ xy_ray = xy_ray + offset_xy * pixel_size
110
+ raw_gaussians = raw_gaussians[..., 2:] # skip the offset_xy
111
+ # 1.4) unproject depth + xy to world ray
112
+ origins, directions = get_world_rays(
113
+ xy_ray,
114
+ repeat(cam2worlds, "b v i j -> b v h w i j", h=H, w=W),
115
+ repeat(intr_normed, "b v i j -> b v h w i j", h=H, w=W),
116
+ )
117
+ gs_means_world = origins + directions * gs_depths[..., None]
118
+ gs_means_world = rearrange(gs_means_world, "b v h w d -> b (v h w) d")
119
+
120
+ # 2. compute other GS attributes
121
+ scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
122
+
123
+ # 2.1) 3DGS scales
124
+ # make the scale invarient to resolution
125
+ scale_min = self.gaussian_scale_min
126
+ scale_max = self.gaussian_scale_max
127
+ scales = scale_min + (scale_max - scale_min) * scales.sigmoid()
128
+ pixel_size = 1 / torch.tensor((W, H), dtype=dtype, device=device)
129
+ multiplier = self.get_scale_multiplier(intr_normed, pixel_size)
130
+ gs_scales = scales * gs_depths[..., None] * multiplier[..., None, None, None]
131
+ gs_scales = rearrange(gs_scales, "b v h w d -> b (v h w) d")
132
+
133
+ # 2.2) 3DGS quaternion (world space)
134
+ # due to historical issue, assume quaternion in order xyzw, not wxyz
135
+ # Normalize the quaternion features to yield a valid quaternion.
136
+ rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
137
+ # rotate them to world space
138
+ cam_quat_xyzw = rearrange(rotations, "b v h w c -> b (v h w) c")
139
+ c2w_mat = repeat(
140
+ cam2worlds,
141
+ "b v i j -> b (v h w) i j",
142
+ h=H,
143
+ w=W,
144
+ )
145
+ world_quat_wxyz = cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w_mat)
146
+ gs_rotations_world = world_quat_wxyz # b (v h w) c
147
+
148
+ # 2.3) 3DGS color / SH coefficient (world space)
149
+ sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
150
+ if not self.pred_color:
151
+ sh = sh * self.sh_mask
152
+
153
+ if self.pred_color or self.sh_degree == 0:
154
+ # predict pre-computed color or predict only DC band, no need to transform
155
+ gs_sh_world = sh
156
+ else:
157
+ gs_sh_world = rotate_sh(sh, cam2worlds[:, :, None, None, None, :3, :3])
158
+ gs_sh_world = rearrange(gs_sh_world, "b v h w xyz d_sh -> b (v h w) xyz d_sh")
159
+
160
+ # 2.4) 3DGS opacity
161
+ gs_opacities = rearrange(opacities, "b v h w ... -> b (v h w) ...")
162
+
163
+ return Gaussians(
164
+ means=gs_means_world,
165
+ harmonics=gs_sh_world,
166
+ opacities=gs_opacities,
167
+ scales=gs_scales,
168
+ rotations=gs_rotations_world,
169
+ )
170
+
171
+ def get_scale_multiplier(
172
+ self,
173
+ intrinsics: torch.Tensor, # "*#batch 3 3"
174
+ pixel_size: torch.Tensor, # "*#batch 2"
175
+ multiplier: float = 0.1,
176
+ ) -> torch.Tensor: # " *batch"
177
+ xy_multipliers = multiplier * einsum(
178
+ intrinsics[..., :2, :2].float().inverse().to(intrinsics),
179
+ pixel_size,
180
+ "... i j, j -> ... i",
181
+ )
182
+ return xy_multipliers.sum(dim=-1)
183
+
184
+ @property
185
+ def d_sh(self) -> int:
186
+ return 1 if self.pred_color else (self.sh_degree + 1) ** 2
187
+
188
+ @property
189
+ def d_in(self) -> int:
190
+ # provided as reference to the gs_dpt output dim
191
+ raw_gs_dim = 0
192
+ if self.pred_offset_xy:
193
+ raw_gs_dim += 2
194
+ raw_gs_dim += 3 # scales
195
+ raw_gs_dim += 4 # quaternion
196
+ raw_gs_dim += 3 * self.d_sh # color
197
+ if self.pred_offset_depth:
198
+ raw_gs_dim += 1
199
+
200
+ return raw_gs_dim
src/depth_anything_3/model/gsdpt.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict as TyDict
16
+ from typing import List, Sequence
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from depth_anything_3.model.dpt import DPT
21
+ from depth_anything_3.model.utils.head_utils import activate_head_gs, custom_interpolate
22
+
23
+
24
+ class GSDPT(DPT):
25
+
26
+ def __init__(
27
+ self,
28
+ dim_in: int,
29
+ patch_size: int = 14,
30
+ output_dim: int = 4,
31
+ activation: str = "linear",
32
+ conf_activation: str = "sigmoid",
33
+ features: int = 256,
34
+ out_channels: Sequence[int] = (256, 512, 1024, 1024),
35
+ pos_embed: bool = True,
36
+ feature_only: bool = False,
37
+ down_ratio: int = 1,
38
+ conf_dim: int = 1,
39
+ norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer"
40
+ fusion_block_inplace: bool = False,
41
+ ) -> None:
42
+ super().__init__(
43
+ dim_in=dim_in,
44
+ patch_size=patch_size,
45
+ output_dim=output_dim,
46
+ activation=activation,
47
+ conf_activation=conf_activation,
48
+ features=features,
49
+ out_channels=out_channels,
50
+ pos_embed=pos_embed,
51
+ down_ratio=down_ratio,
52
+ head_name="raw_gs",
53
+ use_sky_head=False,
54
+ norm_type=norm_type,
55
+ fusion_block_inplace=fusion_block_inplace,
56
+ )
57
+ self.conf_dim = conf_dim
58
+ if conf_dim and conf_dim > 1:
59
+ assert (
60
+ conf_activation == "linear"
61
+ ), "use linear prediction when using view-dependent opacity"
62
+
63
+ merger_out_dim = features if feature_only else features // 2
64
+ self.images_merger = nn.Sequential(
65
+ nn.Conv2d(3, merger_out_dim // 4, 3, 1, 1), # fewer channels first
66
+ nn.GELU(),
67
+ nn.Conv2d(merger_out_dim // 4, merger_out_dim // 2, 3, 1, 1),
68
+ nn.GELU(),
69
+ nn.Conv2d(merger_out_dim // 2, merger_out_dim, 3, 1, 1),
70
+ nn.GELU(),
71
+ )
72
+
73
+ # -------------------------------------------------------------------------
74
+ # Internal forward (single chunk)
75
+ # -------------------------------------------------------------------------
76
+ def _forward_impl(
77
+ self,
78
+ feats: List[torch.Tensor],
79
+ H: int,
80
+ W: int,
81
+ patch_start_idx: int,
82
+ images: torch.Tensor,
83
+ ) -> TyDict[str, torch.Tensor]:
84
+ B, _, C = feats[0].shape
85
+ ph, pw = H // self.patch_size, W // self.patch_size
86
+ resized_feats = []
87
+ for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
88
+ x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C]
89
+ x = self.norm(x)
90
+ x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
91
+
92
+ x = self.projects[stage_idx](x)
93
+ if self.pos_embed:
94
+ x = self._add_pos_embed(x, W, H)
95
+ x = self.resize_layers[stage_idx](x) # Align scale
96
+ resized_feats.append(x)
97
+
98
+ # 2) Fusion pyramid (main branch only)
99
+ fused = self._fuse(resized_feats)
100
+ fused = self.scratch.output_conv1(fused)
101
+
102
+ # 3) Upsample to target resolution, optionally add position encoding again
103
+ h_out = int(ph * self.patch_size / self.down_ratio)
104
+ w_out = int(pw * self.patch_size / self.down_ratio)
105
+
106
+ fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
107
+
108
+ # inject the image information here
109
+ fused = fused + self.images_merger(images)
110
+
111
+ if self.pos_embed:
112
+ fused = self._add_pos_embed(fused, W, H)
113
+
114
+ # 4) Shared neck1
115
+ # feat = self.scratch.output_conv1(fused)
116
+ feat = fused
117
+
118
+ # 5) Main head: logits -> activate_head or single channel activation
119
+ main_logits = self.scratch.output_conv2(feat)
120
+ outs: TyDict[str, torch.Tensor] = {}
121
+ if self.has_conf:
122
+ pred, conf = activate_head_gs(
123
+ main_logits,
124
+ activation=self.activation,
125
+ conf_activation=self.conf_activation,
126
+ conf_dim=self.conf_dim,
127
+ )
128
+ outs[self.head_main] = pred.squeeze(1)
129
+ outs[f"{self.head_main}_conf"] = conf.squeeze(1)
130
+ else:
131
+ outs[self.head_main] = self._apply_activation_single(main_logits).squeeze(1)
132
+
133
+ return outs
src/depth_anything_3/model/utils/attention.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa
16
+
17
+ from typing import Callable, Optional, Union
18
+ import torch
19
+ import torch.nn.functional as F
20
+ from torch import Tensor, nn
21
+
22
+
23
+ class Attention(nn.Module):
24
+ def __init__(
25
+ self,
26
+ dim: int,
27
+ num_heads: int = 8,
28
+ qkv_bias: bool = True,
29
+ proj_bias: bool = True,
30
+ attn_drop: float = 0.0,
31
+ proj_drop: float = 0.0,
32
+ norm_layer: nn.Module = nn.LayerNorm,
33
+ qk_norm: bool = False,
34
+ rope=None,
35
+ ) -> None:
36
+ super().__init__()
37
+ assert dim % num_heads == 0, "dim should be divisible by num_heads"
38
+ self.num_heads = num_heads
39
+ self.head_dim = dim // num_heads
40
+ self.scale = self.head_dim**-0.5
41
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
42
+ self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
43
+ self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
44
+ self.attn_drop = nn.Dropout(attn_drop)
45
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
46
+ self.proj_drop = nn.Dropout(proj_drop)
47
+ self.rope = rope
48
+
49
+ def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
50
+ # Debug breakpoint removed for production
51
+ B, N, C = x.shape
52
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
53
+ q, k, v = qkv.unbind(0)
54
+ q, k = self.q_norm(q), self.k_norm(k)
55
+ q = self.rope(q, pos) if self.rope is not None else q
56
+ k = self.rope(k, pos) if self.rope is not None else k
57
+ x = F.scaled_dot_product_attention(
58
+ q,
59
+ k,
60
+ v,
61
+ dropout_p=self.attn_drop.p if self.training else 0.0,
62
+ attn_mask=attn_mask,
63
+ )
64
+ x = x.transpose(1, 2).reshape(B, N, C)
65
+ x = self.proj(x)
66
+ x = self.proj_drop(x)
67
+ return x
68
+
69
+
70
+ class LayerScale(nn.Module):
71
+ def __init__(
72
+ self,
73
+ dim: int,
74
+ init_values: Union[float, Tensor] = 1e-5,
75
+ inplace: bool = False,
76
+ ) -> None:
77
+ super().__init__()
78
+ self.inplace = inplace
79
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
80
+
81
+ def forward(self, x: Tensor) -> Tensor:
82
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
83
+
84
+
85
+ class Mlp(nn.Module):
86
+ def __init__(
87
+ self,
88
+ in_features: int,
89
+ hidden_features: Optional[int] = None,
90
+ out_features: Optional[int] = None,
91
+ act_layer: Callable[..., nn.Module] = nn.GELU,
92
+ drop: float = 0.0,
93
+ bias: bool = True,
94
+ ) -> None:
95
+ super().__init__()
96
+ out_features = out_features or in_features
97
+ hidden_features = hidden_features or in_features
98
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
99
+ self.act = act_layer()
100
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
101
+ self.drop = nn.Dropout(drop)
102
+
103
+ def forward(self, x: Tensor) -> Tensor:
104
+ x = self.fc1(x)
105
+ x = self.act(x)
106
+ x = self.drop(x)
107
+ x = self.fc2(x)
108
+ x = self.drop(x)
109
+ return x
src/depth_anything_3/model/utils/block.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Callable
17
+ from torch import Tensor, nn
18
+
19
+ from .attention import Attention, LayerScale, Mlp
20
+
21
+
22
+ class Block(nn.Module):
23
+ def __init__(
24
+ self,
25
+ dim: int,
26
+ num_heads: int,
27
+ mlp_ratio: float = 4.0,
28
+ qkv_bias: bool = True,
29
+ proj_bias: bool = True,
30
+ ffn_bias: bool = True,
31
+ drop: float = 0.0,
32
+ attn_drop: float = 0.0,
33
+ init_values=None,
34
+ drop_path: float = 0.0,
35
+ act_layer: Callable[..., nn.Module] = nn.GELU,
36
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
37
+ attn_class: Callable[..., nn.Module] = Attention,
38
+ ffn_layer: Callable[..., nn.Module] = Mlp,
39
+ qk_norm: bool = False,
40
+ rope=None,
41
+ ) -> None:
42
+ super().__init__()
43
+
44
+ self.norm1 = norm_layer(dim)
45
+
46
+ self.attn = attn_class(
47
+ dim,
48
+ num_heads=num_heads,
49
+ qkv_bias=qkv_bias,
50
+ proj_bias=proj_bias,
51
+ attn_drop=attn_drop,
52
+ proj_drop=drop,
53
+ qk_norm=qk_norm,
54
+ rope=rope,
55
+ )
56
+
57
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
58
+ self.norm2 = norm_layer(dim)
59
+ mlp_hidden_dim = int(dim * mlp_ratio)
60
+ self.mlp = ffn_layer(
61
+ in_features=dim,
62
+ hidden_features=mlp_hidden_dim,
63
+ act_layer=act_layer,
64
+ drop=drop,
65
+ bias=ffn_bias,
66
+ )
67
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
68
+
69
+ self.sample_drop_ratio = 0.0 # Equivalent to always having drop_path=0
70
+
71
+ def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
72
+ def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor:
73
+ return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
74
+
75
+ def ffn_residual_func(x: Tensor) -> Tensor:
76
+ return self.ls2(self.mlp(self.norm2(x)))
77
+
78
+ # drop_path is always 0, so always take the else branch
79
+ x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
80
+ x = x + ffn_residual_func(x)
81
+ return x
src/depth_anything_3/model/utils/gs_renderer.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from math import isqrt
17
+ from typing import Literal, Optional
18
+ import torch
19
+ from einops import rearrange, repeat
20
+ from tqdm import tqdm
21
+
22
+ from depth_anything_3.specs import Gaussians
23
+ from depth_anything_3.utils.camera_trj_helpers import (
24
+ interpolate_extrinsics,
25
+ interpolate_intrinsics,
26
+ render_dolly_zoom_path,
27
+ render_stabilization_path,
28
+ render_wander_path,
29
+ render_wobble_inter_path,
30
+ )
31
+ from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, get_fov
32
+ from depth_anything_3.utils.logger import logger
33
+
34
+ try:
35
+ from gsplat import rasterization
36
+ except ImportError:
37
+ logger.warn(
38
+ "Dependency `gsplat` is required for rendering 3DGS. "
39
+ "Install via: pip install git+https://github.com/nerfstudio-project/"
40
+ "gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70"
41
+ )
42
+
43
+
44
+ def render_3dgs(
45
+ extrinsics: torch.Tensor, # "batch_views 4 4", w2c
46
+ intrinsics: torch.Tensor, # "batch_views 3 3", normalized
47
+ image_shape: tuple[int, int],
48
+ gaussian: Gaussians,
49
+ background_color: Optional[torch.Tensor] = None, # "batch_views 3"
50
+ use_sh: bool = True,
51
+ num_view: int = 1,
52
+ color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+D",
53
+ **kwargs,
54
+ ) -> tuple[
55
+ torch.Tensor, # "batch_views 3 height width"
56
+ torch.Tensor, # "batch_views height width"
57
+ ]:
58
+ # extract gaussian params
59
+ gaussian_means = gaussian.means
60
+ gaussian_scales = gaussian.scales
61
+ gaussian_quats = gaussian.rotations
62
+ gaussian_opacities = gaussian.opacities
63
+ gaussian_sh_coefficients = gaussian.harmonics
64
+ b, _, _ = extrinsics.shape
65
+
66
+ if background_color is None:
67
+ background_color = repeat(torch.tensor([0.0, 0.0, 0.0]), "c -> b c", b=b).to(
68
+ gaussian_sh_coefficients
69
+ )
70
+
71
+ if use_sh:
72
+ _, _, _, n = gaussian_sh_coefficients.shape
73
+ degree = isqrt(n) - 1
74
+ shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
75
+ else: # use color
76
+ shs = (
77
+ gaussian_sh_coefficients.squeeze(-1).sigmoid().contiguous()
78
+ ) # (b, g, c), normed to (0, 1)
79
+
80
+ h, w = image_shape
81
+
82
+ fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
83
+ tan_fov_x = (0.5 * fov_x).tan()
84
+ tan_fov_y = (0.5 * fov_y).tan()
85
+ focal_length_x = w / (2 * tan_fov_x)
86
+ focal_length_y = h / (2 * tan_fov_y)
87
+
88
+ view_matrix = extrinsics.float()
89
+
90
+ all_images = []
91
+ all_radii = []
92
+ all_depths = []
93
+ # render view in a batch based, each batch contains one scene
94
+ # assume the Gaussian parameters are originally repeated along the view dim
95
+ batch_scene = b // num_view
96
+
97
+ def index_i_gs_attr(full_attr, idx):
98
+ # return rearrange(full_attr, "(b v) ... -> b v ...", v=num_view)[idx, 0]
99
+ return full_attr[idx]
100
+
101
+ for i in range(batch_scene):
102
+ K = repeat(
103
+ torch.tensor(
104
+ [
105
+ [0, 0, w / 2.0],
106
+ [0, 0, h / 2.0],
107
+ [0, 0, 1],
108
+ ]
109
+ ),
110
+ "i j -> v i j",
111
+ v=num_view,
112
+ ).to(gaussian_means)
113
+ K[:, 0, 0] = focal_length_x.reshape(batch_scene, num_view)[i]
114
+ K[:, 1, 1] = focal_length_y.reshape(batch_scene, num_view)[i]
115
+
116
+ i_means = index_i_gs_attr(gaussian_means, i) # [N, 3]
117
+ i_scales = index_i_gs_attr(gaussian_scales, i)
118
+ i_quats = index_i_gs_attr(gaussian_quats, i)
119
+ i_opacities = index_i_gs_attr(gaussian_opacities, i) # [N,]
120
+ i_colors = index_i_gs_attr(shs, i) # [N, K, 3]
121
+ i_viewmats = rearrange(view_matrix, "(b v) ... -> b v ...", v=num_view)[i] # [v, 4, 4]
122
+ i_backgrounds = rearrange(background_color, "(b v) ... -> b v ...", v=num_view)[
123
+ i
124
+ ] # [v, 3]
125
+
126
+ render_colors, render_alphas, info = rasterization(
127
+ means=i_means,
128
+ quats=i_quats, # [N, 4]
129
+ scales=i_scales, # [N, 3]
130
+ opacities=i_opacities,
131
+ colors=i_colors,
132
+ viewmats=i_viewmats, # [v, 4, 4]
133
+ Ks=K, # [v, 3, 3]
134
+ backgrounds=i_backgrounds,
135
+ render_mode=color_mode,
136
+ width=w,
137
+ height=h,
138
+ packed=False,
139
+ sh_degree=degree if use_sh else None,
140
+ )
141
+ depth = render_colors[..., -1].unbind(dim=0)
142
+
143
+ image = rearrange(render_colors[..., :3], "v h w c -> v c h w").unbind(dim=0)
144
+ radii = info["radii"].unbind(dim=0)
145
+ try:
146
+ info["means2d"].retain_grad() # [1, N, 2]
147
+ except Exception:
148
+ pass
149
+ all_images.extend(image)
150
+ all_depths.extend(depth)
151
+ all_radii.extend(radii)
152
+
153
+ return torch.stack(all_images), torch.stack(all_depths)
154
+
155
+
156
+ def run_renderer_in_chunk_w_trj_mode(
157
+ gaussians: Gaussians,
158
+ extrinsics: torch.Tensor, # world2cam, "batch view 4 4" | "batch view 3 4"
159
+ intrinsics: torch.Tensor, # unnormed intrinsics, "batch view 3 3"
160
+ image_shape: tuple[int, int],
161
+ chunk_size: Optional[int] = 8,
162
+ trj_mode: Literal[
163
+ "original",
164
+ "smooth",
165
+ "interpolate",
166
+ "interpolate_smooth",
167
+ "wander",
168
+ "dolly_zoom",
169
+ "extend",
170
+ "wobble_inter",
171
+ ] = "smooth",
172
+ input_shape: Optional[tuple[int, int]] = None,
173
+ enable_tqdm: Optional[bool] = False,
174
+ **kwargs,
175
+ ) -> tuple[
176
+ torch.Tensor, # color, "batch view 3 height width"
177
+ torch.Tensor, # depth, "batch view height width"
178
+ ]:
179
+ cam2world = affine_inverse(as_homogeneous(extrinsics))
180
+ if input_shape is not None:
181
+ in_h, in_w = input_shape
182
+ else:
183
+ in_h, in_w = image_shape
184
+ intr_normed = intrinsics.clone().detach()
185
+ intr_normed[..., 0, :] /= in_w
186
+ intr_normed[..., 1, :] /= in_h
187
+ if extrinsics.shape[1] <= 1:
188
+ assert trj_mode in [
189
+ "wander",
190
+ "dolly_zoom",
191
+ ], "Please set trj_mode to 'wander' or 'dolly_zoom' when n_views=1"
192
+
193
+ def _smooth_trj_fn_batch(raw_c2ws, k_size=50):
194
+ try:
195
+ smooth_c2ws = torch.stack(
196
+ [render_stabilization_path(c2w_i, k_size) for c2w_i in raw_c2ws],
197
+ dim=0,
198
+ )
199
+ except Exception as e:
200
+ print(f"[DEBUG] Path smoothing failed with error: {e}.")
201
+ smooth_c2ws = raw_c2ws
202
+ return smooth_c2ws
203
+
204
+ # get rendered trj
205
+ if trj_mode == "original":
206
+ tgt_c2w = cam2world
207
+ tgt_intr = intr_normed
208
+ elif trj_mode == "smooth":
209
+ tgt_c2w = _smooth_trj_fn_batch(cam2world)
210
+ tgt_intr = intr_normed
211
+ elif trj_mode in ["interpolate", "interpolate_smooth", "extend"]:
212
+ inter_len = 8
213
+ total_len = (cam2world.shape[1] - 1) * inter_len
214
+ if total_len > 24 * 18: # no more than 18s
215
+ inter_len = max(1, 24 * 10 // (cam2world.shape[1] - 1))
216
+ if total_len < 24 * 2: # no less than 2s
217
+ inter_len = max(1, 24 * 2 // (cam2world.shape[1] - 1))
218
+
219
+ if inter_len > 2:
220
+ t = torch.linspace(0, 1, inter_len, dtype=torch.float32, device=cam2world.device)
221
+ t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
222
+ tgt_c2w_b = []
223
+ tgt_intr_b = []
224
+ for b_idx in range(cam2world.shape[0]):
225
+ tgt_c2w = []
226
+ tgt_intr = []
227
+ for cur_idx in range(cam2world.shape[1] - 1):
228
+ tgt_c2w.append(
229
+ interpolate_extrinsics(
230
+ cam2world[b_idx, cur_idx], cam2world[b_idx, cur_idx + 1], t
231
+ )[(0 if cur_idx == 0 else 1) :]
232
+ )
233
+ tgt_intr.append(
234
+ interpolate_intrinsics(
235
+ intr_normed[b_idx, cur_idx], intr_normed[b_idx, cur_idx + 1], t
236
+ )[(0 if cur_idx == 0 else 1) :]
237
+ )
238
+ tgt_c2w_b.append(torch.cat(tgt_c2w))
239
+ tgt_intr_b.append(torch.cat(tgt_intr))
240
+ tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4
241
+ tgt_intr = torch.stack(tgt_intr_b) # b v 3 3
242
+ else:
243
+ tgt_c2w = cam2world
244
+ tgt_intr = intr_normed
245
+ if trj_mode in ["interpolate_smooth", "extend"]:
246
+ tgt_c2w = _smooth_trj_fn_batch(tgt_c2w)
247
+ if trj_mode == "extend":
248
+ # apply dolly_zoom and wander in the middle frame
249
+ assert cam2world.shape[0] == 1, "extend only supports for batch_size=1 currently."
250
+ mid_idx = tgt_c2w.shape[1] // 2
251
+ c2w_wd, intr_wd = render_wander_path(
252
+ tgt_c2w[0, mid_idx],
253
+ tgt_intr[0, mid_idx],
254
+ h=in_h,
255
+ w=in_w,
256
+ num_frames=max(36, min(60, mid_idx // 2)),
257
+ max_disp=24.0,
258
+ )
259
+ c2w_dz, intr_dz = render_dolly_zoom_path(
260
+ tgt_c2w[0, mid_idx],
261
+ tgt_intr[0, mid_idx],
262
+ h=in_h,
263
+ w=in_w,
264
+ num_frames=max(36, min(60, mid_idx // 2)),
265
+ )
266
+ tgt_c2w = torch.cat(
267
+ [
268
+ tgt_c2w[:, :mid_idx],
269
+ c2w_wd.unsqueeze(0),
270
+ c2w_dz.unsqueeze(0),
271
+ tgt_c2w[:, mid_idx:],
272
+ ],
273
+ dim=1,
274
+ )
275
+ tgt_intr = torch.cat(
276
+ [
277
+ tgt_intr[:, :mid_idx],
278
+ intr_wd.unsqueeze(0),
279
+ intr_dz.unsqueeze(0),
280
+ tgt_intr[:, mid_idx:],
281
+ ],
282
+ dim=1,
283
+ )
284
+ elif trj_mode in ["wander", "dolly_zoom"]:
285
+ if trj_mode == "wander":
286
+ render_fn = render_wander_path
287
+ extra_kwargs = {"max_disp": 24.0}
288
+ else:
289
+ render_fn = render_dolly_zoom_path
290
+ extra_kwargs = {"D_focus": 30.0, "max_disp": 2.0}
291
+ tgt_c2w = []
292
+ tgt_intr = []
293
+ for b_idx in range(cam2world.shape[0]):
294
+ c2w_i, intr_i = render_fn(
295
+ cam2world[b_idx, 0], intr_normed[b_idx, 0], h=in_h, w=in_w, **extra_kwargs
296
+ )
297
+ tgt_c2w.append(c2w_i)
298
+ tgt_intr.append(intr_i)
299
+ tgt_c2w = torch.stack(tgt_c2w)
300
+ tgt_intr = torch.stack(tgt_intr)
301
+ elif trj_mode == "wobble_inter":
302
+ tgt_c2w, tgt_intr = render_wobble_inter_path(
303
+ cam2world=cam2world,
304
+ intr_normed=intr_normed,
305
+ inter_len=10,
306
+ n_skip=3,
307
+ )
308
+ else:
309
+ raise Exception(f"trj mode [{trj_mode}] is not implemented.")
310
+
311
+ _, v = tgt_c2w.shape[:2]
312
+ tgt_extr = affine_inverse(tgt_c2w)
313
+ if chunk_size is None:
314
+ chunk_size = v
315
+ chunk_size = min(v, chunk_size)
316
+ all_colors = []
317
+ all_depths = []
318
+ for chunk_idx in tqdm(
319
+ range(math.ceil(v / chunk_size)),
320
+ desc="Rendering novel views",
321
+ disable=(not enable_tqdm),
322
+ leave=False,
323
+ ):
324
+ s = int(chunk_idx * chunk_size)
325
+ e = int((chunk_idx + 1) * chunk_size)
326
+ cur_n_view = tgt_extr[:, s:e].shape[1]
327
+ color, depth = render_3dgs(
328
+ extrinsics=rearrange(tgt_extr[:, s:e], "b v ... -> (b v) ..."), # w2c
329
+ intrinsics=rearrange(tgt_intr[:, s:e], "b v ... -> (b v) ..."), # normed
330
+ image_shape=image_shape,
331
+ gaussian=gaussians,
332
+ num_view=cur_n_view,
333
+ **kwargs,
334
+ )
335
+ all_colors.append(rearrange(color, "(b v) ... -> b v ...", v=cur_n_view))
336
+ all_depths.append(rearrange(depth, "(b v) ... -> b v ...", v=cur_n_view))
337
+ all_colors = torch.cat(all_colors, dim=1)
338
+ all_depths = torch.cat(all_depths, dim=1)
339
+
340
+ return all_colors, all_depths
src/depth_anything_3/model/utils/head_utils.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Tuple, Union
16
+ import torch
17
+ import torch.nn as nn
18
+ import torch.nn.functional as F
19
+
20
+ # -----------------------------------------------------------------------------
21
+ # Activation functions
22
+ # -----------------------------------------------------------------------------
23
+
24
+
25
+ def activate_head_gs(out, activation="norm_exp", conf_activation="expp1", conf_dim=None):
26
+ """
27
+ Process network output to extract GS params and density values.
28
+ Density could be view-dependent as SH coefficient
29
+
30
+
31
+ Args:
32
+ out: Network output tensor (B, C, H, W)
33
+ activation: Activation type for 3D points
34
+ conf_activation: Activation type for confidence values
35
+
36
+ Returns:
37
+ Tuple of (3D points tensor, confidence tensor)
38
+ """
39
+ # Move channels from last dim to the 4th dimension => (B, H, W, C)
40
+ fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
41
+
42
+ # Split into xyz (first C-1 channels) and confidence (last channel)
43
+ conf_dim = 1 if conf_dim is None else conf_dim
44
+ xyz = fmap[:, :, :, :-conf_dim]
45
+ conf = fmap[:, :, :, -1] if conf_dim == 1 else fmap[:, :, :, -conf_dim:]
46
+
47
+ if activation == "norm_exp":
48
+ d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
49
+ xyz_normed = xyz / d
50
+ pts3d = xyz_normed * torch.expm1(d)
51
+ elif activation == "norm":
52
+ pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
53
+ elif activation == "exp":
54
+ pts3d = torch.exp(xyz)
55
+ elif activation == "relu":
56
+ pts3d = F.relu(xyz)
57
+ elif activation == "sigmoid":
58
+ pts3d = torch.sigmoid(xyz)
59
+ elif activation == "linear":
60
+ pts3d = xyz
61
+ else:
62
+ raise ValueError(f"Unknown activation: {activation}")
63
+
64
+ if conf_activation == "expp1":
65
+ conf_out = 1 + conf.exp()
66
+ elif conf_activation == "expp0":
67
+ conf_out = conf.exp()
68
+ elif conf_activation == "sigmoid":
69
+ conf_out = torch.sigmoid(conf)
70
+ elif conf_activation == "linear":
71
+ conf_out = conf
72
+ else:
73
+ raise ValueError(f"Unknown conf_activation: {conf_activation}")
74
+
75
+ return pts3d, conf_out
76
+
77
+
78
+ # -----------------------------------------------------------------------------
79
+ # Other utilities
80
+ # -----------------------------------------------------------------------------
81
+
82
+
83
+ class Permute(nn.Module):
84
+ """nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage."""
85
+
86
+ dims: Tuple[int, ...]
87
+
88
+ def __init__(self, dims: Tuple[int, ...]) -> None:
89
+ super().__init__()
90
+ self.dims = dims
91
+
92
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
93
+ return x.permute(*self.dims)
94
+
95
+
96
+ def position_grid_to_embed(
97
+ pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100
98
+ ) -> torch.Tensor:
99
+ """
100
+ Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
101
+
102
+ Args:
103
+ pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
104
+ embed_dim: Output channel dimension for embeddings
105
+
106
+ Returns:
107
+ Tensor of shape (H, W, embed_dim) with positional embeddings
108
+ """
109
+ H, W, grid_dim = pos_grid.shape
110
+ assert grid_dim == 2
111
+ pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
112
+
113
+ # Process x and y coordinates separately
114
+ emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
115
+ emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
116
+
117
+ # Combine and reshape
118
+ emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
119
+
120
+ return emb.view(H, W, embed_dim) # [H, W, D]
121
+
122
+
123
+ def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
124
+ """
125
+ This function generates a 1D positional embedding from a given grid using sine and cosine functions. # noqa
126
+
127
+ Args:
128
+ - embed_dim: The embedding dimension.
129
+ - pos: The position to generate the embedding from.
130
+
131
+ Returns:
132
+ - emb: The generated 1D positional embedding.
133
+ """
134
+ assert embed_dim % 2 == 0
135
+ omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
136
+ omega /= embed_dim / 2.0
137
+ omega = 1.0 / omega_0**omega # (D/2,)
138
+
139
+ pos = pos.reshape(-1) # (M,)
140
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
141
+
142
+ emb_sin = torch.sin(out) # (M, D/2)
143
+ emb_cos = torch.cos(out) # (M, D/2)
144
+
145
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
146
+ return emb.float()
147
+
148
+
149
+ # Inspired by https://github.com/microsoft/moge
150
+
151
+
152
+ def create_uv_grid(
153
+ width: int,
154
+ height: int,
155
+ aspect_ratio: float = None,
156
+ dtype: torch.dtype = None,
157
+ device: torch.device = None,
158
+ ) -> torch.Tensor:
159
+ """
160
+ Create a normalized UV grid of shape (width, height, 2).
161
+
162
+ The grid spans horizontally and vertically according to an aspect ratio,
163
+ ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
164
+ corner is at (x_span, y_span), normalized by the diagonal of the plane.
165
+
166
+ Args:
167
+ width (int): Number of points horizontally.
168
+ height (int): Number of points vertically.
169
+ aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
170
+ dtype (torch.dtype, optional): Data type of the resulting tensor.
171
+ device (torch.device, optional): Device on which the tensor is created.
172
+
173
+ Returns:
174
+ torch.Tensor: A (width, height, 2) tensor of UV coordinates.
175
+ """
176
+ # Derive aspect ratio if not explicitly provided
177
+ if aspect_ratio is None:
178
+ aspect_ratio = float(width) / float(height)
179
+
180
+ # Compute normalized spans for X and Y
181
+ diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
182
+ span_x = aspect_ratio / diag_factor
183
+ span_y = 1.0 / diag_factor
184
+
185
+ # Establish the linspace boundaries
186
+ left_x = -span_x * (width - 1) / width
187
+ right_x = span_x * (width - 1) / width
188
+ top_y = -span_y * (height - 1) / height
189
+ bottom_y = span_y * (height - 1) / height
190
+
191
+ # Generate 1D coordinates
192
+ x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
193
+ y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
194
+
195
+ # Create 2D meshgrid (width x height) and stack into UV
196
+ uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
197
+ uv_grid = torch.stack((uu, vv), dim=-1)
198
+
199
+ return uv_grid
200
+
201
+
202
+ # -----------------------------------------------------------------------------
203
+ # Interpolation (safe interpolation, avoid INT_MAX overflow)
204
+ # -----------------------------------------------------------------------------
205
+ def custom_interpolate(
206
+ x: torch.Tensor,
207
+ size: Union[Tuple[int, int], None] = None,
208
+ scale_factor: Union[float, None] = None,
209
+ mode: str = "bilinear",
210
+ align_corners: bool = True,
211
+ ) -> torch.Tensor:
212
+ """
213
+ Safe interpolation implementation to avoid INT_MAX overflow in torch.nn.functional.interpolate.
214
+ """
215
+ if size is None:
216
+ assert scale_factor is not None, "Either size or scale_factor must be provided."
217
+ size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
218
+
219
+ INT_MAX = 1610612736
220
+ total = size[0] * size[1] * x.shape[0] * x.shape[1]
221
+
222
+ if total > INT_MAX:
223
+ chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
224
+ outs = [
225
+ nn.functional.interpolate(c, size=size, mode=mode, align_corners=align_corners)
226
+ for c in chunks
227
+ ]
228
+ return torch.cat(outs, dim=0).contiguous()
229
+
230
+ return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
src/depth_anything_3/model/utils/transform.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import torch
16
+ import torch.nn.functional as F
17
+
18
+
19
+ def extri_intri_to_pose_encoding(
20
+ extrinsics,
21
+ intrinsics,
22
+ image_size_hw=None,
23
+ ):
24
+ """Convert camera extrinsics and intrinsics to a compact pose encoding."""
25
+
26
+ # extrinsics: BxSx3x4
27
+ # intrinsics: BxSx3x3
28
+ R = extrinsics[:, :, :3, :3] # BxSx3x3
29
+ T = extrinsics[:, :, :3, 3] # BxSx3
30
+
31
+ quat = mat_to_quat(R)
32
+ # Note the order of h and w here
33
+ H, W = image_size_hw
34
+ fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
35
+ fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
36
+ pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
37
+
38
+ return pose_encoding
39
+
40
+
41
+ def pose_encoding_to_extri_intri(
42
+ pose_encoding,
43
+ image_size_hw=None,
44
+ ):
45
+ """Convert a pose encoding back to camera extrinsics and intrinsics."""
46
+
47
+ T = pose_encoding[..., :3]
48
+ quat = pose_encoding[..., 3:7]
49
+ fov_h = pose_encoding[..., 7]
50
+ fov_w = pose_encoding[..., 8]
51
+
52
+ R = quat_to_mat(quat)
53
+ extrinsics = torch.cat([R, T[..., None]], dim=-1)
54
+
55
+ H, W = image_size_hw
56
+ fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
57
+ fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
58
+ intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
59
+ intrinsics[..., 0, 0] = fx
60
+ intrinsics[..., 1, 1] = fy
61
+ intrinsics[..., 0, 2] = W / 2
62
+ intrinsics[..., 1, 2] = H / 2
63
+ intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
64
+
65
+ return extrinsics, intrinsics
66
+
67
+
68
+ def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
69
+ """
70
+ Quaternion Order: XYZW or say ijkr, scalar-last
71
+
72
+ Convert rotations given as quaternions to rotation matrices.
73
+ Args:
74
+ quaternions: quaternions with real part last,
75
+ as tensor of shape (..., 4).
76
+
77
+ Returns:
78
+ Rotation matrices as tensor of shape (..., 3, 3).
79
+ """
80
+ i, j, k, r = torch.unbind(quaternions, -1)
81
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
82
+
83
+ o = torch.stack(
84
+ (
85
+ 1 - two_s * (j * j + k * k),
86
+ two_s * (i * j - k * r),
87
+ two_s * (i * k + j * r),
88
+ two_s * (i * j + k * r),
89
+ 1 - two_s * (i * i + k * k),
90
+ two_s * (j * k - i * r),
91
+ two_s * (i * k - j * r),
92
+ two_s * (j * k + i * r),
93
+ 1 - two_s * (i * i + j * j),
94
+ ),
95
+ -1,
96
+ )
97
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
98
+
99
+
100
+ def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
101
+ """
102
+ Convert rotations given as rotation matrices to quaternions.
103
+
104
+ Args:
105
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
106
+
107
+ Returns:
108
+ quaternions with real part last, as tensor of shape (..., 4).
109
+ Quaternion Order: XYZW or say ijkr, scalar-last
110
+ """
111
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
112
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
113
+
114
+ batch_dim = matrix.shape[:-2]
115
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
116
+ matrix.reshape(batch_dim + (9,)), dim=-1
117
+ )
118
+
119
+ q_abs = _sqrt_positive_part(
120
+ torch.stack(
121
+ [
122
+ 1.0 + m00 + m11 + m22,
123
+ 1.0 + m00 - m11 - m22,
124
+ 1.0 - m00 + m11 - m22,
125
+ 1.0 - m00 - m11 + m22,
126
+ ],
127
+ dim=-1,
128
+ )
129
+ )
130
+
131
+ quat_by_rijk = torch.stack(
132
+ [
133
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
134
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
135
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
136
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
137
+ ],
138
+ dim=-2,
139
+ )
140
+
141
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
142
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
143
+
144
+ out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
145
+ batch_dim + (4,)
146
+ )
147
+
148
+ out = out[..., [1, 2, 3, 0]]
149
+
150
+ out = standardize_quaternion(out)
151
+
152
+ return out
153
+
154
+
155
+ def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
156
+ """
157
+ Returns torch.sqrt(torch.max(0, x))
158
+ but with a zero subgradient where x is 0.
159
+ """
160
+ ret = torch.zeros_like(x)
161
+ positive_mask = x > 0
162
+ if torch.is_grad_enabled():
163
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
164
+ else:
165
+ ret = torch.where(positive_mask, torch.sqrt(x), ret)
166
+ return ret
167
+
168
+
169
+ def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
170
+ """
171
+ Convert a unit quaternion to a standard form: one in which the real
172
+ part is non negative.
173
+
174
+ Args:
175
+ quaternions: Quaternions with real part last,
176
+ as tensor of shape (..., 4).
177
+
178
+ Returns:
179
+ Standardized quaternions as tensor of shape (..., 4).
180
+ """
181
+ return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
182
+
183
+
184
+ def cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w):
185
+ # cam_quat_xyzw: (b, n, 4) in xyzw
186
+ # c2w: (b, n, 4, 4)
187
+ b, n = cam_quat_xyzw.shape[:2]
188
+ # 1. xyzw -> wxyz
189
+ cam_quat_wxyz = torch.cat(
190
+ [
191
+ cam_quat_xyzw[..., 3:4], # w
192
+ cam_quat_xyzw[..., 0:1], # x
193
+ cam_quat_xyzw[..., 1:2], # y
194
+ cam_quat_xyzw[..., 2:3], # z
195
+ ],
196
+ dim=-1,
197
+ )
198
+ # 2. Quaternion to matrix
199
+ cam_quat_wxyz_flat = cam_quat_wxyz.reshape(-1, 4)
200
+ rotmat_cam = quat_to_mat(cam_quat_wxyz_flat).reshape(b, n, 3, 3)
201
+ # 3. Transform to world space
202
+ rotmat_c2w = c2w[..., :3, :3]
203
+ rotmat_world = torch.matmul(rotmat_c2w, rotmat_cam)
204
+ # 4. Matrix to quaternion (wxyz)
205
+ rotmat_world_flat = rotmat_world.reshape(-1, 3, 3)
206
+ world_quat_wxyz_flat = mat_to_quat(rotmat_world_flat)
207
+ world_quat_wxyz = world_quat_wxyz_flat.reshape(b, n, 4)
208
+ return world_quat_wxyz
src/depth_anything_3/registry.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import OrderedDict
16
+ from pathlib import Path
17
+
18
+
19
+ def get_all_models() -> OrderedDict:
20
+ """
21
+ Scans all YAML files in the configs directory and returns a sorted dictionary where:
22
+ - Keys are model names (YAML filenames without the .yaml extension)
23
+ - Values are absolute paths to the corresponding YAML files
24
+ """
25
+ # Get path to the configs directory within the da3 package
26
+ # Works both in development and after pip installation
27
+ # configs_dir = files("depth_anything_3").joinpath("configs")
28
+ configs_dir = Path(__file__).resolve().parent / "configs"
29
+
30
+ # Ensure path is a Path object for consistent cross-platform handling
31
+ configs_dir = Path(configs_dir)
32
+
33
+ model_entries = []
34
+ # Iterate through all items in the configs directory
35
+ for item in configs_dir.iterdir():
36
+ # Filter for YAML files (excluding directories)
37
+ if item.is_file() and item.suffix == ".yaml":
38
+ # Extract model name (filename without .yaml extension)
39
+ model_name = item.stem
40
+ # Get absolute path (resolve() handles symlinks)
41
+ file_abs_path = str(item.resolve())
42
+ model_entries.append((model_name, file_abs_path))
43
+
44
+ # Sort entries by model name and convert to OrderedDict
45
+ sorted_entries = sorted(model_entries, key=lambda x: x[0])
46
+ return OrderedDict(sorted_entries)
47
+
48
+
49
+ # Global registry for external imports
50
+ MODEL_REGISTRY = get_all_models()
src/depth_anything_3/services/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ """
16
+ Services module for Depth Anything 3.
17
+ """
18
+
19
+ from depth_anything_3.services.backend import create_app, start_server
20
+
21
+ __all__ = [
22
+ start_server,
23
+ create_app,
24
+ ]