Spaces:
Runtime error
Runtime error
Commit
·
0c4c32b
0
Parent(s):
Initial import
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +5 -0
- ARCHITECTURE.md +90 -0
- gradio_app.py +1321 -0
- src/depth_anything_3/__init__.py +21 -0
- src/depth_anything_3/api.py +416 -0
- src/depth_anything_3/app/__init__.py +1 -0
- src/depth_anything_3/app/css_and_html.py +594 -0
- src/depth_anything_3/app/gradio_app.py +747 -0
- src/depth_anything_3/app/modules/__init__.py +45 -0
- src/depth_anything_3/app/modules/event_handlers.py +629 -0
- src/depth_anything_3/app/modules/file_handlers.py +304 -0
- src/depth_anything_3/app/modules/model_inference.py +286 -0
- src/depth_anything_3/app/modules/ui_components.py +474 -0
- src/depth_anything_3/app/modules/utils.py +211 -0
- src/depth_anything_3/app/modules/visualization.py +434 -0
- src/depth_anything_3/cfg.py +144 -0
- src/depth_anything_3/cli.py +742 -0
- src/depth_anything_3/configs/da3-base.yaml +45 -0
- src/depth_anything_3/configs/da3-giant.yaml +71 -0
- src/depth_anything_3/configs/da3-large.yaml +45 -0
- src/depth_anything_3/configs/da3-small.yaml +45 -0
- src/depth_anything_3/configs/da3metric-large.yaml +28 -0
- src/depth_anything_3/configs/da3mono-large.yaml +28 -0
- src/depth_anything_3/configs/da3nested-giant-large.yaml +10 -0
- src/depth_anything_3/model/__init__.py +20 -0
- src/depth_anything_3/model/cam_dec.py +45 -0
- src/depth_anything_3/model/cam_enc.py +80 -0
- src/depth_anything_3/model/da3.py +378 -0
- src/depth_anything_3/model/dinov2/dinov2.py +64 -0
- src/depth_anything_3/model/dinov2/layers/__init__.py +25 -0
- src/depth_anything_3/model/dinov2/layers/attention.py +100 -0
- src/depth_anything_3/model/dinov2/layers/block.py +143 -0
- src/depth_anything_3/model/dinov2/layers/drop_path.py +35 -0
- src/depth_anything_3/model/dinov2/layers/layer_scale.py +31 -0
- src/depth_anything_3/model/dinov2/layers/mlp.py +40 -0
- src/depth_anything_3/model/dinov2/layers/patch_embed.py +94 -0
- src/depth_anything_3/model/dinov2/layers/rope.py +200 -0
- src/depth_anything_3/model/dinov2/layers/swiglu_ffn.py +62 -0
- src/depth_anything_3/model/dinov2/vision_transformer.py +437 -0
- src/depth_anything_3/model/dpt.py +457 -0
- src/depth_anything_3/model/dualdpt.py +488 -0
- src/depth_anything_3/model/gs_adapter.py +200 -0
- src/depth_anything_3/model/gsdpt.py +133 -0
- src/depth_anything_3/model/utils/attention.py +109 -0
- src/depth_anything_3/model/utils/block.py +81 -0
- src/depth_anything_3/model/utils/gs_renderer.py +340 -0
- src/depth_anything_3/model/utils/head_utils.py +230 -0
- src/depth_anything_3/model/utils/transform.py +208 -0
- src/depth_anything_3/registry.py +50 -0
- src/depth_anything_3/services/__init__.py +24 -0
.gitignore
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
.DS_Store
|
| 3 |
+
.python-version
|
| 4 |
+
data/
|
| 5 |
+
*.pyc
|
ARCHITECTURE.md
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Landing Site Safety Analyzer – Architecture and Calculations
|
| 2 |
+
|
| 3 |
+
This document describes the full computation flow in `gradio_app.py`, from input handling through model inference, safety scoring, and UI composition.
|
| 4 |
+
|
| 5 |
+
## Data and Models
|
| 6 |
+
- **Inputs**: Images from `data/Image/VISLOC` (dropdown populated via `list_visloc_images()` and a 5% border crop via `crop_nonblack` to eliminate black padding). Supported extensions: jpg/jpeg/png (any case).
|
| 7 |
+
- **Depth model**: Depth Anything 3, loaded lazily and cached (`get_model`). Inference is invoked with `process_res` and `process_res_method="upper_bound_resize"` to resize the long image side to the chosen resolution before predicting.
|
| 8 |
+
- **Segmentation model(s)**: Mask2Former ADE20K (`facebook/mask2former-swin-large-ade-semantic`) for optional water and road masking. Loaded lazily and cached per model id (`get_water_segmenter`). Masks are cached per `(model_id, source_path)` to avoid recomputation.
|
| 9 |
+
|
| 10 |
+
## Constants and Assumptions
|
| 11 |
+
- Image crop: 5% on each border before processing.
|
| 12 |
+
- Safety thresholds: user-controlled sliders for flatness (`std_thresh`) and gradient (`grad_thresh`).
|
| 13 |
+
- Altitude and FOV are user-provided (defaults 450 m, 90°).
|
| 14 |
+
- Roof mask mode defaults to “Buildings only” (optionally includes “roof” labels).
|
| 15 |
+
- Flatness detail slider scales the flatness visualization window.
|
| 16 |
+
- Clearance factor dilates hazards to enforce buffer distance.
|
| 17 |
+
|
| 18 |
+
## Per-Image Processing Pipeline
|
| 19 |
+
1. **Load and crop** the selected image to RGB, removing 5% border padding.
|
| 20 |
+
2. **Depth inference**: Run Depth Anything 3 at the input image’s native resolution; get a depth map (`depth_raw`).
|
| 21 |
+
3. **Plane removal**: Fit a best-fit plane over the depth map via least squares (`remove_global_plane`) and subtract it to obtain a detrended depth (`depth`). This prevents large, tilted planes (e.g., concrete yards) from being mislabeled as sloped.
|
| 22 |
+
4. **Footprint to pixels**:
|
| 23 |
+
- Compute focal length in pixels: `fx = (W/2) / tan(FOV/2)`, where `W` is the depth map width and `FOV` is user-set.
|
| 24 |
+
- Convert user footprint (meters) to depth pixels: `patch_px = footprint_m * fx / altitude_m`, clamped to the depth map bounds and forced to be odd.
|
| 25 |
+
5. **Flatness visualization scale**: Compute a separate window size `vis_patch` (odd), scaled by the flatness detail slider and capped to 1/10 of the smallest map dimension, for a visualization-only std map (sharper view).
|
| 26 |
+
6. **Optional masks**:
|
| 27 |
+
- Water and road masks are computed on the (possibly downscaled) RGB using Mask2Former, resized to depth resolution with nearest-neighbor, and converted to boolean arrays. Cached per input path.
|
| 28 |
+
- Roof mask is depth-based: pixels significantly closer than the median depth (per MAD threshold) are treated as raised structures; morphology smooths the mask.
|
| 29 |
+
7. **Flat region search (`pick_flat_patch`)**:
|
| 30 |
+
- Normalize depth to [0,1].
|
| 31 |
+
- Compute local mean and variance via reflective padding and torch avg pooling over `patch` to obtain `std_map`.
|
| 32 |
+
- Compute gradient magnitude via `np.gradient`; normalize by the 95th percentile to get `grad_norm`; derive `grad_mask = grad_norm < grad_thresh`.
|
| 33 |
+
- Combine masks: start with `landing_mask = grad_mask`; if a water mask is available, exclude water pixels; return the lowest-variance window inside this mask as a candidate patch (`box`).
|
| 34 |
+
8. **Safe mask construction**:
|
| 35 |
+
- Initial safe pixels: `(std_map < std_thresh) & (grad_norm < grad_thresh) & landing_mask`.
|
| 36 |
+
- Clearance buffer: build a hazard mask as the inverse of safe pixels, dilate it with an ellipse kernel sized from `clearance_factor * patch_px`, and subtract from safe_mask to enforce distance from hazards.
|
| 37 |
+
- Enforce full-footprint coverage by box filtering the safe mask with a `patch_px` window; keep pixels where coverage >= 0.999.
|
| 38 |
+
- Remove small components: require at least one footprint area (`patch_px * patch_px`) per connected component.
|
| 39 |
+
9. **Landing spot selection**:
|
| 40 |
+
- Prefer centers where the full footprint is safe: find pixels with full coverage, choose the largest connected component, and pick its flattest point (lowest `std_map`).
|
| 41 |
+
- Fallback: center of the flattest window from step 7.
|
| 42 |
+
- Convert the chosen depth-space center to image-space using `scale_x = W_img / W_depth` and draw a square whose side is `patch_px * scale_x` (clamped, min 3 px) on the original image. Also draw a center dot.
|
| 43 |
+
10. **Visualization layers**:
|
| 44 |
+
- Depth visualization uses the original depth (`depth_raw`) rendered with `visualize_depth` and resized to the input image size.
|
| 45 |
+
- Flatness map (std), gradient magnitude, gradient mask, water mask, road mask are resized to image size for display.
|
| 46 |
+
- Safety heatmap overlay and grayscale score are built from the final `safe_mask`.
|
| 47 |
+
|
| 48 |
+
## Safety Heatmap Calculation
|
| 49 |
+
- Input: boolean `safe_mask` (post-coverage + component filtering).
|
| 50 |
+
- Convert to `score` in [0,1].
|
| 51 |
+
- Color mapping (per-pixel):
|
| 52 |
+
- Red channel: `red = (1 - score) ** 0.9 * 255`.
|
| 53 |
+
- Green channel: `green = score ** 1.2 * 200` (gamma and cap to prevent overpowering red).
|
| 54 |
+
- Resize to image size with nearest-neighbor for overlays. A grayscale safety score is also produced for debugging/alternative views.
|
| 55 |
+
|
| 56 |
+
## Overlay Composition (`compose_view`)
|
| 57 |
+
- Start from the selected base view (RGB/Depth/Flatness/Gradient/Mask/Heatmap/Score/Landing spot).
|
| 58 |
+
- Add overlays conditionally with per-layer alpha:
|
| 59 |
+
- Safety heatmap (RGBA alpha from slider).
|
| 60 |
+
- Gradient, flatness maps (alpha from sliders).
|
| 61 |
+
- Water and road masks: convert mask to luminance, modulate alpha by mask pixels and slider, tint red (water) or orange (road).
|
| 62 |
+
- Landing spot overlay is composited last.
|
| 63 |
+
- Output is converted back to RGB.
|
| 64 |
+
|
| 65 |
+
## Caching and State
|
| 66 |
+
- **Model cache**: depth model loaded once per process; segmentation model loaded once per model id.
|
| 67 |
+
- **Mask cache**: water/road masks cached per `(model_id, source_path)` to avoid recomputation across UI tweaks.
|
| 68 |
+
- **Image state**: `images_state` (Gradio State) holds the latest computed layers; overlay-only UI controls recompute compositions without rerunning inference unless base parameters change.
|
| 69 |
+
|
| 70 |
+
## User Controls and Their Effects
|
| 71 |
+
- `process_res`: sets inference resize; affects depth resolution and subsequent pixel scaling.
|
| 72 |
+
- `footprint_m`: desired landing square (meters); converted to pixels via FOV/altitude.
|
| 73 |
+
- `altitude_m`, `fov_deg`: camera parameters for footprint sizing.
|
| 74 |
+
- `flatness_detail`: scales the visualization window for the flatness map.
|
| 75 |
+
- `clearance_factor`: scales the hazard dilation kernel.
|
| 76 |
+
- `std_thresh`, `grad_thresh`: safety criteria for flatness and slope.
|
| 77 |
+
- `use_water_mask` / `use_road_mask` (segmentation-based) and `use_roof_mask` (depth-based): enable masking in safety logic; overlays are separately toggled via `*_on` and alpha sliders.
|
| 78 |
+
- `base_view`, overlay toggles, and alphas: affect only display composition, not safety computation (except when needing masks to be computed for display).
|
| 79 |
+
- `model_id`: selects the Depth Anything 3 checkpoint.
|
| 80 |
+
|
| 81 |
+
## Error Handling
|
| 82 |
+
- Missing/unsupported input paths raise Gradio errors.
|
| 83 |
+
- Water/road masking failures log a warning and fall back to no mask.
|
| 84 |
+
- Coverage/boxFilter failures fall back to looser checks; if no finite std values are available within masks, the flattest region overall is used.
|
| 85 |
+
|
| 86 |
+
## Outputs
|
| 87 |
+
The pipeline returns a dictionary of PIL Images keyed by view name:
|
| 88 |
+
- RGB, Depth, Flatness map (std), Depth gradient, Gradient mask, Water mask, Road mask, Safety heatmap overlay, Safety score (grayscale), Landing spot overlay.
|
| 89 |
+
|
| 90 |
+
These are then composed into the preview according to UI choices.
|
gradio_app.py
ADDED
|
@@ -0,0 +1,1321 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Gradio demo: depth overlays on VISLOC imagery using Depth Anything 3.
|
| 4 |
+
|
| 5 |
+
Run:
|
| 6 |
+
python gradio_app.py
|
| 7 |
+
|
| 8 |
+
Then open the printed local URL. Requires: gradio, pillow, torch, transformers (for water mask).
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
import functools
|
| 13 |
+
import math
|
| 14 |
+
import os
|
| 15 |
+
from pathlib import Path
|
| 16 |
+
|
| 17 |
+
import gradio as gr
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
from PIL import Image, ImageDraw, ImageFilter
|
| 21 |
+
import matplotlib.cm as cm
|
| 22 |
+
|
| 23 |
+
# Prefer installed package; fall back to local src for dev runs.
|
| 24 |
+
try:
|
| 25 |
+
from depth_anything_3.api import DepthAnything3 # type: ignore
|
| 26 |
+
from depth_anything_3.utils.visualize import visualize_depth # type: ignore
|
| 27 |
+
except ModuleNotFoundError:
|
| 28 |
+
import sys
|
| 29 |
+
|
| 30 |
+
ROOT = Path(__file__).resolve().parent
|
| 31 |
+
sys.path.append(str(ROOT / "src"))
|
| 32 |
+
from depth_anything_3.api import DepthAnything3 # noqa: E402
|
| 33 |
+
from depth_anything_3.utils.visualize import visualize_depth # noqa: E402
|
| 34 |
+
|
| 35 |
+
VISLOC_DIR = Path("data/Image/VISLOC")
|
| 36 |
+
HAGDAVS_DIR = Path("data/Image/HAGDAVS")
|
| 37 |
+
VIDEO_DIR = Path("data/Video")
|
| 38 |
+
IMAGE_EXTS = (".jpg", ".jpeg", ".png", ".JPG", ".JPEG", ".PNG")
|
| 39 |
+
VIDEO_EXTS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
|
| 40 |
+
DEFAULT_ALTITUDE_M = 450.0
|
| 41 |
+
ASSUMED_FOV_DEG = 90.0
|
| 42 |
+
WATER_MODEL_ID = "facebook/mask2former-swin-large-ade-semantic"
|
| 43 |
+
ROAD_MODEL_ID = "facebook/mask2former-swin-large-ade-semantic"
|
| 44 |
+
|
| 45 |
+
def crop_nonblack(img: Image.Image, frac: float = 0.05) -> Image.Image:
|
| 46 |
+
"""Naively crop a fixed fraction off each border (to drop black padding)."""
|
| 47 |
+
w, h = img.size
|
| 48 |
+
dx = int(round(w * frac))
|
| 49 |
+
dy = int(round(h * frac))
|
| 50 |
+
return img.crop((dx, dy, w - dx, h - dy))
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@functools.lru_cache(maxsize=1)
|
| 54 |
+
def get_water_segmenter(model_id: str):
|
| 55 |
+
"""Load Mask2Former for water masking (kept on CPU to avoid OOM)."""
|
| 56 |
+
try:
|
| 57 |
+
from transformers import AutoImageProcessor, Mask2FormerForUniversalSegmentation
|
| 58 |
+
except ImportError as e:
|
| 59 |
+
raise ImportError("transformers is required for water masking; install with `pip install transformers`") from e
|
| 60 |
+
device = torch.device("cpu")
|
| 61 |
+
try:
|
| 62 |
+
processor = AutoImageProcessor.from_pretrained(model_id, use_fast=True)
|
| 63 |
+
except TypeError:
|
| 64 |
+
processor = AutoImageProcessor.from_pretrained(model_id)
|
| 65 |
+
model = Mask2FormerForUniversalSegmentation.from_pretrained(model_id).to(device)
|
| 66 |
+
model.eval()
|
| 67 |
+
return processor, model, device
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def compute_water_mask(img: Image.Image, model_id: str, max_side: int = 640) -> np.ndarray | None:
|
| 71 |
+
"""Return boolean mask for water-like classes using Mask2Former ADE weights."""
|
| 72 |
+
processor, model, device = get_water_segmenter(model_id)
|
| 73 |
+
try:
|
| 74 |
+
img_proc = img
|
| 75 |
+
if max(img.size) > max_side:
|
| 76 |
+
scale = max_side / max(img.size)
|
| 77 |
+
new_size = (int(round(img.size[0] * scale)), int(round(img.size[1] * scale)))
|
| 78 |
+
img_proc = img.resize(new_size, resample=Image.BILINEAR)
|
| 79 |
+
try:
|
| 80 |
+
inputs = processor(images=img_proc, return_tensors="pt", use_fast=True).to(device)
|
| 81 |
+
except TypeError:
|
| 82 |
+
inputs = processor(images=img_proc, return_tensors="pt").to(device)
|
| 83 |
+
with torch.inference_mode():
|
| 84 |
+
outputs = model(**inputs)
|
| 85 |
+
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=[img_proc.size[::-1]])[0]
|
| 86 |
+
if torch.is_tensor(seg):
|
| 87 |
+
seg = seg.cpu()
|
| 88 |
+
labels = model.config.id2label
|
| 89 |
+
keywords = ["water", "sea", "lake", "river", "ocean", "pond"]
|
| 90 |
+
water_ids = {i for i, name in labels.items() if any(k in name.lower() for k in keywords)}
|
| 91 |
+
seg_np = np.array(seg)
|
| 92 |
+
mask_small = np.isin(seg_np, list(water_ids)).astype(np.uint8) * 255
|
| 93 |
+
mask_img = Image.fromarray(mask_small).resize(img.size, resample=Image.NEAREST)
|
| 94 |
+
return np.array(mask_img) > 0
|
| 95 |
+
except RuntimeError as e:
|
| 96 |
+
print(f"[WARN] Water masking failed (fallback to no water mask): {e}")
|
| 97 |
+
return None
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def compute_road_mask(img: Image.Image, model_id: str, max_side: int = 640) -> np.ndarray | None:
|
| 101 |
+
"""Return boolean mask for road/highway classes using Mask2Former ADE weights."""
|
| 102 |
+
processor, model, device = get_water_segmenter(model_id)
|
| 103 |
+
try:
|
| 104 |
+
img_proc = img
|
| 105 |
+
if max(img.size) > max_side:
|
| 106 |
+
scale = max_side / max(img.size)
|
| 107 |
+
new_size = (int(round(img.size[0] * scale)), int(round(img.size[1] * scale)))
|
| 108 |
+
img_proc = img.resize(new_size, resample=Image.BILINEAR)
|
| 109 |
+
try:
|
| 110 |
+
inputs = processor(images=img_proc, return_tensors="pt", use_fast=True).to(device)
|
| 111 |
+
except TypeError:
|
| 112 |
+
inputs = processor(images=img_proc, return_tensors="pt").to(device)
|
| 113 |
+
with torch.inference_mode():
|
| 114 |
+
outputs = model(**inputs)
|
| 115 |
+
seg = processor.post_process_semantic_segmentation(outputs, target_sizes=[img_proc.size[::-1]])[0]
|
| 116 |
+
if torch.is_tensor(seg):
|
| 117 |
+
seg = seg.cpu()
|
| 118 |
+
labels = model.config.id2label
|
| 119 |
+
keywords = ["highway", "road", "street", "runway"]
|
| 120 |
+
blocklist = ["field", "park", "grass", "lawn", "garden", "court", "yard", "green"]
|
| 121 |
+
road_ids = {
|
| 122 |
+
i
|
| 123 |
+
for i, name in labels.items()
|
| 124 |
+
if any(k in name.lower() for k in keywords) and not any(b in name.lower() for b in blocklist)
|
| 125 |
+
}
|
| 126 |
+
seg_np = np.array(seg)
|
| 127 |
+
mask_small = np.isin(seg_np, list(road_ids)).astype(np.uint8) * 255
|
| 128 |
+
mask_img = Image.fromarray(mask_small).resize(img.size, resample=Image.NEAREST)
|
| 129 |
+
return np.array(mask_img) > 0
|
| 130 |
+
except RuntimeError as e:
|
| 131 |
+
print(f"[WARN] Road masking failed (fallback to no road mask): {e}")
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def compute_roof_mask_depth(depth: np.ndarray, aggressiveness: float = 1.3, morph_kernel: int = 5) -> np.ndarray:
|
| 136 |
+
"""Depth-based roof/structure mask: flag pixels significantly closer than the median (raised surfaces)."""
|
| 137 |
+
d = depth.astype(np.float32)
|
| 138 |
+
med = np.median(d)
|
| 139 |
+
mad = np.median(np.abs(d - med)) + 1e-6
|
| 140 |
+
threshold = med - aggressiveness * mad
|
| 141 |
+
mask = d < threshold
|
| 142 |
+
mask = mask.astype(np.uint8)
|
| 143 |
+
k = max(1, int(morph_kernel))
|
| 144 |
+
if k % 2 == 0:
|
| 145 |
+
k += 1
|
| 146 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (k, k))
|
| 147 |
+
try:
|
| 148 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
|
| 149 |
+
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
|
| 150 |
+
except Exception:
|
| 151 |
+
pass
|
| 152 |
+
return mask > 0
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def remove_global_plane(depth: np.ndarray) -> np.ndarray:
|
| 156 |
+
"""Remove best-fit global plane from depth to avoid penalizing large flat areas viewed at an angle."""
|
| 157 |
+
if depth.ndim != 2:
|
| 158 |
+
return depth
|
| 159 |
+
h, w = depth.shape
|
| 160 |
+
yy, xx = np.mgrid[0:h, 0:w].astype(np.float32)
|
| 161 |
+
A = np.stack((xx, yy, np.ones_like(xx)), axis=-1).reshape(-1, 3)
|
| 162 |
+
b = depth.astype(np.float32).reshape(-1, 1)
|
| 163 |
+
try:
|
| 164 |
+
coef, _, _, _ = np.linalg.lstsq(A, b, rcond=None)
|
| 165 |
+
plane = (A @ coef).reshape(h, w)
|
| 166 |
+
return depth - plane
|
| 167 |
+
except np.linalg.LinAlgError:
|
| 168 |
+
return depth
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def pick_flat_patch(
|
| 172 |
+
depth: np.ndarray,
|
| 173 |
+
patch: int = 96,
|
| 174 |
+
std_thresh: float = 0.03,
|
| 175 |
+
grad_thresh: float = 0.35,
|
| 176 |
+
water_mask: np.ndarray | None = None,
|
| 177 |
+
):
|
| 178 |
+
"""Find a low-variance depth window as a proxy for flat landing area."""
|
| 179 |
+
depth = depth.astype(np.float32)
|
| 180 |
+
if depth.ndim != 2:
|
| 181 |
+
raise ValueError("Depth map must be 2D (H, W)")
|
| 182 |
+
|
| 183 |
+
patch = max(3, min(patch, min(depth.shape)))
|
| 184 |
+
if patch % 2 == 0:
|
| 185 |
+
patch += 1 # keeps pooling output same size
|
| 186 |
+
depth_norm = (depth - depth.min()) / (depth.ptp() + 1e-6)
|
| 187 |
+
|
| 188 |
+
# Efficient box std via torch avg pooling
|
| 189 |
+
import torch.nn.functional as F
|
| 190 |
+
|
| 191 |
+
def box_mean(arr, k):
|
| 192 |
+
pad = k // 2
|
| 193 |
+
t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
|
| 194 |
+
# Reflective padding avoids dark/bright rims in the std map
|
| 195 |
+
t = F.pad(t, (pad, pad, pad, pad), mode="reflect")
|
| 196 |
+
mean = F.avg_pool2d(t, kernel_size=k, stride=1, padding=0, count_include_pad=False)
|
| 197 |
+
return mean.squeeze(0).squeeze(0).numpy()
|
| 198 |
+
|
| 199 |
+
mean = box_mean(depth_norm, patch)
|
| 200 |
+
mean_sq = box_mean(depth_norm * depth_norm, patch)
|
| 201 |
+
var = np.maximum(mean_sq - mean * mean, 0.0)
|
| 202 |
+
std_map = np.sqrt(var)
|
| 203 |
+
|
| 204 |
+
# Gradient mask to down-weight slopes/edges
|
| 205 |
+
dy, dx = np.gradient(depth_norm)
|
| 206 |
+
grad = np.sqrt(dx * dx + dy * dy)
|
| 207 |
+
grad_ref = np.percentile(grad, 95) + 1e-6
|
| 208 |
+
grad_norm = np.clip(grad / grad_ref, 0.0, 1.0)
|
| 209 |
+
grad_mask = grad_norm < grad_thresh
|
| 210 |
+
|
| 211 |
+
landing_mask = grad_mask
|
| 212 |
+
if water_mask is not None and water_mask.shape == grad_mask.shape:
|
| 213 |
+
landing_mask = landing_mask & (~water_mask)
|
| 214 |
+
|
| 215 |
+
masked_std = np.where(landing_mask, std_map, np.inf)
|
| 216 |
+
if not np.isfinite(masked_std).any():
|
| 217 |
+
masked_std = std_map # fallback: just take the flattest spot
|
| 218 |
+
y, x = np.unravel_index(np.argmin(masked_std), masked_std.shape)
|
| 219 |
+
half = patch // 2
|
| 220 |
+
y0, y1 = max(y - half, 0), min(y + half, depth.shape[0] - 1)
|
| 221 |
+
x0, x1 = max(x - half, 0), min(x + half, depth.shape[1] - 1)
|
| 222 |
+
return (x0, y0, x1, y1), std_map, grad_norm, grad_mask, landing_mask
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def make_safety_heatmap(
|
| 226 |
+
rgb: Image.Image,
|
| 227 |
+
safe_mask: np.ndarray,
|
| 228 |
+
):
|
| 229 |
+
"""Produce a safety heatmap overlay on RGB from a provided safe mask."""
|
| 230 |
+
score = np.clip(safe_mask.astype(np.float32), 0.0, 1.0)
|
| 231 |
+
|
| 232 |
+
# Color: red (unsafe) -> green (safe). Gamma the green channel and cap its max
|
| 233 |
+
# so bright green does not overpower red when blended on the base image.
|
| 234 |
+
green = np.power(score, 1.2) * 200.0
|
| 235 |
+
red = np.power(1.0 - score, 0.9) * 255.0
|
| 236 |
+
heat = np.zeros((*score.shape, 3), dtype=np.uint8)
|
| 237 |
+
heat[..., 0] = red
|
| 238 |
+
heat[..., 1] = green
|
| 239 |
+
heat_img = Image.fromarray(heat).resize(rgb.size, resample=Image.NEAREST)
|
| 240 |
+
score_gray = Image.fromarray((score * 255).astype(np.uint8)).resize(rgb.size, resample=Image.NEAREST)
|
| 241 |
+
return heat_img, score_gray
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
@functools.lru_cache(maxsize=1)
|
| 245 |
+
def get_model(model_id: str = "depth-anything/DA3METRIC-LARGE"):
|
| 246 |
+
"""Load model once and cache."""
|
| 247 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 248 |
+
model = DepthAnything3.from_pretrained(model_id).to(device)
|
| 249 |
+
model.eval()
|
| 250 |
+
return model, device
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
@functools.lru_cache(maxsize=1)
|
| 254 |
+
def list_visloc_images() -> list[Path]:
|
| 255 |
+
"""Return sorted VISLOC image paths from data/Image/VISLOC."""
|
| 256 |
+
if not VISLOC_DIR.exists():
|
| 257 |
+
return []
|
| 258 |
+
files = [p for p in VISLOC_DIR.iterdir() if p.suffix in IMAGE_EXTS]
|
| 259 |
+
return sorted(files)
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
@functools.lru_cache(maxsize=1)
|
| 263 |
+
def list_hagdavs_images() -> list[Path]:
|
| 264 |
+
"""Return sorted HAGDAVS image paths from data/Image/HAGDAVS."""
|
| 265 |
+
if not HAGDAVS_DIR.exists():
|
| 266 |
+
return []
|
| 267 |
+
files = [p for p in HAGDAVS_DIR.iterdir() if p.suffix in IMAGE_EXTS]
|
| 268 |
+
return sorted(files)
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
@functools.lru_cache(maxsize=1)
|
| 272 |
+
def list_videos() -> list[Path]:
|
| 273 |
+
if not VIDEO_DIR.exists():
|
| 274 |
+
return []
|
| 275 |
+
files = [p for p in VIDEO_DIR.iterdir() if p.suffix.lower() in VIDEO_EXTS]
|
| 276 |
+
return sorted(files)
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
@functools.lru_cache(maxsize=1)
|
| 280 |
+
def list_all_data_inputs() -> list[str]:
|
| 281 |
+
"""Collect VISLOC image files for selection."""
|
| 282 |
+
return [str(p) for p in list_visloc_images()]
|
| 283 |
+
|
| 284 |
+
|
| 285 |
+
# Simple cache for water/road masks keyed by (model_id, path)
|
| 286 |
+
WATER_MASK_CACHE: dict[tuple[str, str], np.ndarray] = {}
|
| 287 |
+
ROAD_MASK_CACHE: dict[tuple[str, str], np.ndarray] = {}
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def run_on_image(
|
| 291 |
+
image: Image.Image,
|
| 292 |
+
footprint_m: float,
|
| 293 |
+
std_thresh: float,
|
| 294 |
+
grad_thresh: float,
|
| 295 |
+
use_water_mask: bool,
|
| 296 |
+
use_road_mask: bool,
|
| 297 |
+
use_roof_mask: bool,
|
| 298 |
+
altitude_m: float,
|
| 299 |
+
fov_deg: float,
|
| 300 |
+
flatness_detail: float,
|
| 301 |
+
clearance_factor: float,
|
| 302 |
+
process_res_cap: int,
|
| 303 |
+
roof_aggressiveness: float,
|
| 304 |
+
roof_morph_frac: float,
|
| 305 |
+
segmentation_max_side: int,
|
| 306 |
+
depth_smoothing_base: float,
|
| 307 |
+
coverage_strictness: float,
|
| 308 |
+
min_component_multiplier: float,
|
| 309 |
+
model_id: str,
|
| 310 |
+
source_path: str | None = None,
|
| 311 |
+
) -> dict:
|
| 312 |
+
rgb_np = np.array(image)
|
| 313 |
+
|
| 314 |
+
model, device = get_model(model_id)
|
| 315 |
+
# Fixed upper-bound resolution (cap) while avoiding upscaling small images.
|
| 316 |
+
process_res = min(max(image.size), int(process_res_cap))
|
| 317 |
+
with torch.inference_mode():
|
| 318 |
+
pred = model.inference(
|
| 319 |
+
image=[rgb_np],
|
| 320 |
+
process_res=process_res,
|
| 321 |
+
process_res_method="upper_bound_resize",
|
| 322 |
+
export_dir=None,
|
| 323 |
+
)
|
| 324 |
+
depth_raw = np.array(pred.depth[0])
|
| 325 |
+
depth = remove_global_plane(depth_raw)
|
| 326 |
+
# Smooth depth for resolution-invariant flatness/gradient (higher res -> slightly more smoothing)
|
| 327 |
+
res_scale = max(0.5, min(2.5, process_res / 1024))
|
| 328 |
+
sigma = max(0.0, depth_smoothing_base) * res_scale
|
| 329 |
+
k = max(3, int(round(sigma * 3)) * 2 + 1)
|
| 330 |
+
try:
|
| 331 |
+
depth = cv2.GaussianBlur(depth, (k, k), sigmaX=sigma, sigmaY=sigma)
|
| 332 |
+
except Exception:
|
| 333 |
+
pass
|
| 334 |
+
# Convert landing footprint (meters) to pixels at current processed resolution
|
| 335 |
+
fov = max(10.0, min(170.0, float(fov_deg)))
|
| 336 |
+
altitude = max(1.0, float(altitude_m))
|
| 337 |
+
fx = (depth.shape[1] / 2.0) / math.tan(math.radians(fov) / 2.0)
|
| 338 |
+
patch_px = footprint_m * fx / altitude
|
| 339 |
+
patch_px = max(3, min(int(round(patch_px)), min(depth.shape) - 1))
|
| 340 |
+
if patch_px % 2 == 0:
|
| 341 |
+
patch_px += 1 # keep pooling symmetric
|
| 342 |
+
|
| 343 |
+
# For visualization, compute a flatness map with a smaller, sharper window (decoupled from footprint)
|
| 344 |
+
depth_norm = (depth - depth.min()) / (depth.ptp() + 1e-6)
|
| 345 |
+
vis_patch = max(
|
| 346 |
+
5,
|
| 347 |
+
min(
|
| 348 |
+
int(max(1.0, flatness_detail) * patch_px),
|
| 349 |
+
min(depth.shape) // 10,
|
| 350 |
+
min(depth.shape) - 1,
|
| 351 |
+
),
|
| 352 |
+
)
|
| 353 |
+
if vis_patch % 2 == 0:
|
| 354 |
+
vis_patch += 1
|
| 355 |
+
import torch.nn.functional as F
|
| 356 |
+
|
| 357 |
+
def box_mean_np(arr: np.ndarray, k: int):
|
| 358 |
+
pad = k // 2
|
| 359 |
+
t = torch.from_numpy(arr).unsqueeze(0).unsqueeze(0)
|
| 360 |
+
t = F.pad(t, (pad, pad, pad, pad), mode="reflect")
|
| 361 |
+
mean = F.avg_pool2d(t, kernel_size=k, stride=1, padding=0, count_include_pad=False)
|
| 362 |
+
return mean.squeeze(0).squeeze(0).numpy()
|
| 363 |
+
|
| 364 |
+
std_map_vis = np.sqrt(
|
| 365 |
+
np.maximum(box_mean_np(depth_norm * depth_norm, vis_patch) - box_mean_np(depth_norm, vis_patch) ** 2, 0.0)
|
| 366 |
+
)
|
| 367 |
+
|
| 368 |
+
# Optional water mask (resized to depth resolution)
|
| 369 |
+
water_mask_resized = None
|
| 370 |
+
water_mask_img = None
|
| 371 |
+
if use_water_mask:
|
| 372 |
+
cache_key = (WATER_MODEL_ID, source_path or "", int(segmentation_max_side))
|
| 373 |
+
if cache_key in WATER_MASK_CACHE:
|
| 374 |
+
water_mask_img = WATER_MASK_CACHE[cache_key]
|
| 375 |
+
else:
|
| 376 |
+
water_mask_img = compute_water_mask(image, WATER_MODEL_ID, max_side=segmentation_max_side)
|
| 377 |
+
if source_path is not None and water_mask_img is not None:
|
| 378 |
+
WATER_MASK_CACHE[cache_key] = water_mask_img
|
| 379 |
+
if water_mask_img is not None:
|
| 380 |
+
water_mask_resized = (
|
| 381 |
+
np.array(water_mask_img)
|
| 382 |
+
if isinstance(water_mask_img, np.ndarray)
|
| 383 |
+
else np.array(water_mask_img)
|
| 384 |
+
)
|
| 385 |
+
water_mask_resized = (
|
| 386 |
+
Image.fromarray(water_mask_resized.astype(np.uint8) * 255)
|
| 387 |
+
.resize((depth.shape[1], depth.shape[0]), resample=Image.NEAREST)
|
| 388 |
+
)
|
| 389 |
+
water_mask_resized = np.array(water_mask_resized) > 0
|
| 390 |
+
|
| 391 |
+
road_mask_resized = None
|
| 392 |
+
road_mask_img = None
|
| 393 |
+
if use_road_mask:
|
| 394 |
+
cache_key_r = (ROAD_MODEL_ID, source_path or "", int(segmentation_max_side))
|
| 395 |
+
if cache_key_r in ROAD_MASK_CACHE:
|
| 396 |
+
road_mask_img = ROAD_MASK_CACHE[cache_key_r]
|
| 397 |
+
else:
|
| 398 |
+
road_mask_img = compute_road_mask(image, ROAD_MODEL_ID, max_side=segmentation_max_side)
|
| 399 |
+
if source_path is not None and road_mask_img is not None:
|
| 400 |
+
ROAD_MASK_CACHE[cache_key_r] = road_mask_img
|
| 401 |
+
if road_mask_img is not None:
|
| 402 |
+
road_mask_resized = (
|
| 403 |
+
np.array(road_mask_img)
|
| 404 |
+
if isinstance(road_mask_img, np.ndarray)
|
| 405 |
+
else np.array(road_mask_img)
|
| 406 |
+
)
|
| 407 |
+
road_mask_resized = (
|
| 408 |
+
Image.fromarray(road_mask_resized.astype(np.uint8) * 255)
|
| 409 |
+
.resize((depth.shape[1], depth.shape[0]), resample=Image.NEAREST)
|
| 410 |
+
)
|
| 411 |
+
road_mask_resized = np.array(road_mask_resized) > 0
|
| 412 |
+
roof_mask_resized = None
|
| 413 |
+
if use_roof_mask:
|
| 414 |
+
# Depth-based elevation mask: closer-than-median surfaces are treated as roofs/structures.
|
| 415 |
+
aggressiveness = max(0.5, min(3.0, roof_aggressiveness))
|
| 416 |
+
morph_k = max(3, int(round(patch_px * roof_morph_frac)))
|
| 417 |
+
roof_mask_resized = compute_roof_mask_depth(depth, aggressiveness=aggressiveness, morph_kernel=morph_k)
|
| 418 |
+
|
| 419 |
+
box, std_map, grad_norm, grad_mask, landing_mask = pick_flat_patch(
|
| 420 |
+
depth,
|
| 421 |
+
patch=patch_px,
|
| 422 |
+
std_thresh=std_thresh,
|
| 423 |
+
grad_thresh=grad_thresh,
|
| 424 |
+
water_mask=water_mask_resized,
|
| 425 |
+
)
|
| 426 |
+
if road_mask_resized is not None:
|
| 427 |
+
landing_mask = landing_mask & (~road_mask_resized)
|
| 428 |
+
if roof_mask_resized is not None:
|
| 429 |
+
landing_mask = landing_mask & (~roof_mask_resized)
|
| 430 |
+
safe_mask = (std_map < std_thresh) & (grad_norm < grad_thresh) & landing_mask
|
| 431 |
+
# Clearance: dilate hazards to enforce buffer around unsafe regions
|
| 432 |
+
try:
|
| 433 |
+
clearance_px = max(1, int(round(clearance_factor * patch_px)))
|
| 434 |
+
if clearance_px % 2 == 0:
|
| 435 |
+
clearance_px += 1
|
| 436 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (clearance_px, clearance_px))
|
| 437 |
+
hazard = (~safe_mask).astype(np.uint8)
|
| 438 |
+
buffered = cv2.dilate(hazard, kernel, iterations=1).astype(bool)
|
| 439 |
+
safe_mask = safe_mask & (~buffered)
|
| 440 |
+
except Exception:
|
| 441 |
+
pass
|
| 442 |
+
# Strict footprint coverage: a center is safe only if the full footprint is safe
|
| 443 |
+
try:
|
| 444 |
+
coverage = cv2.boxFilter(
|
| 445 |
+
safe_mask.astype(np.float32),
|
| 446 |
+
ddepth=-1,
|
| 447 |
+
ksize=(patch_px, patch_px),
|
| 448 |
+
normalize=True,
|
| 449 |
+
anchor=(patch_px // 2, patch_px // 2),
|
| 450 |
+
)
|
| 451 |
+
safe_mask = coverage >= max(0.0, min(1.0, coverage_strictness))
|
| 452 |
+
except Exception:
|
| 453 |
+
pass
|
| 454 |
+
|
| 455 |
+
# Drop tiny components: require at least one footprint area
|
| 456 |
+
area_thresh = max(1, int(patch_px * patch_px * max(0.1, min_component_multiplier)))
|
| 457 |
+
num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(safe_mask.astype(np.uint8), connectivity=8)
|
| 458 |
+
if num_labels > 1:
|
| 459 |
+
keep = np.zeros_like(labels, dtype=bool)
|
| 460 |
+
for i in range(1, num_labels):
|
| 461 |
+
if stats[i, cv2.CC_STAT_AREA] >= area_thresh:
|
| 462 |
+
keep |= labels == i
|
| 463 |
+
safe_mask = keep
|
| 464 |
+
|
| 465 |
+
# Recommended landing spot overlay (scaled to input image size)
|
| 466 |
+
# Prefer centers where the full footprint is safe; fall back to best flat spot
|
| 467 |
+
safe_fit = safe_mask.astype(np.float32)
|
| 468 |
+
try:
|
| 469 |
+
coverage = cv2.boxFilter(
|
| 470 |
+
safe_fit.astype(np.float32),
|
| 471 |
+
ddepth=-1,
|
| 472 |
+
ksize=(patch_px, patch_px),
|
| 473 |
+
normalize=True,
|
| 474 |
+
anchor=(patch_px // 2, patch_px // 2),
|
| 475 |
+
)
|
| 476 |
+
valid_centers = coverage >= 1.0
|
| 477 |
+
except Exception:
|
| 478 |
+
valid_centers = safe_fit > 0.5
|
| 479 |
+
|
| 480 |
+
if valid_centers.any():
|
| 481 |
+
cc_mask = valid_centers.astype(np.uint8)
|
| 482 |
+
num_c, labels_c, stats_c, _ = cv2.connectedComponentsWithStats(cc_mask, connectivity=8)
|
| 483 |
+
target_mask = valid_centers
|
| 484 |
+
if num_c > 1:
|
| 485 |
+
# Pick largest safe component by area (skip background)
|
| 486 |
+
areas = stats_c[1:, cv2.CC_STAT_AREA]
|
| 487 |
+
largest_idx = 1 + int(np.argmax(areas))
|
| 488 |
+
target_mask = labels_c == largest_idx
|
| 489 |
+
cand = np.where(target_mask)
|
| 490 |
+
std_cand = std_map[cand]
|
| 491 |
+
idx = np.argmin(std_cand)
|
| 492 |
+
cy, cx = cand[0][idx], cand[1][idx]
|
| 493 |
+
else:
|
| 494 |
+
y0, x0, y1, x1 = box[1], box[0], box[3], box[2]
|
| 495 |
+
cy, cx = (y0 + y1) // 2, (x0 + x1) // 2
|
| 496 |
+
|
| 497 |
+
half = patch_px // 2
|
| 498 |
+
x0 = max(int(cx - half), 0)
|
| 499 |
+
x1 = min(int(cx + half), depth.shape[1] - 1)
|
| 500 |
+
y0 = max(int(cy - half), 0)
|
| 501 |
+
y1 = min(int(cy + half), depth.shape[0] - 1)
|
| 502 |
+
|
| 503 |
+
scale_x = image.width / depth.shape[1]
|
| 504 |
+
scale_y = image.height / depth.shape[0]
|
| 505 |
+
# Draw a box whose side length matches the footprint in input-image pixels
|
| 506 |
+
side_img = max(3, int(round(patch_px * scale_x)))
|
| 507 |
+
cx_img = int(round(cx * scale_x))
|
| 508 |
+
cy_img = int(round(cy * scale_y))
|
| 509 |
+
half_img = side_img // 2
|
| 510 |
+
bx0 = max(cx_img - half_img, 0)
|
| 511 |
+
bx1 = min(cx_img + half_img, image.width - 1)
|
| 512 |
+
by0 = max(cy_img - half_img, 0)
|
| 513 |
+
by1 = min(cy_img + half_img, image.height - 1)
|
| 514 |
+
spot_overlay = Image.new("RGBA", image.size, (0, 0, 0, 0))
|
| 515 |
+
draw = ImageDraw.Draw(spot_overlay)
|
| 516 |
+
draw.rectangle((bx0, by0, bx1, by1), outline=(0, 255, 0, 255), width=4)
|
| 517 |
+
cx, cy = (bx0 + bx1) // 2, (by0 + by1) // 2
|
| 518 |
+
draw.ellipse((cx - 5, cy - 5, cx + 5, cy + 5), fill=(0, 255, 0, 255))
|
| 519 |
+
|
| 520 |
+
depth_vis = Image.fromarray(visualize_depth(depth_raw, cmap="Spectral")).resize(
|
| 521 |
+
image.size, resample=Image.BILINEAR
|
| 522 |
+
)
|
| 523 |
+
flatness_img = Image.fromarray((std_map_vis / (std_map_vis.max() + 1e-6) * 255).astype(np.uint8)).resize(
|
| 524 |
+
image.size, resample=Image.NEAREST
|
| 525 |
+
)
|
| 526 |
+
grad_img = Image.fromarray((grad_norm * 255).astype(np.uint8)).resize(
|
| 527 |
+
image.size, resample=Image.BILINEAR
|
| 528 |
+
)
|
| 529 |
+
grad_mask_img = Image.fromarray(((grad_norm < grad_thresh).astype(np.uint8) * 255)).resize(
|
| 530 |
+
image.size, resample=Image.NEAREST
|
| 531 |
+
)
|
| 532 |
+
water_mask_view = None
|
| 533 |
+
if use_water_mask and water_mask_img is not None:
|
| 534 |
+
water_mask_view = Image.fromarray((np.array(water_mask_img).astype(np.uint8) * 255))
|
| 535 |
+
water_mask_view = water_mask_view.resize(image.size, resample=Image.NEAREST)
|
| 536 |
+
road_mask_view = None
|
| 537 |
+
if use_road_mask and road_mask_img is not None:
|
| 538 |
+
road_mask_view = Image.fromarray((np.array(road_mask_img).astype(np.uint8) * 255))
|
| 539 |
+
road_mask_view = road_mask_view.resize(image.size, resample=Image.NEAREST)
|
| 540 |
+
roof_mask_view = None
|
| 541 |
+
if use_roof_mask and roof_mask_resized is not None:
|
| 542 |
+
roof_mask_view = Image.fromarray((roof_mask_resized.astype(np.uint8) * 255))
|
| 543 |
+
roof_mask_view = roof_mask_view.resize(image.size, resample=Image.NEAREST)
|
| 544 |
+
|
| 545 |
+
heat_overlay, heat_gray = make_safety_heatmap(image, safe_mask)
|
| 546 |
+
|
| 547 |
+
images = {
|
| 548 |
+
"RGB": image,
|
| 549 |
+
"Depth": depth_vis,
|
| 550 |
+
"Flatness map (std)": flatness_img,
|
| 551 |
+
"Depth gradient": grad_img,
|
| 552 |
+
"Gradient mask": grad_mask_img,
|
| 553 |
+
"Water mask": water_mask_view if water_mask_view is not None else Image.new("L", image.size, 0),
|
| 554 |
+
"Road mask": road_mask_view if road_mask_view is not None else Image.new("L", image.size, 0),
|
| 555 |
+
"Roof mask": roof_mask_view if roof_mask_view is not None else Image.new("L", image.size, 0),
|
| 556 |
+
"Safety heatmap overlay": heat_overlay,
|
| 557 |
+
"Safety score": heat_gray,
|
| 558 |
+
"Landing spot overlay": spot_overlay,
|
| 559 |
+
}
|
| 560 |
+
return images
|
| 561 |
+
|
| 562 |
+
|
| 563 |
+
def process_image(
|
| 564 |
+
input_path: str,
|
| 565 |
+
footprint_m: float,
|
| 566 |
+
std_thresh: float,
|
| 567 |
+
grad_thresh: float,
|
| 568 |
+
use_water_mask: bool,
|
| 569 |
+
use_road_mask: bool,
|
| 570 |
+
use_roof_mask: bool,
|
| 571 |
+
altitude_m: float,
|
| 572 |
+
fov_deg: float,
|
| 573 |
+
flatness_detail: float,
|
| 574 |
+
clearance_factor: float,
|
| 575 |
+
process_res_cap: int,
|
| 576 |
+
roof_aggressiveness: float,
|
| 577 |
+
roof_morph_frac: float,
|
| 578 |
+
segmentation_max_side: int,
|
| 579 |
+
depth_smoothing_base: float,
|
| 580 |
+
coverage_strictness: float,
|
| 581 |
+
min_component_multiplier: float,
|
| 582 |
+
model_id: str,
|
| 583 |
+
source_path: str | None = None,
|
| 584 |
+
) -> dict:
|
| 585 |
+
path = Path(input_path)
|
| 586 |
+
if not path.exists():
|
| 587 |
+
raise gr.Error(f"Input path not found: {path}")
|
| 588 |
+
if path.suffix.lower() not in IMAGE_EXTS:
|
| 589 |
+
raise gr.Error(f"Unsupported image type for path: {path}")
|
| 590 |
+
image = crop_nonblack(Image.open(path).convert("RGB"))
|
| 591 |
+
return run_on_image(
|
| 592 |
+
image=image,
|
| 593 |
+
footprint_m=footprint_m,
|
| 594 |
+
std_thresh=std_thresh,
|
| 595 |
+
grad_thresh=grad_thresh,
|
| 596 |
+
use_water_mask=use_water_mask,
|
| 597 |
+
use_road_mask=use_road_mask,
|
| 598 |
+
use_roof_mask=use_roof_mask,
|
| 599 |
+
altitude_m=altitude_m,
|
| 600 |
+
fov_deg=fov_deg,
|
| 601 |
+
flatness_detail=flatness_detail,
|
| 602 |
+
clearance_factor=clearance_factor,
|
| 603 |
+
process_res_cap=process_res_cap,
|
| 604 |
+
roof_aggressiveness=roof_aggressiveness,
|
| 605 |
+
roof_morph_frac=roof_morph_frac,
|
| 606 |
+
segmentation_max_side=segmentation_max_side,
|
| 607 |
+
depth_smoothing_base=depth_smoothing_base,
|
| 608 |
+
coverage_strictness=coverage_strictness,
|
| 609 |
+
min_component_multiplier=min_component_multiplier,
|
| 610 |
+
model_id=model_id,
|
| 611 |
+
source_path=str(path),
|
| 612 |
+
)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def compose_view(
|
| 616 |
+
images_dict: dict,
|
| 617 |
+
base_view: str,
|
| 618 |
+
heat_on: bool,
|
| 619 |
+
heat_alpha: float,
|
| 620 |
+
grad_on: bool,
|
| 621 |
+
grad_alpha: float,
|
| 622 |
+
flat_on: bool,
|
| 623 |
+
flat_alpha: float,
|
| 624 |
+
water_on: bool,
|
| 625 |
+
water_alpha: float,
|
| 626 |
+
water_enabled: bool,
|
| 627 |
+
spot_on: bool,
|
| 628 |
+
road_on: bool,
|
| 629 |
+
road_alpha: float,
|
| 630 |
+
road_enabled: bool,
|
| 631 |
+
roof_on: bool,
|
| 632 |
+
roof_alpha: float,
|
| 633 |
+
roof_enabled: bool,
|
| 634 |
+
) -> Image.Image:
|
| 635 |
+
"""Return a composited view with per-layer alpha controls."""
|
| 636 |
+
if not images_dict:
|
| 637 |
+
raise gr.Error("Run inference first, then select a view.")
|
| 638 |
+
if base_view not in images_dict:
|
| 639 |
+
raise gr.Error(f"Unknown view: {base_view}")
|
| 640 |
+
|
| 641 |
+
base = images_dict.get(base_view)
|
| 642 |
+
if base is None:
|
| 643 |
+
raise gr.Error(f"No image for view: {base_view}")
|
| 644 |
+
out = base.convert("RGBA")
|
| 645 |
+
|
| 646 |
+
if heat_on and "Safety heatmap overlay" in images_dict:
|
| 647 |
+
heat = images_dict["Safety heatmap overlay"]
|
| 648 |
+
if heat is not None:
|
| 649 |
+
heat_rgba = heat.convert("RGBA")
|
| 650 |
+
alpha = int(min(max(heat_alpha, 0.0), 1.0) * 255)
|
| 651 |
+
heat_rgba.putalpha(alpha)
|
| 652 |
+
out = Image.alpha_composite(out, heat_rgba)
|
| 653 |
+
|
| 654 |
+
if grad_on and "Depth gradient" in images_dict:
|
| 655 |
+
grad_img = images_dict["Depth gradient"]
|
| 656 |
+
if grad_img is not None:
|
| 657 |
+
grad_rgba = grad_img.convert("RGBA")
|
| 658 |
+
grad_rgba.putalpha(int(min(max(grad_alpha, 0.0), 1.0) * 255))
|
| 659 |
+
out = Image.alpha_composite(out, grad_rgba)
|
| 660 |
+
|
| 661 |
+
if flat_on and "Flatness map (std)" in images_dict:
|
| 662 |
+
flat_img = images_dict["Flatness map (std)"]
|
| 663 |
+
if flat_img is not None:
|
| 664 |
+
flat_rgba = flat_img.convert("RGBA")
|
| 665 |
+
flat_rgba.putalpha(int(min(max(flat_alpha, 0.0), 1.0) * 255))
|
| 666 |
+
out = Image.alpha_composite(out, flat_rgba)
|
| 667 |
+
|
| 668 |
+
if water_on and water_enabled and "Water mask" in images_dict:
|
| 669 |
+
wm = images_dict["Water mask"]
|
| 670 |
+
if wm is not None:
|
| 671 |
+
m = wm.convert("L")
|
| 672 |
+
overlay = Image.new("RGBA", wm.size, (255, 0, 0, 0))
|
| 673 |
+
alpha = int(min(max(water_alpha, 0.0), 1.0) * 255)
|
| 674 |
+
overlay.putalpha(Image.eval(m, lambda px: int(px * (alpha / 255.0))))
|
| 675 |
+
out = Image.alpha_composite(out, overlay)
|
| 676 |
+
|
| 677 |
+
if road_on and road_enabled and "Road mask" in images_dict:
|
| 678 |
+
rm = images_dict["Road mask"]
|
| 679 |
+
if rm is not None:
|
| 680 |
+
m = rm.convert("L")
|
| 681 |
+
overlay = Image.new("RGBA", rm.size, (255, 165, 0, 0)) # orange
|
| 682 |
+
alpha = int(min(max(road_alpha, 0.0), 1.0) * 255)
|
| 683 |
+
overlay.putalpha(Image.eval(m, lambda px: int(px * (alpha / 255.0))))
|
| 684 |
+
out = Image.alpha_composite(out, overlay)
|
| 685 |
+
|
| 686 |
+
if roof_on and roof_enabled and "Roof mask" in images_dict:
|
| 687 |
+
rf = images_dict["Roof mask"]
|
| 688 |
+
if rf is not None:
|
| 689 |
+
m = rf.convert("L")
|
| 690 |
+
overlay = Image.new("RGBA", rf.size, (255, 0, 255, 0)) # magenta tint for roofs
|
| 691 |
+
alpha = int(min(max(roof_alpha, 0.0), 1.0) * 255)
|
| 692 |
+
overlay.putalpha(Image.eval(m, lambda px: int(px * (alpha / 255.0))))
|
| 693 |
+
out = Image.alpha_composite(out, overlay)
|
| 694 |
+
|
| 695 |
+
if spot_on and "Landing spot overlay" in images_dict:
|
| 696 |
+
spot = images_dict["Landing spot overlay"]
|
| 697 |
+
if spot is not None:
|
| 698 |
+
out = Image.alpha_composite(out, spot.convert("RGBA"))
|
| 699 |
+
|
| 700 |
+
return out.convert("RGB")
|
| 701 |
+
|
| 702 |
+
|
| 703 |
+
def build_ui():
|
| 704 |
+
with gr.Blocks(title="Landing Site Safety Analyzer (VISLOC)") as demo:
|
| 705 |
+
gr.Markdown(
|
| 706 |
+
"## Landing Site Safety Analyzer\n"
|
| 707 |
+
"Run DepthAnything3 on VISLOC images under `data/Image/VISLOC` to evaluate landing zones: depth, safety heatmap, gradients, flatness, and water masks. Toggle layers, footprint, and opacity to assess safety."
|
| 708 |
+
)
|
| 709 |
+
with gr.Row():
|
| 710 |
+
with gr.Column(scale=1, min_width=320):
|
| 711 |
+
gr.Markdown("### Input")
|
| 712 |
+
all_choices = list_all_data_inputs()
|
| 713 |
+
input_path = gr.Dropdown(
|
| 714 |
+
label="Input file",
|
| 715 |
+
choices=all_choices,
|
| 716 |
+
value=all_choices[0] if all_choices else "",
|
| 717 |
+
info="Pick any VISLOC image under data/Image/VISLOC/.",
|
| 718 |
+
)
|
| 719 |
+
footprint_m = gr.Slider(
|
| 720 |
+
label="Landing footprint (meters)",
|
| 721 |
+
value=10,
|
| 722 |
+
minimum=1,
|
| 723 |
+
maximum=150,
|
| 724 |
+
step=1,
|
| 725 |
+
info="Side length (meters) of the clear area required for landing (assumes ~450m altitude, 90° FOV).",
|
| 726 |
+
)
|
| 727 |
+
std_thresh = gr.Slider(
|
| 728 |
+
label="Flatness threshold",
|
| 729 |
+
value=0.01,
|
| 730 |
+
minimum=0.001,
|
| 731 |
+
maximum=0.08,
|
| 732 |
+
step=0.001,
|
| 733 |
+
info="Lower values favor flatter regions when computing the heatmap.",
|
| 734 |
+
)
|
| 735 |
+
grad_thresh = gr.Slider(
|
| 736 |
+
label="Gradient threshold",
|
| 737 |
+
value=0.1,
|
| 738 |
+
minimum=0.02,
|
| 739 |
+
maximum=1.0,
|
| 740 |
+
step=0.01,
|
| 741 |
+
info="Lower values suppress sloped/edgy areas in the heatmap.",
|
| 742 |
+
)
|
| 743 |
+
flatness_detail = gr.Slider(
|
| 744 |
+
label="Flatness detail (relative)",
|
| 745 |
+
value=1.0,
|
| 746 |
+
minimum=0.5,
|
| 747 |
+
maximum=2.5,
|
| 748 |
+
step=0.1,
|
| 749 |
+
info="Scales the window for the flatness visualization; lower = finer detail.",
|
| 750 |
+
)
|
| 751 |
+
clearance_factor = gr.Slider(
|
| 752 |
+
label="Clearance factor",
|
| 753 |
+
value=0.5,
|
| 754 |
+
minimum=0.0,
|
| 755 |
+
maximum=2.0,
|
| 756 |
+
step=0.05,
|
| 757 |
+
info="How much to dilate unsafe regions relative to the footprint to enforce buffer distance.",
|
| 758 |
+
)
|
| 759 |
+
process_res_cap = gr.Slider(
|
| 760 |
+
label="Processing resolution cap",
|
| 761 |
+
value=1024,
|
| 762 |
+
minimum=512,
|
| 763 |
+
maximum=2048,
|
| 764 |
+
step=32,
|
| 765 |
+
info="Upper bound on the longest side fed to the depth model; avoids oversized, noisy inference.",
|
| 766 |
+
)
|
| 767 |
+
depth_smoothing_base = gr.Slider(
|
| 768 |
+
label="Depth smoothing base",
|
| 769 |
+
value=0.8,
|
| 770 |
+
minimum=0.0,
|
| 771 |
+
maximum=2.0,
|
| 772 |
+
step=0.05,
|
| 773 |
+
info="Base Gaussian sigma multiplier for depth smoothing (scaled by resolution).",
|
| 774 |
+
)
|
| 775 |
+
coverage_strictness = gr.Slider(
|
| 776 |
+
label="Coverage strictness",
|
| 777 |
+
value=0.999,
|
| 778 |
+
minimum=0.8,
|
| 779 |
+
maximum=1.0,
|
| 780 |
+
step=0.001,
|
| 781 |
+
info="Minimum fraction of a footprint that must be safe to count a center as safe.",
|
| 782 |
+
)
|
| 783 |
+
min_component_multiplier = gr.Slider(
|
| 784 |
+
label="Min safe area (x footprint)",
|
| 785 |
+
value=1.0,
|
| 786 |
+
minimum=0.1,
|
| 787 |
+
maximum=5.0,
|
| 788 |
+
step=0.1,
|
| 789 |
+
info="Minimum safe component area in multiples of footprint^2.",
|
| 790 |
+
)
|
| 791 |
+
segmentation_max_side = gr.Slider(
|
| 792 |
+
label="Segmentation max side",
|
| 793 |
+
value=640,
|
| 794 |
+
minimum=256,
|
| 795 |
+
maximum=1024,
|
| 796 |
+
step=32,
|
| 797 |
+
info="Resize longest image side to this for water/road segmentation.",
|
| 798 |
+
)
|
| 799 |
+
with gr.Accordion("Camera settings", open=False):
|
| 800 |
+
altitude_m = gr.Slider(
|
| 801 |
+
label="Camera altitude (m)",
|
| 802 |
+
value=450,
|
| 803 |
+
minimum=10,
|
| 804 |
+
maximum=1500,
|
| 805 |
+
step=5,
|
| 806 |
+
info="Altitude used to convert footprint meters to pixels.",
|
| 807 |
+
)
|
| 808 |
+
fov_deg = gr.Slider(
|
| 809 |
+
label="Camera FOV (deg)",
|
| 810 |
+
value=90,
|
| 811 |
+
minimum=30,
|
| 812 |
+
maximum=150,
|
| 813 |
+
step=1,
|
| 814 |
+
info="Horizontal field of view used for footprint sizing.",
|
| 815 |
+
)
|
| 816 |
+
model_id = gr.Dropdown(
|
| 817 |
+
label="Model",
|
| 818 |
+
value="depth-anything/DA3MONO-LARGE",
|
| 819 |
+
choices=[
|
| 820 |
+
"depth-anything/DA3MONO-LARGE",
|
| 821 |
+
"depth-anything/DA3METRIC-LARGE",
|
| 822 |
+
"depth-anything/DA3-BASE",
|
| 823 |
+
"depth-anything/DA3NESTED-GIANT-LARGE",
|
| 824 |
+
],
|
| 825 |
+
info="Which pretrained DepthAnything3 checkpoint to use.",
|
| 826 |
+
)
|
| 827 |
+
with gr.Accordion("Masking", open=True):
|
| 828 |
+
with gr.Row():
|
| 829 |
+
use_water_mask = gr.Checkbox(
|
| 830 |
+
label="Exclude water (segmentation)", value=True, info="Apply water segmentation to down-weight water regions."
|
| 831 |
+
)
|
| 832 |
+
use_road_mask = gr.Checkbox(
|
| 833 |
+
label="Exclude roads (segmentation)", value=True, info="Apply road segmentation to avoid roads/highways."
|
| 834 |
+
)
|
| 835 |
+
use_roof_mask = gr.Checkbox(
|
| 836 |
+
label="Exclude rooftops (depth)", value=True, info="Use depth (closer-than-median) to avoid rooftops/raised structures."
|
| 837 |
+
)
|
| 838 |
+
roof_aggressiveness = gr.Slider(
|
| 839 |
+
label="Rooftop aggressiveness (MAD multiplier)",
|
| 840 |
+
value=1.3,
|
| 841 |
+
minimum=0.5,
|
| 842 |
+
maximum=3.0,
|
| 843 |
+
step=0.05,
|
| 844 |
+
info="Higher = more aggressive exclusion of raised areas in the depth-based rooftop mask.",
|
| 845 |
+
)
|
| 846 |
+
roof_morph_frac = gr.Slider(
|
| 847 |
+
label="Rooftop morph kernel (fraction of footprint px)",
|
| 848 |
+
value=0.15,
|
| 849 |
+
minimum=0.05,
|
| 850 |
+
maximum=0.5,
|
| 851 |
+
step=0.01,
|
| 852 |
+
info="Controls smoothing/merging of rooftop mask relative to footprint size.",
|
| 853 |
+
)
|
| 854 |
+
with gr.Row():
|
| 855 |
+
run_btn = gr.Button("Run", variant="primary", scale=1)
|
| 856 |
+
stop_btn = gr.Button("Stop", variant="stop", scale=1)
|
| 857 |
+
images_state = gr.State({})
|
| 858 |
+
with gr.Column(scale=3):
|
| 859 |
+
gr.Markdown("### Preview")
|
| 860 |
+
main_view = gr.Image(
|
| 861 |
+
label="Preview",
|
| 862 |
+
height=800,
|
| 863 |
+
elem_id="main-preview",
|
| 864 |
+
show_fullscreen_button=False,
|
| 865 |
+
)
|
| 866 |
+
gr.HTML(
|
| 867 |
+
"""
|
| 868 |
+
<style>
|
| 869 |
+
#main-preview img,
|
| 870 |
+
#main-preview canvas { cursor: zoom-in; }
|
| 871 |
+
#main-preview-zoom-overlay {
|
| 872 |
+
position: fixed;
|
| 873 |
+
inset: 0;
|
| 874 |
+
z-index: 1000;
|
| 875 |
+
display: none;
|
| 876 |
+
align-items: center;
|
| 877 |
+
justify-content: center;
|
| 878 |
+
background: rgba(0, 0, 0, 0.85);
|
| 879 |
+
}
|
| 880 |
+
#main-preview-zoom-overlay img {
|
| 881 |
+
max-width: 95vw;
|
| 882 |
+
max-height: 95vh;
|
| 883 |
+
box-shadow: 0 0 24px rgba(0, 0, 0, 0.6);
|
| 884 |
+
}
|
| 885 |
+
</style>
|
| 886 |
+
<div id="main-preview-zoom-overlay"></div>
|
| 887 |
+
<script>
|
| 888 |
+
(() => {
|
| 889 |
+
const containerId = "main-preview";
|
| 890 |
+
const overlayId = "main-preview-zoom-overlay";
|
| 891 |
+
|
| 892 |
+
const ensureOverlay = () => {
|
| 893 |
+
let overlay = document.getElementById(overlayId);
|
| 894 |
+
if (!overlay) {
|
| 895 |
+
overlay = document.createElement("div");
|
| 896 |
+
overlay.id = overlayId;
|
| 897 |
+
document.body.appendChild(overlay);
|
| 898 |
+
}
|
| 899 |
+
overlay.onclick = () => {
|
| 900 |
+
overlay.style.display = "none";
|
| 901 |
+
overlay.innerHTML = "";
|
| 902 |
+
};
|
| 903 |
+
return overlay;
|
| 904 |
+
};
|
| 905 |
+
|
| 906 |
+
const getMedia = (container) => {
|
| 907 |
+
if (!container) return null;
|
| 908 |
+
const img = container.querySelector("img");
|
| 909 |
+
if (img) return { type: "img", el: img, getSrc: () => img.currentSrc || img.src };
|
| 910 |
+
const canvas = container.querySelector("canvas");
|
| 911 |
+
if (canvas) return { type: "canvas", el: canvas, getSrc: () => canvas.toDataURL("image/png") };
|
| 912 |
+
return null;
|
| 913 |
+
};
|
| 914 |
+
|
| 915 |
+
const bind = () => {
|
| 916 |
+
const container = document.getElementById(containerId);
|
| 917 |
+
if (!container || container.dataset.zoomBound) return;
|
| 918 |
+
container.dataset.zoomBound = "1";
|
| 919 |
+
container.addEventListener("click", (ev) => {
|
| 920 |
+
const media = getMedia(container);
|
| 921 |
+
if (!media) return;
|
| 922 |
+
const src = media.getSrc();
|
| 923 |
+
if (!src) return;
|
| 924 |
+
const overlay = ensureOverlay();
|
| 925 |
+
overlay.innerHTML = "";
|
| 926 |
+
const zoomed = document.createElement("img");
|
| 927 |
+
zoomed.src = src;
|
| 928 |
+
overlay.appendChild(zoomed);
|
| 929 |
+
overlay.style.display = "flex";
|
| 930 |
+
ev.stopPropagation();
|
| 931 |
+
});
|
| 932 |
+
};
|
| 933 |
+
|
| 934 |
+
// Poll because Gradio swaps the image element on updates.
|
| 935 |
+
const interval = setInterval(() => {
|
| 936 |
+
const media = getMedia(document.getElementById(containerId));
|
| 937 |
+
if (media && media.el && !media.el.dataset.cursorSet) {
|
| 938 |
+
media.el.dataset.cursorSet = "1";
|
| 939 |
+
media.el.style.cursor = "zoom-in";
|
| 940 |
+
}
|
| 941 |
+
bind();
|
| 942 |
+
}, 500);
|
| 943 |
+
window.addEventListener("beforeunload", () => clearInterval(interval));
|
| 944 |
+
})();
|
| 945 |
+
</script>
|
| 946 |
+
""",
|
| 947 |
+
elem_id="main-preview-zoom-helper",
|
| 948 |
+
)
|
| 949 |
+
with gr.Column(scale=1, min_width=260):
|
| 950 |
+
gr.Markdown("### Overlays")
|
| 951 |
+
base_view = gr.Dropdown(
|
| 952 |
+
label="Base view",
|
| 953 |
+
value="RGB",
|
| 954 |
+
choices=[
|
| 955 |
+
"RGB",
|
| 956 |
+
"Depth",
|
| 957 |
+
"Flatness map (std)",
|
| 958 |
+
"Depth gradient",
|
| 959 |
+
"Gradient mask",
|
| 960 |
+
"Water mask",
|
| 961 |
+
"Safety score",
|
| 962 |
+
"Safety heatmap overlay",
|
| 963 |
+
],
|
| 964 |
+
)
|
| 965 |
+
heat_on = gr.Checkbox(label="Heatmap", value=True, info="Show the safety heatmap overlay.")
|
| 966 |
+
heat_alpha = gr.Slider(
|
| 967 |
+
label="Heatmap alpha", value=0.15, minimum=0.0, maximum=1.0, step=0.05, info="Heatmap opacity."
|
| 968 |
+
)
|
| 969 |
+
grad_on = gr.Checkbox(label="Depth gradient", value=False, info="Overlay the depth gradient magnitude.")
|
| 970 |
+
grad_alpha = gr.Slider(
|
| 971 |
+
label="Gradient alpha", value=0.35, minimum=0.0, maximum=1.0, step=0.05, info="Gradient overlay opacity."
|
| 972 |
+
)
|
| 973 |
+
flat_on = gr.Checkbox(label="Flatness map", value=False, info="Overlay per-pixel flatness (std).")
|
| 974 |
+
flat_alpha = gr.Slider(
|
| 975 |
+
label="Flatness alpha", value=0.25, minimum=0.0, maximum=1.0, step=0.05, info="Flatness overlay opacity."
|
| 976 |
+
)
|
| 977 |
+
spot_on = gr.Checkbox(label="Show landing spot", value=True, info="Overlay the recommended landing box.")
|
| 978 |
+
with gr.Accordion("Mask overlays", open=True):
|
| 979 |
+
water_on = gr.Checkbox(label="Water mask overlay", value=False, info="Overlay detected water regions.")
|
| 980 |
+
water_alpha = gr.Slider(
|
| 981 |
+
label="Water mask alpha",
|
| 982 |
+
value=0.5,
|
| 983 |
+
minimum=0.0,
|
| 984 |
+
maximum=1.0,
|
| 985 |
+
step=0.05,
|
| 986 |
+
info="Water overlay opacity.",
|
| 987 |
+
)
|
| 988 |
+
road_on = gr.Checkbox(label="Road mask overlay", value=False, info="Overlay detected road regions.")
|
| 989 |
+
road_alpha = gr.Slider(
|
| 990 |
+
label="Road mask alpha",
|
| 991 |
+
value=0.5,
|
| 992 |
+
minimum=0.0,
|
| 993 |
+
maximum=1.0,
|
| 994 |
+
step=0.05,
|
| 995 |
+
info="Road overlay opacity.",
|
| 996 |
+
)
|
| 997 |
+
roof_on = gr.Checkbox(label="Roof mask overlay", value=False, info="Overlay detected roof regions.")
|
| 998 |
+
roof_alpha = gr.Slider(
|
| 999 |
+
label="Roof mask alpha",
|
| 1000 |
+
value=0.5,
|
| 1001 |
+
minimum=0.0,
|
| 1002 |
+
maximum=1.0,
|
| 1003 |
+
step=0.05,
|
| 1004 |
+
info="Roof overlay opacity.",
|
| 1005 |
+
)
|
| 1006 |
+
|
| 1007 |
+
def process_any(
|
| 1008 |
+
input_path,
|
| 1009 |
+
footprint_m,
|
| 1010 |
+
std_thresh,
|
| 1011 |
+
grad_thresh,
|
| 1012 |
+
use_water_mask,
|
| 1013 |
+
use_road_mask,
|
| 1014 |
+
use_roof_mask,
|
| 1015 |
+
altitude_m,
|
| 1016 |
+
fov_deg,
|
| 1017 |
+
flatness_detail,
|
| 1018 |
+
clearance_factor,
|
| 1019 |
+
model_id,
|
| 1020 |
+
base_view,
|
| 1021 |
+
heat_on,
|
| 1022 |
+
heat_alpha,
|
| 1023 |
+
grad_on,
|
| 1024 |
+
grad_alpha,
|
| 1025 |
+
flat_on,
|
| 1026 |
+
flat_alpha,
|
| 1027 |
+
water_on,
|
| 1028 |
+
water_alpha,
|
| 1029 |
+
spot_on,
|
| 1030 |
+
road_on,
|
| 1031 |
+
road_alpha,
|
| 1032 |
+
roof_on,
|
| 1033 |
+
roof_alpha,
|
| 1034 |
+
):
|
| 1035 |
+
if not input_path:
|
| 1036 |
+
raise gr.Error("Select an input image first.")
|
| 1037 |
+
path = Path(input_path)
|
| 1038 |
+
if not path.exists():
|
| 1039 |
+
raise gr.Error(f"Input not found: {path}")
|
| 1040 |
+
if path.suffix.lower() in IMAGE_EXTS:
|
| 1041 |
+
imgs = process_image(
|
| 1042 |
+
input_path=str(path),
|
| 1043 |
+
footprint_m=footprint_m,
|
| 1044 |
+
std_thresh=std_thresh,
|
| 1045 |
+
grad_thresh=grad_thresh,
|
| 1046 |
+
use_water_mask=use_water_mask,
|
| 1047 |
+
use_road_mask=use_road_mask,
|
| 1048 |
+
use_roof_mask=use_roof_mask,
|
| 1049 |
+
altitude_m=altitude_m,
|
| 1050 |
+
fov_deg=fov_deg,
|
| 1051 |
+
flatness_detail=flatness_detail,
|
| 1052 |
+
clearance_factor=clearance_factor,
|
| 1053 |
+
model_id=model_id,
|
| 1054 |
+
source_path=str(path),
|
| 1055 |
+
)
|
| 1056 |
+
composed = compose_view(
|
| 1057 |
+
imgs,
|
| 1058 |
+
base_view,
|
| 1059 |
+
heat_on,
|
| 1060 |
+
heat_alpha,
|
| 1061 |
+
grad_on,
|
| 1062 |
+
grad_alpha,
|
| 1063 |
+
flat_on,
|
| 1064 |
+
flat_alpha,
|
| 1065 |
+
water_on,
|
| 1066 |
+
water_alpha,
|
| 1067 |
+
water_enabled=use_water_mask,
|
| 1068 |
+
road_on=road_on,
|
| 1069 |
+
road_alpha=road_alpha,
|
| 1070 |
+
road_enabled=use_road_mask,
|
| 1071 |
+
roof_on=roof_on,
|
| 1072 |
+
roof_alpha=roof_alpha,
|
| 1073 |
+
roof_enabled=use_roof_mask,
|
| 1074 |
+
spot_on=spot_on,
|
| 1075 |
+
)
|
| 1076 |
+
yield imgs, composed
|
| 1077 |
+
else:
|
| 1078 |
+
raise gr.Error(f"Unsupported input type for path: {path} (images only)")
|
| 1079 |
+
|
| 1080 |
+
run_event = run_btn.click(
|
| 1081 |
+
fn=process_any,
|
| 1082 |
+
inputs=[
|
| 1083 |
+
input_path,
|
| 1084 |
+
footprint_m,
|
| 1085 |
+
std_thresh,
|
| 1086 |
+
grad_thresh,
|
| 1087 |
+
use_water_mask,
|
| 1088 |
+
use_road_mask,
|
| 1089 |
+
use_roof_mask,
|
| 1090 |
+
altitude_m,
|
| 1091 |
+
fov_deg,
|
| 1092 |
+
flatness_detail,
|
| 1093 |
+
clearance_factor,
|
| 1094 |
+
model_id,
|
| 1095 |
+
base_view,
|
| 1096 |
+
heat_on,
|
| 1097 |
+
heat_alpha,
|
| 1098 |
+
grad_on,
|
| 1099 |
+
grad_alpha,
|
| 1100 |
+
flat_on,
|
| 1101 |
+
flat_alpha,
|
| 1102 |
+
water_on,
|
| 1103 |
+
water_alpha,
|
| 1104 |
+
spot_on,
|
| 1105 |
+
road_on,
|
| 1106 |
+
road_alpha,
|
| 1107 |
+
roof_on,
|
| 1108 |
+
roof_alpha,
|
| 1109 |
+
],
|
| 1110 |
+
outputs=[images_state, main_view],
|
| 1111 |
+
)
|
| 1112 |
+
stop_btn.click(fn=None, inputs=None, outputs=None, cancels=[run_event])
|
| 1113 |
+
def update_preview_ui(
|
| 1114 |
+
images_state_val,
|
| 1115 |
+
input_path_val,
|
| 1116 |
+
footprint_m_val,
|
| 1117 |
+
std_thresh_val,
|
| 1118 |
+
grad_thresh_val,
|
| 1119 |
+
use_water_mask_val,
|
| 1120 |
+
use_road_mask_val,
|
| 1121 |
+
use_roof_mask_val,
|
| 1122 |
+
altitude_m_val,
|
| 1123 |
+
fov_deg_val,
|
| 1124 |
+
flatness_detail_val,
|
| 1125 |
+
clearance_factor_val,
|
| 1126 |
+
model_id_val,
|
| 1127 |
+
base_view_val,
|
| 1128 |
+
heat_on_val,
|
| 1129 |
+
heat_alpha_val,
|
| 1130 |
+
grad_on_val,
|
| 1131 |
+
grad_alpha_val,
|
| 1132 |
+
flat_on_val,
|
| 1133 |
+
flat_alpha_val,
|
| 1134 |
+
water_on_val,
|
| 1135 |
+
water_alpha_val,
|
| 1136 |
+
spot_on_val,
|
| 1137 |
+
road_on_val,
|
| 1138 |
+
road_alpha_val,
|
| 1139 |
+
roof_on_val,
|
| 1140 |
+
roof_alpha_val,
|
| 1141 |
+
):
|
| 1142 |
+
path = Path(str(input_path_val))
|
| 1143 |
+
imgs_val = images_state_val
|
| 1144 |
+
# If current input is an image, re-run processing to reflect new settings
|
| 1145 |
+
if path.exists() and path.suffix.lower() in IMAGE_EXTS:
|
| 1146 |
+
try:
|
| 1147 |
+
imgs_val = process_image(
|
| 1148 |
+
input_path=str(path),
|
| 1149 |
+
footprint_m=footprint_m_val,
|
| 1150 |
+
std_thresh=std_thresh_val,
|
| 1151 |
+
grad_thresh=grad_thresh_val,
|
| 1152 |
+
use_water_mask=use_water_mask_val,
|
| 1153 |
+
use_road_mask=use_road_mask_val,
|
| 1154 |
+
use_roof_mask=use_roof_mask_val,
|
| 1155 |
+
altitude_m=altitude_m_val,
|
| 1156 |
+
fov_deg=fov_deg_val,
|
| 1157 |
+
flatness_detail=flatness_detail_val,
|
| 1158 |
+
clearance_factor=clearance_factor_val,
|
| 1159 |
+
model_id=model_id_val,
|
| 1160 |
+
)
|
| 1161 |
+
except Exception:
|
| 1162 |
+
imgs_val = images_state_val
|
| 1163 |
+
if not imgs_val:
|
| 1164 |
+
return images_state_val, gr.update()
|
| 1165 |
+
composed = compose_view(
|
| 1166 |
+
imgs_val,
|
| 1167 |
+
base_view_val,
|
| 1168 |
+
heat_on_val,
|
| 1169 |
+
heat_alpha_val,
|
| 1170 |
+
grad_on_val,
|
| 1171 |
+
grad_alpha_val,
|
| 1172 |
+
flat_on_val,
|
| 1173 |
+
flat_alpha_val,
|
| 1174 |
+
water_on_val,
|
| 1175 |
+
water_alpha_val,
|
| 1176 |
+
use_water_mask_val,
|
| 1177 |
+
spot_on_val,
|
| 1178 |
+
road_on_val,
|
| 1179 |
+
road_alpha_val,
|
| 1180 |
+
use_road_mask_val,
|
| 1181 |
+
roof_on_val,
|
| 1182 |
+
roof_alpha_val,
|
| 1183 |
+
use_roof_mask_val,
|
| 1184 |
+
)
|
| 1185 |
+
return imgs_val, composed
|
| 1186 |
+
|
| 1187 |
+
overlay_inputs = [
|
| 1188 |
+
images_state,
|
| 1189 |
+
base_view,
|
| 1190 |
+
heat_on,
|
| 1191 |
+
heat_alpha,
|
| 1192 |
+
grad_on,
|
| 1193 |
+
grad_alpha,
|
| 1194 |
+
flat_on,
|
| 1195 |
+
flat_alpha,
|
| 1196 |
+
water_on,
|
| 1197 |
+
water_alpha,
|
| 1198 |
+
spot_on,
|
| 1199 |
+
use_water_mask,
|
| 1200 |
+
road_on,
|
| 1201 |
+
road_alpha,
|
| 1202 |
+
use_road_mask,
|
| 1203 |
+
roof_on,
|
| 1204 |
+
roof_alpha,
|
| 1205 |
+
use_roof_mask,
|
| 1206 |
+
]
|
| 1207 |
+
|
| 1208 |
+
def update_overlays_only(
|
| 1209 |
+
images_state_val,
|
| 1210 |
+
base_view_val,
|
| 1211 |
+
heat_on_val,
|
| 1212 |
+
heat_alpha_val,
|
| 1213 |
+
grad_on_val,
|
| 1214 |
+
grad_alpha_val,
|
| 1215 |
+
flat_on_val,
|
| 1216 |
+
flat_alpha_val,
|
| 1217 |
+
water_on_val,
|
| 1218 |
+
water_alpha_val,
|
| 1219 |
+
spot_on_val,
|
| 1220 |
+
use_water_mask_val,
|
| 1221 |
+
road_on_val,
|
| 1222 |
+
road_alpha_val,
|
| 1223 |
+
use_road_mask_val,
|
| 1224 |
+
roof_on_val,
|
| 1225 |
+
roof_alpha_val,
|
| 1226 |
+
use_roof_mask_val,
|
| 1227 |
+
):
|
| 1228 |
+
if not images_state_val:
|
| 1229 |
+
return images_state_val, gr.update()
|
| 1230 |
+
return images_state_val, compose_view(
|
| 1231 |
+
images_state_val,
|
| 1232 |
+
base_view_val,
|
| 1233 |
+
heat_on_val,
|
| 1234 |
+
heat_alpha_val,
|
| 1235 |
+
grad_on_val,
|
| 1236 |
+
grad_alpha_val,
|
| 1237 |
+
flat_on_val,
|
| 1238 |
+
flat_alpha_val,
|
| 1239 |
+
water_on_val,
|
| 1240 |
+
water_alpha_val,
|
| 1241 |
+
use_water_mask_val,
|
| 1242 |
+
spot_on_val,
|
| 1243 |
+
road_on_val,
|
| 1244 |
+
road_alpha_val,
|
| 1245 |
+
use_road_mask_val,
|
| 1246 |
+
roof_on_val,
|
| 1247 |
+
roof_alpha_val,
|
| 1248 |
+
use_roof_mask_val,
|
| 1249 |
+
)
|
| 1250 |
+
|
| 1251 |
+
base_view.change(fn=update_overlays_only, inputs=overlay_inputs, outputs=[images_state, main_view])
|
| 1252 |
+
for control in (
|
| 1253 |
+
heat_on,
|
| 1254 |
+
heat_alpha,
|
| 1255 |
+
grad_on,
|
| 1256 |
+
grad_alpha,
|
| 1257 |
+
flat_on,
|
| 1258 |
+
flat_alpha,
|
| 1259 |
+
water_on,
|
| 1260 |
+
water_alpha,
|
| 1261 |
+
spot_on,
|
| 1262 |
+
use_water_mask,
|
| 1263 |
+
road_on,
|
| 1264 |
+
road_alpha,
|
| 1265 |
+
use_road_mask,
|
| 1266 |
+
roof_on,
|
| 1267 |
+
roof_alpha,
|
| 1268 |
+
use_roof_mask,
|
| 1269 |
+
):
|
| 1270 |
+
control.change(fn=update_overlays_only, inputs=overlay_inputs, outputs=[images_state, main_view])
|
| 1271 |
+
|
| 1272 |
+
model_inputs = [
|
| 1273 |
+
images_state,
|
| 1274 |
+
input_path,
|
| 1275 |
+
footprint_m,
|
| 1276 |
+
std_thresh,
|
| 1277 |
+
grad_thresh,
|
| 1278 |
+
use_water_mask,
|
| 1279 |
+
use_road_mask,
|
| 1280 |
+
use_roof_mask,
|
| 1281 |
+
altitude_m,
|
| 1282 |
+
fov_deg,
|
| 1283 |
+
flatness_detail,
|
| 1284 |
+
clearance_factor,
|
| 1285 |
+
model_id,
|
| 1286 |
+
base_view,
|
| 1287 |
+
heat_on,
|
| 1288 |
+
heat_alpha,
|
| 1289 |
+
grad_on,
|
| 1290 |
+
grad_alpha,
|
| 1291 |
+
flat_on,
|
| 1292 |
+
flat_alpha,
|
| 1293 |
+
water_on,
|
| 1294 |
+
water_alpha,
|
| 1295 |
+
spot_on,
|
| 1296 |
+
road_on,
|
| 1297 |
+
road_alpha,
|
| 1298 |
+
roof_on,
|
| 1299 |
+
roof_alpha,
|
| 1300 |
+
]
|
| 1301 |
+
for control in (
|
| 1302 |
+
input_path,
|
| 1303 |
+
footprint_m,
|
| 1304 |
+
std_thresh,
|
| 1305 |
+
grad_thresh,
|
| 1306 |
+
use_water_mask,
|
| 1307 |
+
use_road_mask,
|
| 1308 |
+
use_roof_mask,
|
| 1309 |
+
altitude_m,
|
| 1310 |
+
fov_deg,
|
| 1311 |
+
flatness_detail,
|
| 1312 |
+
clearance_factor,
|
| 1313 |
+
model_id,
|
| 1314 |
+
):
|
| 1315 |
+
control.change(fn=update_preview_ui, inputs=model_inputs, outputs=[images_state, main_view])
|
| 1316 |
+
return demo
|
| 1317 |
+
|
| 1318 |
+
|
| 1319 |
+
if __name__ == "__main__":
|
| 1320 |
+
demo = build_ui()
|
| 1321 |
+
demo.queue().launch()
|
src/depth_anything_3/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Depth Anything 3 Python package entrypoint.
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
from depth_anything_3.api import DepthAnything3
|
| 19 |
+
|
| 20 |
+
__all__ = ["DepthAnything3"]
|
| 21 |
+
__version__ = "0.1.0"
|
src/depth_anything_3/api.py
ADDED
|
@@ -0,0 +1,416 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
"""
|
| 15 |
+
Depth Anything 3 API module.
|
| 16 |
+
|
| 17 |
+
This module provides the main API for Depth Anything 3, including model loading,
|
| 18 |
+
inference, and export capabilities. It supports both single and nested model architectures.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from __future__ import annotations
|
| 22 |
+
|
| 23 |
+
import time
|
| 24 |
+
from typing import Optional, Sequence
|
| 25 |
+
import numpy as np
|
| 26 |
+
import torch
|
| 27 |
+
import torch.nn as nn
|
| 28 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 29 |
+
from PIL import Image
|
| 30 |
+
|
| 31 |
+
from depth_anything_3.cfg import create_object, load_config
|
| 32 |
+
from depth_anything_3.registry import MODEL_REGISTRY
|
| 33 |
+
from depth_anything_3.specs import Prediction
|
| 34 |
+
from depth_anything_3.utils.export import export
|
| 35 |
+
from depth_anything_3.utils.geometry import affine_inverse
|
| 36 |
+
from depth_anything_3.utils.io.input_processor import InputProcessor
|
| 37 |
+
from depth_anything_3.utils.io.output_processor import OutputProcessor
|
| 38 |
+
from depth_anything_3.utils.logger import logger
|
| 39 |
+
from depth_anything_3.utils.pose_align import align_poses_umeyama
|
| 40 |
+
|
| 41 |
+
torch.backends.cudnn.benchmark = False
|
| 42 |
+
# logger.info("CUDNN Benchmark Disabled")
|
| 43 |
+
|
| 44 |
+
SAFETENSORS_NAME = "model.safetensors"
|
| 45 |
+
CONFIG_NAME = "config.json"
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class DepthAnything3(nn.Module, PyTorchModelHubMixin):
|
| 49 |
+
"""
|
| 50 |
+
Depth Anything 3 main API class.
|
| 51 |
+
|
| 52 |
+
This class provides a high-level interface for depth estimation using Depth Anything 3.
|
| 53 |
+
It supports both single and nested model architectures with metric scaling capabilities.
|
| 54 |
+
|
| 55 |
+
Features:
|
| 56 |
+
- Hugging Face Hub integration via PyTorchModelHubMixin
|
| 57 |
+
- Support for multiple model presets (vitb, vitg, nested variants)
|
| 58 |
+
- Automatic mixed precision inference
|
| 59 |
+
- Export capabilities for various formats (GLB, PLY, NPZ, etc.)
|
| 60 |
+
- Camera pose estimation and metric depth scaling
|
| 61 |
+
|
| 62 |
+
Usage:
|
| 63 |
+
# Load from Hugging Face Hub
|
| 64 |
+
model = DepthAnything3.from_pretrained("huggingface/model-name")
|
| 65 |
+
|
| 66 |
+
# Or create with specific preset
|
| 67 |
+
model = DepthAnything3(preset="vitg")
|
| 68 |
+
|
| 69 |
+
# Run inference
|
| 70 |
+
prediction = model.inference(images, export_dir="output", export_format="glb")
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
_commit_hash: str | None = None # Set by mixin when loading from Hub
|
| 74 |
+
|
| 75 |
+
def __init__(self, model_name: str = "da3-large", **kwargs):
|
| 76 |
+
"""
|
| 77 |
+
Initialize DepthAnything3 with specified preset.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
model_name: The name of the model preset to use.
|
| 81 |
+
Examples: 'da3-giant', 'da3-large', 'da3metric-large', 'da3nested-giant-large'.
|
| 82 |
+
**kwargs: Additional keyword arguments (currently unused).
|
| 83 |
+
"""
|
| 84 |
+
super().__init__()
|
| 85 |
+
self.model_name = model_name
|
| 86 |
+
|
| 87 |
+
# Build the underlying network
|
| 88 |
+
self.config = load_config(MODEL_REGISTRY[self.model_name])
|
| 89 |
+
self.model = create_object(self.config)
|
| 90 |
+
self.model.eval()
|
| 91 |
+
|
| 92 |
+
# Initialize processors
|
| 93 |
+
self.input_processor = InputProcessor()
|
| 94 |
+
self.output_processor = OutputProcessor()
|
| 95 |
+
|
| 96 |
+
# Device management (set by user)
|
| 97 |
+
self.device = None
|
| 98 |
+
|
| 99 |
+
@torch.inference_mode()
|
| 100 |
+
def forward(
|
| 101 |
+
self,
|
| 102 |
+
image: torch.Tensor,
|
| 103 |
+
extrinsics: torch.Tensor | None = None,
|
| 104 |
+
intrinsics: torch.Tensor | None = None,
|
| 105 |
+
export_feat_layers: list[int] | None = None,
|
| 106 |
+
infer_gs: bool = False,
|
| 107 |
+
) -> dict[str, torch.Tensor]:
|
| 108 |
+
"""
|
| 109 |
+
Forward pass through the model.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
image: Input batch with shape ``(B, N, 3, H, W)`` on the model device.
|
| 113 |
+
extrinsics: Optional camera extrinsics with shape ``(B, N, 4, 4)``.
|
| 114 |
+
intrinsics: Optional camera intrinsics with shape ``(B, N, 3, 3)``.
|
| 115 |
+
export_feat_layers: Layer indices to return intermediate features for.
|
| 116 |
+
|
| 117 |
+
Returns:
|
| 118 |
+
Dictionary containing model predictions
|
| 119 |
+
"""
|
| 120 |
+
# Determine optimal autocast dtype
|
| 121 |
+
autocast_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 122 |
+
with torch.no_grad():
|
| 123 |
+
with torch.autocast(device_type=image.device.type, dtype=autocast_dtype):
|
| 124 |
+
return self.model(image, extrinsics, intrinsics, export_feat_layers, infer_gs)
|
| 125 |
+
|
| 126 |
+
def inference(
|
| 127 |
+
self,
|
| 128 |
+
image: list[np.ndarray | Image.Image | str],
|
| 129 |
+
extrinsics: np.ndarray | None = None,
|
| 130 |
+
intrinsics: np.ndarray | None = None,
|
| 131 |
+
align_to_input_ext_scale: bool = True,
|
| 132 |
+
infer_gs: bool = False,
|
| 133 |
+
render_exts: np.ndarray | None = None,
|
| 134 |
+
render_ixts: np.ndarray | None = None,
|
| 135 |
+
render_hw: tuple[int, int] | None = None,
|
| 136 |
+
process_res: int = 504,
|
| 137 |
+
process_res_method: str = "upper_bound_resize",
|
| 138 |
+
export_dir: str | None = None,
|
| 139 |
+
export_format: str = "mini_npz",
|
| 140 |
+
export_feat_layers: Sequence[int] | None = None,
|
| 141 |
+
# GLB export parameters
|
| 142 |
+
conf_thresh_percentile: float = 40.0,
|
| 143 |
+
num_max_points: int = 1_000_000,
|
| 144 |
+
show_cameras: bool = True,
|
| 145 |
+
# Feat_vis export parameters
|
| 146 |
+
feat_vis_fps: int = 15,
|
| 147 |
+
# Other export parameters, e.g., gs_ply, gs_video
|
| 148 |
+
export_kwargs: Optional[dict] = None,
|
| 149 |
+
) -> Prediction:
|
| 150 |
+
"""
|
| 151 |
+
Run inference on input images.
|
| 152 |
+
|
| 153 |
+
Args:
|
| 154 |
+
image: List of input images (numpy arrays, PIL Images, or file paths)
|
| 155 |
+
extrinsics: Camera extrinsics (N, 4, 4)
|
| 156 |
+
intrinsics: Camera intrinsics (N, 3, 3)
|
| 157 |
+
align_to_input_ext_scale: whether to align the input pose scale to the prediction
|
| 158 |
+
infer_gs: Enable the 3D Gaussian branch (needed for `gs_ply`/`gs_video` exports)
|
| 159 |
+
render_exts: Optional render extrinsics for Gaussian video export
|
| 160 |
+
render_ixts: Optional render intrinsics for Gaussian video export
|
| 161 |
+
render_hw: Optional render resolution for Gaussian video export
|
| 162 |
+
process_res: Processing resolution
|
| 163 |
+
process_res_method: Resize method for processing
|
| 164 |
+
export_dir: Directory to export results
|
| 165 |
+
export_format: Export format (mini_npz, npz, glb, ply, gs, gs_video)
|
| 166 |
+
export_feat_layers: Layer indices to export intermediate features from
|
| 167 |
+
conf_thresh_percentile: [GLB] Lower percentile for adaptive confidence threshold (default: 40.0) # noqa: E501
|
| 168 |
+
num_max_points: [GLB] Maximum number of points in the point cloud (default: 1,000,000)
|
| 169 |
+
show_cameras: [GLB] Show camera wireframes in the exported scene (default: True)
|
| 170 |
+
feat_vis_fps: [FEAT_VIS] Frame rate for output video (default: 15)
|
| 171 |
+
export_kwargs: additional arguments to export functions.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
Prediction object containing depth maps and camera parameters
|
| 175 |
+
"""
|
| 176 |
+
if "gs" in export_format:
|
| 177 |
+
assert infer_gs, "must set `infer_gs=True` to perform gs-related export."
|
| 178 |
+
|
| 179 |
+
# Preprocess images
|
| 180 |
+
imgs_cpu, extrinsics, intrinsics = self._preprocess_inputs(
|
| 181 |
+
image, extrinsics, intrinsics, process_res, process_res_method
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Prepare tensors for model
|
| 185 |
+
imgs, ex_t, in_t = self._prepare_model_inputs(imgs_cpu, extrinsics, intrinsics)
|
| 186 |
+
|
| 187 |
+
# Normalize extrinsics
|
| 188 |
+
ex_t_norm = self._normalize_extrinsics(ex_t.clone() if ex_t is not None else None)
|
| 189 |
+
|
| 190 |
+
# Run model forward pass
|
| 191 |
+
export_feat_layers = list(export_feat_layers) if export_feat_layers is not None else []
|
| 192 |
+
|
| 193 |
+
raw_output = self._run_model_forward(imgs, ex_t_norm, in_t, export_feat_layers, infer_gs)
|
| 194 |
+
|
| 195 |
+
# Convert raw output to prediction
|
| 196 |
+
prediction = self._convert_to_prediction(raw_output)
|
| 197 |
+
|
| 198 |
+
# Align prediction to extrinsincs
|
| 199 |
+
prediction = self._align_to_input_extrinsics_intrinsics(
|
| 200 |
+
extrinsics, intrinsics, prediction, align_to_input_ext_scale
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Add processed images for visualization
|
| 204 |
+
prediction = self._add_processed_images(prediction, imgs_cpu)
|
| 205 |
+
|
| 206 |
+
# Export if requested
|
| 207 |
+
if export_dir is not None:
|
| 208 |
+
export_kwargs = {} if export_kwargs is None else dict(export_kwargs)
|
| 209 |
+
if "gs" in export_format:
|
| 210 |
+
if infer_gs and "gs_video" not in export_format:
|
| 211 |
+
export_format = f"{export_format}-gs_video"
|
| 212 |
+
if "gs_video" in export_format:
|
| 213 |
+
if "gs_video" not in export_kwargs:
|
| 214 |
+
export_kwargs["gs_video"] = {}
|
| 215 |
+
export_kwargs["gs_video"].update(
|
| 216 |
+
{
|
| 217 |
+
"extrinsics": render_exts,
|
| 218 |
+
"intrinsics": render_ixts,
|
| 219 |
+
"out_image_hw": render_hw,
|
| 220 |
+
}
|
| 221 |
+
)
|
| 222 |
+
# Add GLB export parameters
|
| 223 |
+
if "glb" in export_format:
|
| 224 |
+
if "glb" not in export_kwargs:
|
| 225 |
+
export_kwargs["glb"] = {}
|
| 226 |
+
export_kwargs["glb"].update(
|
| 227 |
+
{
|
| 228 |
+
"conf_thresh_percentile": conf_thresh_percentile,
|
| 229 |
+
"num_max_points": num_max_points,
|
| 230 |
+
"show_cameras": show_cameras,
|
| 231 |
+
}
|
| 232 |
+
)
|
| 233 |
+
# Add Feat_vis export parameters
|
| 234 |
+
if "feat_vis" in export_format:
|
| 235 |
+
if "feat_vis" not in export_kwargs:
|
| 236 |
+
export_kwargs["feat_vis"] = {}
|
| 237 |
+
export_kwargs["feat_vis"].update(
|
| 238 |
+
{
|
| 239 |
+
"fps": feat_vis_fps,
|
| 240 |
+
}
|
| 241 |
+
)
|
| 242 |
+
self._export_results(prediction, export_format, export_dir, **export_kwargs)
|
| 243 |
+
|
| 244 |
+
return prediction
|
| 245 |
+
|
| 246 |
+
def _preprocess_inputs(
|
| 247 |
+
self,
|
| 248 |
+
image: list[np.ndarray | Image.Image | str],
|
| 249 |
+
extrinsics: np.ndarray | None = None,
|
| 250 |
+
intrinsics: np.ndarray | None = None,
|
| 251 |
+
process_res: int = 504,
|
| 252 |
+
process_res_method: str = "upper_bound_resize",
|
| 253 |
+
) -> torch.Tensor:
|
| 254 |
+
"""Preprocess input images using input processor."""
|
| 255 |
+
start_time = time.time()
|
| 256 |
+
imgs_cpu, extrinsics, intrinsics = self.input_processor(
|
| 257 |
+
image,
|
| 258 |
+
extrinsics.copy() if extrinsics is not None else None,
|
| 259 |
+
intrinsics.copy() if intrinsics is not None else None,
|
| 260 |
+
process_res,
|
| 261 |
+
process_res_method,
|
| 262 |
+
)
|
| 263 |
+
end_time = time.time()
|
| 264 |
+
logger.info(
|
| 265 |
+
"Processed Images Done taking",
|
| 266 |
+
end_time - start_time,
|
| 267 |
+
"seconds. Shape: ",
|
| 268 |
+
imgs_cpu.shape,
|
| 269 |
+
)
|
| 270 |
+
return imgs_cpu, extrinsics, intrinsics
|
| 271 |
+
|
| 272 |
+
def _prepare_model_inputs(
|
| 273 |
+
self,
|
| 274 |
+
imgs_cpu: torch.Tensor,
|
| 275 |
+
extrinsics: torch.tensor | None,
|
| 276 |
+
intrinsics: torch.tensor | None,
|
| 277 |
+
) -> tuple[torch.Tensor, torch.Tensor | None, torch.Tensor | None]:
|
| 278 |
+
"""Prepare tensors for model input."""
|
| 279 |
+
device = self._get_model_device()
|
| 280 |
+
|
| 281 |
+
# Move images to model device
|
| 282 |
+
imgs = imgs_cpu.to(device, non_blocking=True)[None].float()
|
| 283 |
+
|
| 284 |
+
# Convert camera parameters to tensors
|
| 285 |
+
ex_t = (
|
| 286 |
+
extrinsics.to(device, non_blocking=True)[None].float()
|
| 287 |
+
if extrinsics is not None
|
| 288 |
+
else None
|
| 289 |
+
)
|
| 290 |
+
in_t = (
|
| 291 |
+
intrinsics.to(device, non_blocking=True)[None].float()
|
| 292 |
+
if intrinsics is not None
|
| 293 |
+
else None
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
return imgs, ex_t, in_t
|
| 297 |
+
|
| 298 |
+
def _normalize_extrinsics(self, ex_t: torch.Tensor) -> torch.Tensor:
|
| 299 |
+
"""Normalize extrinsics"""
|
| 300 |
+
if ex_t is None:
|
| 301 |
+
return None
|
| 302 |
+
transform = affine_inverse(ex_t[:, :1])
|
| 303 |
+
ex_t_norm = ex_t @ transform
|
| 304 |
+
c2ws = affine_inverse(ex_t_norm)
|
| 305 |
+
translations = c2ws[..., :3, 3]
|
| 306 |
+
dists = translations.norm(dim=-1)
|
| 307 |
+
median_dist = torch.median(dists)
|
| 308 |
+
median_dist = torch.clamp(median_dist, min=1e-1)
|
| 309 |
+
ex_t_norm[..., :3, 3] = ex_t_norm[..., :3, 3] / median_dist
|
| 310 |
+
return ex_t_norm
|
| 311 |
+
|
| 312 |
+
def _align_to_input_extrinsics_intrinsics(
|
| 313 |
+
self,
|
| 314 |
+
extrinsics: torch.Tensor,
|
| 315 |
+
intrinsics: torch.Tensor,
|
| 316 |
+
prediction: Prediction,
|
| 317 |
+
align_to_input_ext_scale: bool = True,
|
| 318 |
+
ransac_view_thresh: int = 10,
|
| 319 |
+
) -> Prediction:
|
| 320 |
+
"""Align depth map to input extrinsics"""
|
| 321 |
+
if extrinsics is None or prediction.extrinsics is None:
|
| 322 |
+
return prediction
|
| 323 |
+
if intrinsics is not None:
|
| 324 |
+
prediction.intrinsics = intrinsics.numpy()
|
| 325 |
+
_, _, scale, aligned_extrinsics = align_poses_umeyama(
|
| 326 |
+
prediction.extrinsics,
|
| 327 |
+
extrinsics.numpy(),
|
| 328 |
+
ransac=len(extrinsics) >= ransac_view_thresh,
|
| 329 |
+
return_aligned=True,
|
| 330 |
+
random_state=42,
|
| 331 |
+
)
|
| 332 |
+
if align_to_input_ext_scale:
|
| 333 |
+
prediction.extrinsics = extrinsics[..., :3, :].numpy()
|
| 334 |
+
prediction.depth /= scale
|
| 335 |
+
else:
|
| 336 |
+
prediction.extrinsics = aligned_extrinsics
|
| 337 |
+
return prediction
|
| 338 |
+
|
| 339 |
+
def _run_model_forward(
|
| 340 |
+
self,
|
| 341 |
+
imgs: torch.Tensor,
|
| 342 |
+
ex_t: torch.Tensor | None,
|
| 343 |
+
in_t: torch.Tensor | None,
|
| 344 |
+
export_feat_layers: Sequence[int] | None = None,
|
| 345 |
+
infer_gs: bool = False,
|
| 346 |
+
) -> dict[str, torch.Tensor]:
|
| 347 |
+
"""Run model forward pass."""
|
| 348 |
+
device = imgs.device
|
| 349 |
+
need_sync = device.type == "cuda"
|
| 350 |
+
if need_sync:
|
| 351 |
+
torch.cuda.synchronize(device)
|
| 352 |
+
start_time = time.time()
|
| 353 |
+
feat_layers = list(export_feat_layers) if export_feat_layers is not None else None
|
| 354 |
+
output = self.forward(imgs, ex_t, in_t, feat_layers, infer_gs)
|
| 355 |
+
if need_sync:
|
| 356 |
+
torch.cuda.synchronize(device)
|
| 357 |
+
end_time = time.time()
|
| 358 |
+
logger.info(f"Model Forward Pass Done. Time: {end_time - start_time} seconds")
|
| 359 |
+
return output
|
| 360 |
+
|
| 361 |
+
def _convert_to_prediction(self, raw_output: dict[str, torch.Tensor]) -> Prediction:
|
| 362 |
+
"""Convert raw model output to Prediction object."""
|
| 363 |
+
start_time = time.time()
|
| 364 |
+
output = self.output_processor(raw_output)
|
| 365 |
+
end_time = time.time()
|
| 366 |
+
logger.info(f"Conversion to Prediction Done. Time: {end_time - start_time} seconds")
|
| 367 |
+
return output
|
| 368 |
+
|
| 369 |
+
def _add_processed_images(self, prediction: Prediction, imgs_cpu: torch.Tensor) -> Prediction:
|
| 370 |
+
"""Add processed images to prediction for visualization."""
|
| 371 |
+
# Convert from (N, 3, H, W) to (N, H, W, 3) and denormalize
|
| 372 |
+
processed_imgs = imgs_cpu.permute(0, 2, 3, 1).cpu().numpy() # (N, H, W, 3)
|
| 373 |
+
|
| 374 |
+
# Denormalize from ImageNet normalization
|
| 375 |
+
mean = np.array([0.485, 0.456, 0.406])
|
| 376 |
+
std = np.array([0.229, 0.224, 0.225])
|
| 377 |
+
processed_imgs = processed_imgs * std + mean
|
| 378 |
+
processed_imgs = np.clip(processed_imgs, 0, 1)
|
| 379 |
+
processed_imgs = (processed_imgs * 255).astype(np.uint8)
|
| 380 |
+
|
| 381 |
+
prediction.processed_images = processed_imgs
|
| 382 |
+
return prediction
|
| 383 |
+
|
| 384 |
+
def _export_results(
|
| 385 |
+
self, prediction: Prediction, export_format: str, export_dir: str, **kwargs
|
| 386 |
+
) -> None:
|
| 387 |
+
"""Export results to specified format and directory."""
|
| 388 |
+
start_time = time.time()
|
| 389 |
+
export(prediction, export_format, export_dir, **kwargs)
|
| 390 |
+
end_time = time.time()
|
| 391 |
+
logger.info(f"Export Results Done. Time: {end_time - start_time} seconds")
|
| 392 |
+
|
| 393 |
+
def _get_model_device(self) -> torch.device:
|
| 394 |
+
"""
|
| 395 |
+
Get the device where the model is located.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
Device where the model parameters are located
|
| 399 |
+
|
| 400 |
+
Raises:
|
| 401 |
+
ValueError: If no tensors are found in the model
|
| 402 |
+
"""
|
| 403 |
+
if self.device is not None:
|
| 404 |
+
return self.device
|
| 405 |
+
|
| 406 |
+
# Find device from parameters
|
| 407 |
+
for param in self.parameters():
|
| 408 |
+
self.device = param.device
|
| 409 |
+
return param.device
|
| 410 |
+
|
| 411 |
+
# Find device from buffers
|
| 412 |
+
for buffer in self.buffers():
|
| 413 |
+
self.device = buffer.device
|
| 414 |
+
return buffer.device
|
| 415 |
+
|
| 416 |
+
raise ValueError("No tensor found in model")
|
src/depth_anything_3/app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# Package marker for app modules.
|
src/depth_anything_3/app/css_and_html.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: E501
|
| 2 |
+
|
| 3 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
CSS and HTML content for the Depth Anything 3 Gradio application.
|
| 19 |
+
This module contains all the CSS styles and HTML content blocks
|
| 20 |
+
used in the Gradio interface.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
# CSS Styles for the Gradio interface
|
| 24 |
+
GRADIO_CSS = """
|
| 25 |
+
/* Add Font Awesome CDN with all styles including brands and colors */
|
| 26 |
+
@import url('https://cdnjs.cloudflare.com/ajax/libs/font-awesome/6.4.0/css/all.min.css');
|
| 27 |
+
|
| 28 |
+
/* Add custom styles for colored icons */
|
| 29 |
+
.fa-color-blue {
|
| 30 |
+
color: #3b82f6;
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
.fa-color-purple {
|
| 34 |
+
color: #8b5cf6;
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
.fa-color-cyan {
|
| 38 |
+
color: #06b6d4;
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
.fa-color-green {
|
| 42 |
+
color: #10b981;
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
.fa-color-yellow {
|
| 46 |
+
color: #f59e0b;
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
.fa-color-red {
|
| 50 |
+
color: #ef4444;
|
| 51 |
+
}
|
| 52 |
+
|
| 53 |
+
.link-btn {
|
| 54 |
+
display: inline-flex;
|
| 55 |
+
align-items: center;
|
| 56 |
+
gap: 8px;
|
| 57 |
+
text-decoration: none;
|
| 58 |
+
padding: 12px 24px;
|
| 59 |
+
border-radius: 50px;
|
| 60 |
+
font-weight: 500;
|
| 61 |
+
transition: all 0.3s ease;
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
/* Dark mode tech theme */
|
| 65 |
+
@media (prefers-color-scheme: dark) {
|
| 66 |
+
html, body {
|
| 67 |
+
background: #1e293b;
|
| 68 |
+
color: #ffffff;
|
| 69 |
+
}
|
| 70 |
+
|
| 71 |
+
.gradio-container {
|
| 72 |
+
background: #1e293b;
|
| 73 |
+
color: #ffffff;
|
| 74 |
+
}
|
| 75 |
+
|
| 76 |
+
.link-btn {
|
| 77 |
+
background: rgba(255, 255, 255, 0.2);
|
| 78 |
+
color: white;
|
| 79 |
+
backdrop-filter: blur(10px);
|
| 80 |
+
border: 1px solid rgba(255, 255, 255, 0.3);
|
| 81 |
+
}
|
| 82 |
+
|
| 83 |
+
.link-btn:hover {
|
| 84 |
+
background: rgba(255, 255, 255, 0.3);
|
| 85 |
+
transform: translateY(-2px);
|
| 86 |
+
box-shadow: 0 8px 25px rgba(0, 0, 0, 0.2);
|
| 87 |
+
}
|
| 88 |
+
|
| 89 |
+
.tech-bg {
|
| 90 |
+
background: linear-gradient(135deg, #0f172a, #1e293b); /* Darker colors */
|
| 91 |
+
position: relative;
|
| 92 |
+
overflow: hidden;
|
| 93 |
+
}
|
| 94 |
+
|
| 95 |
+
.tech-bg::before {
|
| 96 |
+
content: '';
|
| 97 |
+
position: absolute;
|
| 98 |
+
top: 0;
|
| 99 |
+
left: 0;
|
| 100 |
+
right: 0;
|
| 101 |
+
bottom: 0;
|
| 102 |
+
background:
|
| 103 |
+
radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
|
| 104 |
+
radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.15) 0%, transparent 50%), /* Reduced opacity */
|
| 105 |
+
radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.1) 0%, transparent 50%); /* Reduced opacity */
|
| 106 |
+
animation: techPulse 8s ease-in-out infinite;
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
.gradio-container .panel,
|
| 110 |
+
.gradio-container .block,
|
| 111 |
+
.gradio-container .form {
|
| 112 |
+
background: rgba(0, 0, 0, 0.3);
|
| 113 |
+
border: 1px solid rgba(59, 130, 246, 0.2);
|
| 114 |
+
border-radius: 10px;
|
| 115 |
+
}
|
| 116 |
+
|
| 117 |
+
.gradio-container * {
|
| 118 |
+
color: #ffffff;
|
| 119 |
+
}
|
| 120 |
+
|
| 121 |
+
.gradio-container label {
|
| 122 |
+
color: #e0e0e0;
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
.gradio-container .markdown {
|
| 126 |
+
color: #e0e0e0;
|
| 127 |
+
}
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
/* Light mode tech theme */
|
| 131 |
+
@media (prefers-color-scheme: light) {
|
| 132 |
+
html, body {
|
| 133 |
+
background: #ffffff;
|
| 134 |
+
color: #1e293b;
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
.gradio-container {
|
| 138 |
+
background: #ffffff;
|
| 139 |
+
color: #1e293b;
|
| 140 |
+
}
|
| 141 |
+
|
| 142 |
+
.tech-bg {
|
| 143 |
+
background: linear-gradient(135deg, #ffffff, #f1f5f9);
|
| 144 |
+
position: relative;
|
| 145 |
+
overflow: hidden;
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
.link-btn {
|
| 149 |
+
background: rgba(59, 130, 246, 0.15);
|
| 150 |
+
color: var(--body-text-color);
|
| 151 |
+
border: 1px solid rgba(59, 130, 246, 0.3);
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
.link-btn:hover {
|
| 155 |
+
background: rgba(59, 130, 246, 0.25);
|
| 156 |
+
transform: translateY(-2px);
|
| 157 |
+
box-shadow: 0 8px 25px rgba(59, 130, 246, 0.2);
|
| 158 |
+
}
|
| 159 |
+
|
| 160 |
+
.tech-bg::before {
|
| 161 |
+
content: '';
|
| 162 |
+
position: absolute;
|
| 163 |
+
top: 0;
|
| 164 |
+
left: 0;
|
| 165 |
+
right: 0;
|
| 166 |
+
bottom: 0;
|
| 167 |
+
background:
|
| 168 |
+
radial-gradient(circle at 20% 80%, rgba(59, 130, 246, 0.1) 0%, transparent 50%),
|
| 169 |
+
radial-gradient(circle at 80% 20%, rgba(139, 92, 246, 0.1) 0%, transparent 50%),
|
| 170 |
+
radial-gradient(circle at 40% 40%, rgba(18, 194, 233, 0.08) 0%, transparent 50%);
|
| 171 |
+
animation: techPulse 8s ease-in-out infinite;
|
| 172 |
+
}
|
| 173 |
+
|
| 174 |
+
.gradio-container .panel,
|
| 175 |
+
.gradio-container .block,
|
| 176 |
+
.gradio-container .form {
|
| 177 |
+
background: rgba(255, 255, 255, 0.8);
|
| 178 |
+
border: 1px solid rgba(59, 130, 246, 0.3);
|
| 179 |
+
border-radius: 10px;
|
| 180 |
+
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1);
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
.gradio-container * {
|
| 184 |
+
color: #1e293b;
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
.gradio-container label {
|
| 188 |
+
color: #334155;
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
.gradio-container .markdown {
|
| 192 |
+
color: #334155;
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
@keyframes techPulse {
|
| 200 |
+
0%, 100% { opacity: 0.5; }
|
| 201 |
+
50% { opacity: 0.8; }
|
| 202 |
+
}
|
| 203 |
+
|
| 204 |
+
/* Custom log with tech gradient */
|
| 205 |
+
.custom-log * {
|
| 206 |
+
font-style: italic;
|
| 207 |
+
font-size: 22px !important;
|
| 208 |
+
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
|
| 209 |
+
background-size: 400% 400%;
|
| 210 |
+
-webkit-background-clip: text;
|
| 211 |
+
background-clip: text;
|
| 212 |
+
font-weight: bold !important;
|
| 213 |
+
color: transparent !important;
|
| 214 |
+
text-align: center !important;
|
| 215 |
+
animation: techGradient 3s ease infinite;
|
| 216 |
+
}
|
| 217 |
+
|
| 218 |
+
@keyframes techGradient {
|
| 219 |
+
0% { background-position: 0% 50%; }
|
| 220 |
+
50% { background-position: 100% 50%; }
|
| 221 |
+
100% { background-position: 0% 50%; }
|
| 222 |
+
}
|
| 223 |
+
|
| 224 |
+
@keyframes metricPulse {
|
| 225 |
+
0%, 100% { background-position: 0% 50%; }
|
| 226 |
+
50% { background-position: 100% 50%; }
|
| 227 |
+
}
|
| 228 |
+
|
| 229 |
+
@keyframes pointcloudPulse {
|
| 230 |
+
0%, 100% { background-position: 0% 50%; }
|
| 231 |
+
50% { background-position: 100% 50%; }
|
| 232 |
+
}
|
| 233 |
+
|
| 234 |
+
@keyframes camerasPulse {
|
| 235 |
+
0%, 100% { background-position: 0% 50%; }
|
| 236 |
+
50% { background-position: 100% 50%; }
|
| 237 |
+
}
|
| 238 |
+
|
| 239 |
+
@keyframes gaussiansPulse {
|
| 240 |
+
0%, 100% { background-position: 0% 50%; }
|
| 241 |
+
50% { background-position: 100% 50%; }
|
| 242 |
+
}
|
| 243 |
+
|
| 244 |
+
/* Special colors for key terms - Global styles */
|
| 245 |
+
.metric-text {
|
| 246 |
+
background: linear-gradient(45deg, #ff6b6b, #ff8e53, #ff6b6b);
|
| 247 |
+
background-size: 200% 200%;
|
| 248 |
+
-webkit-background-clip: text;
|
| 249 |
+
background-clip: text;
|
| 250 |
+
color: transparent !important;
|
| 251 |
+
animation: metricPulse 2s ease-in-out infinite;
|
| 252 |
+
font-weight: 700;
|
| 253 |
+
text-shadow: 0 0 10px rgba(255, 107, 107, 0.5);
|
| 254 |
+
}
|
| 255 |
+
|
| 256 |
+
.pointcloud-text {
|
| 257 |
+
background: linear-gradient(45deg, #4ecdc4, #44a08d, #4ecdc4);
|
| 258 |
+
background-size: 200% 200%;
|
| 259 |
+
-webkit-background-clip: text;
|
| 260 |
+
background-clip: text;
|
| 261 |
+
color: transparent !important;
|
| 262 |
+
animation: pointcloudPulse 2.5s ease-in-out infinite;
|
| 263 |
+
font-weight: 700;
|
| 264 |
+
text-shadow: 0 0 10px rgba(78, 205, 196, 0.5);
|
| 265 |
+
}
|
| 266 |
+
|
| 267 |
+
.cameras-text {
|
| 268 |
+
background: linear-gradient(45deg, #667eea, #764ba2, #667eea);
|
| 269 |
+
background-size: 200% 200%;
|
| 270 |
+
-webkit-background-clip: text;
|
| 271 |
+
background-clip: text;
|
| 272 |
+
color: transparent !important;
|
| 273 |
+
animation: camerasPulse 3s ease-in-out infinite;
|
| 274 |
+
font-weight: 700;
|
| 275 |
+
text-shadow: 0 0 10px rgba(102, 126, 234, 0.5);
|
| 276 |
+
}
|
| 277 |
+
|
| 278 |
+
.gaussians-text {
|
| 279 |
+
background: linear-gradient(45deg, #f093fb, #f5576c, #f093fb);
|
| 280 |
+
background-size: 200% 200%;
|
| 281 |
+
-webkit-background-clip: text;
|
| 282 |
+
background-clip: text;
|
| 283 |
+
color: transparent !important;
|
| 284 |
+
animation: gaussiansPulse 2.2s ease-in-out infinite;
|
| 285 |
+
font-weight: 700;
|
| 286 |
+
text-shadow: 0 0 10px rgba(240, 147, 251, 0.5);
|
| 287 |
+
}
|
| 288 |
+
|
| 289 |
+
.example-log * {
|
| 290 |
+
font-style: italic;
|
| 291 |
+
font-size: 16px !important;
|
| 292 |
+
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
|
| 293 |
+
-webkit-background-clip: text;
|
| 294 |
+
background-clip: text;
|
| 295 |
+
color: transparent !important;
|
| 296 |
+
}
|
| 297 |
+
|
| 298 |
+
#my_radio .wrap {
|
| 299 |
+
display: flex;
|
| 300 |
+
flex-wrap: nowrap;
|
| 301 |
+
justify-content: center;
|
| 302 |
+
align-items: center;
|
| 303 |
+
}
|
| 304 |
+
|
| 305 |
+
#my_radio .wrap label {
|
| 306 |
+
display: flex;
|
| 307 |
+
width: 50%;
|
| 308 |
+
justify-content: center;
|
| 309 |
+
align-items: center;
|
| 310 |
+
margin: 0;
|
| 311 |
+
padding: 10px 0;
|
| 312 |
+
box-sizing: border-box;
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
/* Align navigation buttons with dropdown bottom */
|
| 316 |
+
.navigation-row {
|
| 317 |
+
display: flex !important;
|
| 318 |
+
align-items: flex-end !important;
|
| 319 |
+
gap: 8px !important;
|
| 320 |
+
}
|
| 321 |
+
|
| 322 |
+
.navigation-row > div:nth-child(1),
|
| 323 |
+
.navigation-row > div:nth-child(3) {
|
| 324 |
+
align-self: flex-end !important;
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
.navigation-row > div:nth-child(2) {
|
| 328 |
+
flex: 1 !important;
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
/* Make thumbnails clickable with pointer cursor */
|
| 332 |
+
.clickable-thumbnail img {
|
| 333 |
+
cursor: pointer !important;
|
| 334 |
+
}
|
| 335 |
+
|
| 336 |
+
.clickable-thumbnail:hover img {
|
| 337 |
+
cursor: pointer !important;
|
| 338 |
+
opacity: 0.8;
|
| 339 |
+
transition: opacity 0.3s ease;
|
| 340 |
+
}
|
| 341 |
+
|
| 342 |
+
/* Make thumbnail containers narrower horizontally */
|
| 343 |
+
.clickable-thumbnail {
|
| 344 |
+
padding: 5px 2px !important;
|
| 345 |
+
margin: 0 2px !important;
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
.clickable-thumbnail .image-container {
|
| 349 |
+
margin: 0 !important;
|
| 350 |
+
padding: 0 !important;
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
.scene-info {
|
| 354 |
+
text-align: center !important;
|
| 355 |
+
padding: 5px 2px !important;
|
| 356 |
+
margin: 0 !important;
|
| 357 |
+
}
|
| 358 |
+
"""
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def get_header_html(logo_base64=None):
|
| 362 |
+
"""
|
| 363 |
+
Generate the main header HTML with logo and title.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
logo_base64 (str, optional): Base64 encoded logo image
|
| 367 |
+
|
| 368 |
+
Returns:
|
| 369 |
+
str: HTML string for the header
|
| 370 |
+
"""
|
| 371 |
+
return """
|
| 372 |
+
<div class="tech-bg" style="text-align: center; margin-bottom: 5px; padding: 40px 20px; border-radius: 15px; position: relative; overflow: hidden;">
|
| 373 |
+
<div style="position: relative; z-index: 2;">
|
| 374 |
+
<h1 style="margin: 0; font-size: 3.5em; font-weight: 700;
|
| 375 |
+
background: linear-gradient(135deg, #3b82f6, #8b5cf6);
|
| 376 |
+
background-size: 400% 400%;
|
| 377 |
+
-webkit-background-clip: text;
|
| 378 |
+
background-clip: text;
|
| 379 |
+
color: transparent;
|
| 380 |
+
animation: techGradient 3s ease infinite;
|
| 381 |
+
text-shadow: 0 0 30px rgba(59, 130, 246, 0.5);
|
| 382 |
+
letter-spacing: 2px;">
|
| 383 |
+
Depth Anything 3
|
| 384 |
+
</h1>
|
| 385 |
+
<p style="margin: 15px 0 0 0; font-size: 2.16em; font-weight: 300;" class="header-subtitle">
|
| 386 |
+
Recovering the Visual Space from Any Views
|
| 387 |
+
</p>
|
| 388 |
+
<div style="margin-top: 20px;">
|
| 389 |
+
<!-- Revert buttons to original inline styles -->
|
| 390 |
+
<a href="https://depth-anything-3.github.io" target="_blank" class="link-btn">
|
| 391 |
+
<i class="fas fa-globe" style="margin-right: 8px;"></i> Project Page
|
| 392 |
+
</a>
|
| 393 |
+
<a href="https://arxiv.org/abs/2406.09414" target="_blank" class="link-btn">
|
| 394 |
+
<i class="fas fa-file-pdf" style="margin-right: 8px;"></i> Paper
|
| 395 |
+
</a>
|
| 396 |
+
<a href="https://github.com/ByteDance-Seed/Depth-Anything-3" target="_blank" class="link-btn">
|
| 397 |
+
<i class="fab fa-github" style="margin-right: 8px;"></i> Code
|
| 398 |
+
</a>
|
| 399 |
+
</div>
|
| 400 |
+
</div>
|
| 401 |
+
</div>
|
| 402 |
+
|
| 403 |
+
<style>
|
| 404 |
+
/* Ensure tech-bg class is properly applied in dark mode */
|
| 405 |
+
@media (prefers-color-scheme: dark) {
|
| 406 |
+
.header-subtitle {
|
| 407 |
+
color: #cbd5e1;
|
| 408 |
+
}
|
| 409 |
+
/* Increase priority to ensure background color is properly applied */
|
| 410 |
+
.tech-bg {
|
| 411 |
+
background: linear-gradient(135deg, #0f172a, #1e293b) !important;
|
| 412 |
+
}
|
| 413 |
+
}
|
| 414 |
+
|
| 415 |
+
@media (prefers-color-scheme: light) {
|
| 416 |
+
.header-subtitle {
|
| 417 |
+
color: #475569;
|
| 418 |
+
}
|
| 419 |
+
/* Also add explicit background color for light mode */
|
| 420 |
+
.tech-bg {
|
| 421 |
+
background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%) !important;
|
| 422 |
+
}
|
| 423 |
+
}
|
| 424 |
+
</style>
|
| 425 |
+
"""
|
| 426 |
+
|
| 427 |
+
|
| 428 |
+
def get_description_html():
|
| 429 |
+
"""
|
| 430 |
+
Generate the main description and getting started HTML.
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
str: HTML string for the description
|
| 434 |
+
"""
|
| 435 |
+
return """
|
| 436 |
+
<div class="description-container" style="padding: 25px; border-radius: 15px; margin: 0 0 20px 0;">
|
| 437 |
+
<h2 class="description-title" style="margin-top: 0; font-size: 1.6em; text-align: center;">
|
| 438 |
+
<i class="fas fa-bullseye fa-color-red" style="margin-right: 8px;"></i> What This Demo Does
|
| 439 |
+
</h2>
|
| 440 |
+
<div class="description-content" style="padding: 20px; border-radius: 10px; margin: 15px 0; text-align: center;">
|
| 441 |
+
<p class="description-main" style="line-height: 1.6; margin: 0; font-size: 1.45em;">
|
| 442 |
+
<strong>Upload images or videos</strong> → <strong>Get <span class="metric-text">Metric</span> <span class="pointcloud-text">Point Clouds</span>, <span class="cameras-text">Cameras</span> and <span class="gaussians-text">Novel Views</span></strong> → <strong>Explore in 3D</strong>
|
| 443 |
+
</p>
|
| 444 |
+
</div>
|
| 445 |
+
|
| 446 |
+
<div style="text-align: center; margin-top: 15px;">
|
| 447 |
+
<p class="description-tip" style="font-style: italic; margin: 0;">
|
| 448 |
+
<i class="fas fa-lightbulb fa-color-yellow" style="margin-right: 8px;"></i> <strong>Tip:</strong> Landscape-oriented images or videos are preferred for best 3D recovering.
|
| 449 |
+
</p>
|
| 450 |
+
</div>
|
| 451 |
+
</div>
|
| 452 |
+
|
| 453 |
+
<style>
|
| 454 |
+
@media (prefers-color-scheme: dark) {
|
| 455 |
+
.description-container {
|
| 456 |
+
background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
|
| 457 |
+
border: 1px solid rgba(59, 130, 246, 0.2);
|
| 458 |
+
}
|
| 459 |
+
.description-title { color: #3b82f6; }
|
| 460 |
+
.description-content { background: rgba(0, 0, 0, 0.3); }
|
| 461 |
+
.description-main { color: #e0e0e0; }
|
| 462 |
+
.description-text { color: #cbd5e1; }
|
| 463 |
+
.description-tip { color: #cbd5e1; }
|
| 464 |
+
}
|
| 465 |
+
|
| 466 |
+
@media (prefers-color-scheme: light) {
|
| 467 |
+
.description-container {
|
| 468 |
+
background: linear-gradient(135deg, rgba(59, 130, 246, 0.05) 0%, rgba(139, 92, 246, 0.05) 100%);
|
| 469 |
+
border: 1px solid rgba(59, 130, 246, 0.3);
|
| 470 |
+
}
|
| 471 |
+
.description-title { color: #3b82f6; }
|
| 472 |
+
.description-content { background: transparent; }
|
| 473 |
+
.description-main { color: #1e293b; }
|
| 474 |
+
.description-text { color: #475569; }
|
| 475 |
+
.description-tip { color: #475569; }
|
| 476 |
+
}
|
| 477 |
+
</style>
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def get_acknowledgements_html():
|
| 482 |
+
"""
|
| 483 |
+
Generate the acknowledgements section HTML.
|
| 484 |
+
|
| 485 |
+
Returns:
|
| 486 |
+
str: HTML string for the acknowledgements
|
| 487 |
+
"""
|
| 488 |
+
return """
|
| 489 |
+
<div style="background: linear-gradient(135deg, rgba(59, 130, 246, 0.1) 0%, rgba(139, 92, 246, 0.1) 100%);
|
| 490 |
+
padding: 25px; border-radius: 15px; margin: 20px 0; border: 1px solid rgba(59, 130, 246, 0.2);">
|
| 491 |
+
<h3 style="color: #3b82f6; margin-top: 0; text-align: center; font-size: 1.4em;">
|
| 492 |
+
<i class="fas fa-trophy fa-color-yellow" style="margin-right: 8px;"></i> Research Credits & Acknowledgments
|
| 493 |
+
</h3>
|
| 494 |
+
|
| 495 |
+
<div style="display: grid; grid-template-columns: 1fr 1fr; gap: 20px; margin: 15px 0;">
|
| 496 |
+
<!-- Original Research Section (Left) -->
|
| 497 |
+
<div style="text-align: center;">
|
| 498 |
+
<h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-flask fa-color-green" style="margin-right: 8px;"></i> Original Research</h4>
|
| 499 |
+
<p style="color: #e0e0e0; margin: 5px 0;">
|
| 500 |
+
<a href="https://depth-anything-3.github.io" target="_blank"
|
| 501 |
+
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
|
| 502 |
+
Depth Anything 3
|
| 503 |
+
</a>
|
| 504 |
+
</p>
|
| 505 |
+
</div>
|
| 506 |
+
|
| 507 |
+
<!-- Previous Versions Section (Right) -->
|
| 508 |
+
<div style="text-align: center;">
|
| 509 |
+
<h4 style="color: #8b5cf6; margin: 10px 0;"><i class="fas fa-history fa-color-blue" style="margin-right: 8px;"></i> Previous Versions</h4>
|
| 510 |
+
<div style="display: flex; flex-direction: row; gap: 15px; justify-content: center; align-items: center;">
|
| 511 |
+
<p style="color: #e0e0e0; margin: 0;">
|
| 512 |
+
<a href="https://huggingface.co/spaces/LiheYoung/Depth-Anything" target="_blank"
|
| 513 |
+
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
|
| 514 |
+
Depth-Anything
|
| 515 |
+
</a>
|
| 516 |
+
</p>
|
| 517 |
+
<span style="color: #e0e0e0;">•</span>
|
| 518 |
+
<p style="color: #e0e0e0; margin: 0;">
|
| 519 |
+
<a href="https://huggingface.co/spaces/depth-anything/Depth-Anything-V2" target="_blank"
|
| 520 |
+
style="color: #3b82f6; text-decoration: none; font-weight: 600;">
|
| 521 |
+
Depth-Anything-V2
|
| 522 |
+
</a>
|
| 523 |
+
</p>
|
| 524 |
+
</div>
|
| 525 |
+
</div>
|
| 526 |
+
</div>
|
| 527 |
+
|
| 528 |
+
<!-- HF Demo Adapted from - Centered at the bottom of the whole block -->
|
| 529 |
+
<div style="margin-top: 20px; padding-top: 15px; border-top: 1px solid rgba(59, 130, 246, 0.3); text-align: center;">
|
| 530 |
+
<p style="color: #a0a0a0; font-size: 0.9em; margin: 0;">
|
| 531 |
+
<i class="fas fa-code-branch fa-color-gray" style="margin-right: 5px;"></i> HF demo adapted from <a href="https://huggingface.co/spaces/facebook/map-anything" target="_blank" style="color: inherit; text-decoration: none;">Map Anything</a>
|
| 532 |
+
</p>
|
| 533 |
+
</div>
|
| 534 |
+
</div>
|
| 535 |
+
"""
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def get_gradio_theme():
|
| 539 |
+
"""
|
| 540 |
+
Get the configured Gradio theme with adaptive tech colors.
|
| 541 |
+
|
| 542 |
+
Returns:
|
| 543 |
+
gr.themes.Base: Configured Gradio theme
|
| 544 |
+
"""
|
| 545 |
+
import gradio as gr
|
| 546 |
+
|
| 547 |
+
return gr.themes.Base(
|
| 548 |
+
primary_hue=gr.themes.Color(
|
| 549 |
+
c50="#eff6ff",
|
| 550 |
+
c100="#dbeafe",
|
| 551 |
+
c200="#bfdbfe",
|
| 552 |
+
c300="#93c5fd",
|
| 553 |
+
c400="#60a5fa",
|
| 554 |
+
c500="#3b82f6",
|
| 555 |
+
c600="#2563eb",
|
| 556 |
+
c700="#1d4ed8",
|
| 557 |
+
c800="#1e40af",
|
| 558 |
+
c900="#1e3a8a",
|
| 559 |
+
c950="#172554",
|
| 560 |
+
),
|
| 561 |
+
secondary_hue=gr.themes.Color(
|
| 562 |
+
c50="#f5f3ff",
|
| 563 |
+
c100="#ede9fe",
|
| 564 |
+
c200="#ddd6fe",
|
| 565 |
+
c300="#c4b5fd",
|
| 566 |
+
c400="#a78bfa",
|
| 567 |
+
c500="#8b5cf6",
|
| 568 |
+
c600="#7c3aed",
|
| 569 |
+
c700="#6d28d9",
|
| 570 |
+
c800="#5b21b6",
|
| 571 |
+
c900="#4c1d95",
|
| 572 |
+
c950="#2e1065",
|
| 573 |
+
),
|
| 574 |
+
neutral_hue=gr.themes.Color(
|
| 575 |
+
c50="#f8fafc",
|
| 576 |
+
c100="#f1f5f9",
|
| 577 |
+
c200="#e2e8f0",
|
| 578 |
+
c300="#cbd5e1",
|
| 579 |
+
c400="#94a3b8",
|
| 580 |
+
c500="#64748b",
|
| 581 |
+
c600="#475569",
|
| 582 |
+
c700="#334155",
|
| 583 |
+
c800="#1e293b",
|
| 584 |
+
c900="#0f172a",
|
| 585 |
+
c950="#020617",
|
| 586 |
+
),
|
| 587 |
+
)
|
| 588 |
+
|
| 589 |
+
|
| 590 |
+
# Measure tab instructions HTML
|
| 591 |
+
MEASURE_INSTRUCTIONS_HTML = """
|
| 592 |
+
### Click points on the image to compute distance.
|
| 593 |
+
> <i class="fas fa-triangle-exclamation fa-color-red" style="margin-right: 5px;"></i> Metric scale estimation is difficult on aerial/drone images.
|
| 594 |
+
"""
|
src/depth_anything_3/app/gradio_app.py
ADDED
|
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Refactored Gradio App for Depth Anything 3.
|
| 17 |
+
|
| 18 |
+
This is the main application file that orchestrates all components.
|
| 19 |
+
The original functionality has been split into modular components for better maintainability.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
from typing import Any, Dict, List
|
| 25 |
+
import gradio as gr
|
| 26 |
+
|
| 27 |
+
from depth_anything_3.app.css_and_html import GRADIO_CSS, get_gradio_theme
|
| 28 |
+
from depth_anything_3.app.modules.event_handlers import EventHandlers
|
| 29 |
+
from depth_anything_3.app.modules.ui_components import UIComponents
|
| 30 |
+
|
| 31 |
+
# Set environment variables
|
| 32 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DepthAnything3App:
|
| 36 |
+
"""
|
| 37 |
+
Main application class for Depth Anything 3 Gradio app.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self, model_dir: str = None, workspace_dir: str = None, gallery_dir: str = None):
|
| 41 |
+
"""
|
| 42 |
+
Initialize the application.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
model_dir: Path to the model directory
|
| 46 |
+
workspace_dir: Path to the workspace directory
|
| 47 |
+
gallery_dir: Path to the gallery directory
|
| 48 |
+
"""
|
| 49 |
+
self.model_dir = model_dir
|
| 50 |
+
self.workspace_dir = workspace_dir
|
| 51 |
+
self.gallery_dir = gallery_dir
|
| 52 |
+
|
| 53 |
+
# Set environment variables for directories
|
| 54 |
+
if self.model_dir:
|
| 55 |
+
os.environ["DA3_MODEL_DIR"] = self.model_dir
|
| 56 |
+
if self.workspace_dir:
|
| 57 |
+
os.environ["DA3_WORKSPACE_DIR"] = self.workspace_dir
|
| 58 |
+
if self.gallery_dir:
|
| 59 |
+
os.environ["DA3_GALLERY_DIR"] = self.gallery_dir
|
| 60 |
+
|
| 61 |
+
self.event_handlers = EventHandlers()
|
| 62 |
+
self.ui_components = UIComponents()
|
| 63 |
+
|
| 64 |
+
def cache_examples(
|
| 65 |
+
self,
|
| 66 |
+
show_cam: bool = True,
|
| 67 |
+
filter_black_bg: bool = False,
|
| 68 |
+
filter_white_bg: bool = False,
|
| 69 |
+
save_percentage: float = 20.0,
|
| 70 |
+
num_max_points: int = 1000,
|
| 71 |
+
cache_gs_tag: str = "",
|
| 72 |
+
gs_trj_mode: str = "smooth",
|
| 73 |
+
gs_video_quality: str = "low",
|
| 74 |
+
) -> None:
|
| 75 |
+
"""
|
| 76 |
+
Pre-cache all example scenes at startup.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
show_cam: Whether to show camera in visualization
|
| 80 |
+
filter_black_bg: Whether to filter black background
|
| 81 |
+
filter_white_bg: Whether to filter white background
|
| 82 |
+
save_percentage: Filter percentage for point cloud
|
| 83 |
+
num_max_points: Maximum number of points
|
| 84 |
+
cache_gs_tag: Tag to match scene names for high-res+3DGS caching (e.g., "dl3dv")
|
| 85 |
+
gs_trj_mode: Trajectory mode for 3DGS
|
| 86 |
+
gs_video_quality: Video quality for 3DGS
|
| 87 |
+
"""
|
| 88 |
+
from depth_anything_3.app.modules.utils import get_scene_info
|
| 89 |
+
|
| 90 |
+
examples_dir = os.path.join(self.workspace_dir, "examples")
|
| 91 |
+
if not os.path.exists(examples_dir):
|
| 92 |
+
print(f"Examples directory not found: {examples_dir}")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
scenes = get_scene_info(examples_dir)
|
| 96 |
+
if not scenes:
|
| 97 |
+
print("No example scenes found to cache.")
|
| 98 |
+
return
|
| 99 |
+
|
| 100 |
+
print(f"\n{'='*60}")
|
| 101 |
+
print(f"Caching {len(scenes)} example scenes...")
|
| 102 |
+
print(f"{'='*60}\n")
|
| 103 |
+
|
| 104 |
+
for i, scene in enumerate(scenes, 1):
|
| 105 |
+
scene_name = scene["name"]
|
| 106 |
+
|
| 107 |
+
# Check if scene name matches the gs tag for high-res+3DGS caching
|
| 108 |
+
use_high_res_gs = cache_gs_tag and cache_gs_tag.lower() in scene_name.lower()
|
| 109 |
+
|
| 110 |
+
if use_high_res_gs:
|
| 111 |
+
print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (HIGH-RES + 3DGS)")
|
| 112 |
+
print(f" - Number of images: {scene['num_images']}")
|
| 113 |
+
print(f" - Matched tag: '{cache_gs_tag}' - using high_res + 3DGS")
|
| 114 |
+
else:
|
| 115 |
+
print(f"[{i}/{len(scenes)}] Caching scene: {scene_name} (LOW-RES)")
|
| 116 |
+
print(f" - Number of images: {scene['num_images']}")
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
# Load example scene
|
| 120 |
+
_, target_dir, _, _, _, _, _, _, _ = self.event_handlers.load_example_scene(
|
| 121 |
+
scene_name
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
if target_dir and target_dir != "None":
|
| 125 |
+
# Run reconstruction with appropriate settings
|
| 126 |
+
print(" - Running reconstruction...")
|
| 127 |
+
result = self.event_handlers.gradio_demo(
|
| 128 |
+
target_dir=target_dir,
|
| 129 |
+
show_cam=show_cam,
|
| 130 |
+
filter_black_bg=filter_black_bg,
|
| 131 |
+
filter_white_bg=filter_white_bg,
|
| 132 |
+
process_res_method="high_res" if use_high_res_gs else "low_res",
|
| 133 |
+
selected_first_frame="",
|
| 134 |
+
save_percentage=save_percentage,
|
| 135 |
+
num_max_points=num_max_points,
|
| 136 |
+
infer_gs=use_high_res_gs,
|
| 137 |
+
gs_trj_mode=gs_trj_mode,
|
| 138 |
+
gs_video_quality=gs_video_quality,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Check if successful
|
| 142 |
+
if result[0] is not None: # reconstruction_output
|
| 143 |
+
print(f" ✓ Scene '{scene_name}' cached successfully")
|
| 144 |
+
else:
|
| 145 |
+
print(f" ✗ Scene '{scene_name}' caching failed: {result[1]}")
|
| 146 |
+
else:
|
| 147 |
+
print(f" ✗ Scene '{scene_name}' loading failed")
|
| 148 |
+
|
| 149 |
+
except Exception as e:
|
| 150 |
+
print(f" ✗ Error caching scene '{scene_name}': {str(e)}")
|
| 151 |
+
|
| 152 |
+
print()
|
| 153 |
+
|
| 154 |
+
print("=" * 60)
|
| 155 |
+
print("Example scene caching completed!")
|
| 156 |
+
print("=" * 60 + "\n")
|
| 157 |
+
|
| 158 |
+
def create_app(self) -> gr.Blocks:
|
| 159 |
+
"""
|
| 160 |
+
Create and configure the Gradio application.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Configured Gradio Blocks interface
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
# Initialize theme
|
| 167 |
+
def get_theme():
|
| 168 |
+
return get_gradio_theme()
|
| 169 |
+
|
| 170 |
+
with gr.Blocks(theme=get_theme(), css=GRADIO_CSS) as demo:
|
| 171 |
+
# State variables for the tabbed interface
|
| 172 |
+
is_example = gr.Textbox(label="is_example", visible=False, value="None")
|
| 173 |
+
processed_data_state = gr.State(value=None)
|
| 174 |
+
measure_points_state = gr.State(value=[])
|
| 175 |
+
selected_first_frame_state = gr.State(value="")
|
| 176 |
+
selected_image_index_state = gr.State(value=0) # Track selected image index
|
| 177 |
+
# current_view_index = gr.State(value=0) # noqa: F841 Track current view index
|
| 178 |
+
|
| 179 |
+
# Header and description
|
| 180 |
+
self.ui_components.create_header_section()
|
| 181 |
+
self.ui_components.create_description_section()
|
| 182 |
+
|
| 183 |
+
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
|
| 184 |
+
|
| 185 |
+
# Main content area
|
| 186 |
+
with gr.Row():
|
| 187 |
+
with gr.Column(scale=2):
|
| 188 |
+
# Upload section
|
| 189 |
+
(
|
| 190 |
+
input_video,
|
| 191 |
+
s_time_interval,
|
| 192 |
+
input_images,
|
| 193 |
+
image_gallery,
|
| 194 |
+
select_first_frame_btn,
|
| 195 |
+
) = self.ui_components.create_upload_section()
|
| 196 |
+
|
| 197 |
+
with gr.Column(scale=4):
|
| 198 |
+
with gr.Column():
|
| 199 |
+
# gr.Markdown("**Metric 3D Reconstruction (Point Cloud and Camera Poses)**")
|
| 200 |
+
# Reconstruction control section (buttons) - moved below tabs
|
| 201 |
+
|
| 202 |
+
log_output = gr.Markdown(
|
| 203 |
+
"Please upload a video or images, then click Reconstruct.",
|
| 204 |
+
elem_classes=["custom-log"],
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Tabbed interface
|
| 208 |
+
with gr.Tabs():
|
| 209 |
+
with gr.Tab("Point Cloud & Cameras"):
|
| 210 |
+
reconstruction_output = (
|
| 211 |
+
self.ui_components.create_3d_viewer_section()
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
with gr.Tab("Metric Depth"):
|
| 215 |
+
(
|
| 216 |
+
prev_measure_btn,
|
| 217 |
+
measure_view_selector,
|
| 218 |
+
next_measure_btn,
|
| 219 |
+
measure_image,
|
| 220 |
+
measure_depth_image,
|
| 221 |
+
measure_text,
|
| 222 |
+
) = self.ui_components.create_measure_section()
|
| 223 |
+
|
| 224 |
+
with gr.Tab("3DGS Rendered Novel Views"):
|
| 225 |
+
gs_video, gs_info = self.ui_components.create_nvs_video()
|
| 226 |
+
|
| 227 |
+
# Inference control section (before inference)
|
| 228 |
+
(process_res_method_dropdown, infer_gs) = (
|
| 229 |
+
self.ui_components.create_inference_control_section()
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Display control section - includes 3DGS options, buttons, and Visualization Options # noqa: E501
|
| 233 |
+
(
|
| 234 |
+
show_cam,
|
| 235 |
+
filter_black_bg,
|
| 236 |
+
filter_white_bg,
|
| 237 |
+
save_percentage,
|
| 238 |
+
num_max_points,
|
| 239 |
+
gs_trj_mode,
|
| 240 |
+
gs_video_quality,
|
| 241 |
+
submit_btn,
|
| 242 |
+
clear_btn,
|
| 243 |
+
) = self.ui_components.create_display_control_section()
|
| 244 |
+
|
| 245 |
+
# bind visibility of gs_trj_mode to infer_gs
|
| 246 |
+
infer_gs.change(
|
| 247 |
+
fn=lambda checked: (
|
| 248 |
+
gr.update(visible=checked),
|
| 249 |
+
gr.update(visible=checked),
|
| 250 |
+
gr.update(visible=checked),
|
| 251 |
+
gr.update(visible=(not checked)),
|
| 252 |
+
),
|
| 253 |
+
inputs=infer_gs,
|
| 254 |
+
outputs=[gs_trj_mode, gs_video_quality, gs_video, gs_info],
|
| 255 |
+
)
|
| 256 |
+
|
| 257 |
+
# Example scenes section
|
| 258 |
+
gr.Markdown("## Example Scenes")
|
| 259 |
+
|
| 260 |
+
scenes = self.ui_components.create_example_scenes_section()
|
| 261 |
+
scene_components = self.ui_components.create_example_scene_grid(scenes)
|
| 262 |
+
|
| 263 |
+
# Set up event handlers
|
| 264 |
+
self._setup_event_handlers(
|
| 265 |
+
demo,
|
| 266 |
+
is_example,
|
| 267 |
+
processed_data_state,
|
| 268 |
+
measure_points_state,
|
| 269 |
+
target_dir_output,
|
| 270 |
+
input_video,
|
| 271 |
+
input_images,
|
| 272 |
+
s_time_interval,
|
| 273 |
+
image_gallery,
|
| 274 |
+
reconstruction_output,
|
| 275 |
+
log_output,
|
| 276 |
+
show_cam,
|
| 277 |
+
filter_black_bg,
|
| 278 |
+
filter_white_bg,
|
| 279 |
+
process_res_method_dropdown,
|
| 280 |
+
save_percentage,
|
| 281 |
+
submit_btn,
|
| 282 |
+
clear_btn,
|
| 283 |
+
num_max_points,
|
| 284 |
+
infer_gs,
|
| 285 |
+
select_first_frame_btn,
|
| 286 |
+
selected_first_frame_state,
|
| 287 |
+
selected_image_index_state,
|
| 288 |
+
measure_view_selector,
|
| 289 |
+
measure_image,
|
| 290 |
+
measure_depth_image,
|
| 291 |
+
measure_text,
|
| 292 |
+
prev_measure_btn,
|
| 293 |
+
next_measure_btn,
|
| 294 |
+
scenes,
|
| 295 |
+
scene_components,
|
| 296 |
+
gs_video,
|
| 297 |
+
gs_info,
|
| 298 |
+
gs_trj_mode,
|
| 299 |
+
gs_video_quality,
|
| 300 |
+
)
|
| 301 |
+
|
| 302 |
+
# Acknowledgements
|
| 303 |
+
self.ui_components.create_acknowledgements_section()
|
| 304 |
+
|
| 305 |
+
return demo
|
| 306 |
+
|
| 307 |
+
def _setup_event_handlers(
|
| 308 |
+
self,
|
| 309 |
+
demo: gr.Blocks,
|
| 310 |
+
is_example: gr.Textbox,
|
| 311 |
+
processed_data_state: gr.State,
|
| 312 |
+
measure_points_state: gr.State,
|
| 313 |
+
target_dir_output: gr.Textbox,
|
| 314 |
+
input_video: gr.Video,
|
| 315 |
+
input_images: gr.File,
|
| 316 |
+
s_time_interval: gr.Slider,
|
| 317 |
+
image_gallery: gr.Gallery,
|
| 318 |
+
reconstruction_output: gr.Model3D,
|
| 319 |
+
log_output: gr.Markdown,
|
| 320 |
+
show_cam: gr.Checkbox,
|
| 321 |
+
filter_black_bg: gr.Checkbox,
|
| 322 |
+
filter_white_bg: gr.Checkbox,
|
| 323 |
+
process_res_method_dropdown: gr.Dropdown,
|
| 324 |
+
save_percentage: gr.Slider,
|
| 325 |
+
submit_btn: gr.Button,
|
| 326 |
+
clear_btn: gr.ClearButton,
|
| 327 |
+
num_max_points: gr.Slider,
|
| 328 |
+
infer_gs: gr.Checkbox,
|
| 329 |
+
select_first_frame_btn: gr.Button,
|
| 330 |
+
selected_first_frame_state: gr.State,
|
| 331 |
+
selected_image_index_state: gr.State,
|
| 332 |
+
measure_view_selector: gr.Dropdown,
|
| 333 |
+
measure_image: gr.Image,
|
| 334 |
+
measure_depth_image: gr.Image,
|
| 335 |
+
measure_text: gr.Markdown,
|
| 336 |
+
prev_measure_btn: gr.Button,
|
| 337 |
+
next_measure_btn: gr.Button,
|
| 338 |
+
scenes: List[Dict[str, Any]],
|
| 339 |
+
scene_components: List[gr.Image],
|
| 340 |
+
gs_video: gr.Video,
|
| 341 |
+
gs_info: gr.Markdown,
|
| 342 |
+
gs_trj_mode: gr.Dropdown,
|
| 343 |
+
gs_video_quality: gr.Dropdown,
|
| 344 |
+
) -> None:
|
| 345 |
+
"""
|
| 346 |
+
Set up all event handlers for the application.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
demo: Gradio Blocks interface
|
| 350 |
+
All other arguments: Gradio components to connect
|
| 351 |
+
"""
|
| 352 |
+
# Configure clear button
|
| 353 |
+
clear_btn.add(
|
| 354 |
+
[
|
| 355 |
+
input_video,
|
| 356 |
+
input_images,
|
| 357 |
+
reconstruction_output,
|
| 358 |
+
log_output,
|
| 359 |
+
target_dir_output,
|
| 360 |
+
image_gallery,
|
| 361 |
+
gs_video,
|
| 362 |
+
]
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Main reconstruction button
|
| 366 |
+
submit_btn.click(
|
| 367 |
+
fn=self.event_handlers.clear_fields, inputs=[], outputs=[reconstruction_output]
|
| 368 |
+
).then(fn=self.event_handlers.update_log, inputs=[], outputs=[log_output]).then(
|
| 369 |
+
fn=self.event_handlers.gradio_demo,
|
| 370 |
+
inputs=[
|
| 371 |
+
target_dir_output,
|
| 372 |
+
show_cam,
|
| 373 |
+
filter_black_bg,
|
| 374 |
+
filter_white_bg,
|
| 375 |
+
process_res_method_dropdown,
|
| 376 |
+
selected_first_frame_state,
|
| 377 |
+
save_percentage,
|
| 378 |
+
# pass num_max_points
|
| 379 |
+
num_max_points,
|
| 380 |
+
infer_gs,
|
| 381 |
+
gs_trj_mode,
|
| 382 |
+
gs_video_quality,
|
| 383 |
+
],
|
| 384 |
+
outputs=[
|
| 385 |
+
reconstruction_output,
|
| 386 |
+
log_output,
|
| 387 |
+
processed_data_state,
|
| 388 |
+
measure_image,
|
| 389 |
+
measure_depth_image,
|
| 390 |
+
measure_text,
|
| 391 |
+
measure_view_selector,
|
| 392 |
+
gs_video,
|
| 393 |
+
gs_video, # gs_video visibility
|
| 394 |
+
gs_info, # gs_info visibility
|
| 395 |
+
],
|
| 396 |
+
).then(
|
| 397 |
+
fn=lambda: "False",
|
| 398 |
+
inputs=[],
|
| 399 |
+
outputs=[is_example], # set is_example to "False"
|
| 400 |
+
)
|
| 401 |
+
|
| 402 |
+
# Real-time visualization updates
|
| 403 |
+
self._setup_visualization_handlers(
|
| 404 |
+
show_cam,
|
| 405 |
+
filter_black_bg,
|
| 406 |
+
filter_white_bg,
|
| 407 |
+
process_res_method_dropdown,
|
| 408 |
+
target_dir_output,
|
| 409 |
+
is_example,
|
| 410 |
+
reconstruction_output,
|
| 411 |
+
log_output,
|
| 412 |
+
)
|
| 413 |
+
|
| 414 |
+
# File upload handlers
|
| 415 |
+
input_video.change(
|
| 416 |
+
fn=self.event_handlers.handle_uploads,
|
| 417 |
+
inputs=[input_video, input_images, s_time_interval],
|
| 418 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 419 |
+
)
|
| 420 |
+
input_images.change(
|
| 421 |
+
fn=self.event_handlers.handle_uploads,
|
| 422 |
+
inputs=[input_video, input_images, s_time_interval],
|
| 423 |
+
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Image gallery click handler (for selecting first frame)
|
| 427 |
+
def handle_image_selection(evt: gr.SelectData):
|
| 428 |
+
if evt is None or evt.index is None:
|
| 429 |
+
return "No image selected", 0
|
| 430 |
+
selected_index = evt.index
|
| 431 |
+
return f"Selected image {selected_index} as potential first frame", selected_index
|
| 432 |
+
|
| 433 |
+
image_gallery.select(
|
| 434 |
+
fn=handle_image_selection,
|
| 435 |
+
outputs=[log_output, selected_image_index_state],
|
| 436 |
+
)
|
| 437 |
+
|
| 438 |
+
# Select first frame handler
|
| 439 |
+
select_first_frame_btn.click(
|
| 440 |
+
fn=self.event_handlers.select_first_frame,
|
| 441 |
+
inputs=[image_gallery, selected_image_index_state],
|
| 442 |
+
outputs=[image_gallery, log_output, selected_first_frame_state],
|
| 443 |
+
)
|
| 444 |
+
|
| 445 |
+
# Navigation handlers
|
| 446 |
+
self._setup_navigation_handlers(
|
| 447 |
+
prev_measure_btn,
|
| 448 |
+
next_measure_btn,
|
| 449 |
+
measure_view_selector,
|
| 450 |
+
measure_image,
|
| 451 |
+
measure_depth_image,
|
| 452 |
+
measure_points_state,
|
| 453 |
+
processed_data_state,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
# Measurement handler
|
| 457 |
+
measure_image.select(
|
| 458 |
+
fn=self.event_handlers.measure,
|
| 459 |
+
inputs=[processed_data_state, measure_points_state, measure_view_selector],
|
| 460 |
+
outputs=[measure_image, measure_depth_image, measure_points_state, measure_text],
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# Example scene handlers
|
| 464 |
+
self._setup_example_scene_handlers(
|
| 465 |
+
scenes,
|
| 466 |
+
scene_components,
|
| 467 |
+
reconstruction_output,
|
| 468 |
+
target_dir_output,
|
| 469 |
+
image_gallery,
|
| 470 |
+
log_output,
|
| 471 |
+
is_example,
|
| 472 |
+
processed_data_state,
|
| 473 |
+
measure_view_selector,
|
| 474 |
+
measure_image,
|
| 475 |
+
measure_depth_image,
|
| 476 |
+
gs_video,
|
| 477 |
+
gs_info,
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def _setup_visualization_handlers(
|
| 481 |
+
self,
|
| 482 |
+
show_cam: gr.Checkbox,
|
| 483 |
+
filter_black_bg: gr.Checkbox,
|
| 484 |
+
filter_white_bg: gr.Checkbox,
|
| 485 |
+
process_res_method_dropdown: gr.Dropdown,
|
| 486 |
+
target_dir_output: gr.Textbox,
|
| 487 |
+
is_example: gr.Textbox,
|
| 488 |
+
reconstruction_output: gr.Model3D,
|
| 489 |
+
log_output: gr.Markdown,
|
| 490 |
+
) -> None:
|
| 491 |
+
"""Set up visualization update handlers."""
|
| 492 |
+
# Common inputs for visualization updates
|
| 493 |
+
viz_inputs = [
|
| 494 |
+
target_dir_output,
|
| 495 |
+
show_cam,
|
| 496 |
+
is_example,
|
| 497 |
+
filter_black_bg,
|
| 498 |
+
filter_white_bg,
|
| 499 |
+
process_res_method_dropdown,
|
| 500 |
+
]
|
| 501 |
+
|
| 502 |
+
# Set up change handlers for all visualization controls
|
| 503 |
+
for component in [show_cam, filter_black_bg, filter_white_bg]:
|
| 504 |
+
component.change(
|
| 505 |
+
fn=self.event_handlers.update_visualization,
|
| 506 |
+
inputs=viz_inputs,
|
| 507 |
+
outputs=[reconstruction_output, log_output],
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
def _setup_navigation_handlers(
|
| 511 |
+
self,
|
| 512 |
+
prev_measure_btn: gr.Button,
|
| 513 |
+
next_measure_btn: gr.Button,
|
| 514 |
+
measure_view_selector: gr.Dropdown,
|
| 515 |
+
measure_image: gr.Image,
|
| 516 |
+
measure_depth_image: gr.Image,
|
| 517 |
+
measure_points_state: gr.State,
|
| 518 |
+
processed_data_state: gr.State,
|
| 519 |
+
) -> None:
|
| 520 |
+
"""Set up navigation handlers for measure tab."""
|
| 521 |
+
# Measure tab navigation
|
| 522 |
+
prev_measure_btn.click(
|
| 523 |
+
fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
|
| 524 |
+
processed_data, current_selector, -1
|
| 525 |
+
),
|
| 526 |
+
inputs=[processed_data_state, measure_view_selector],
|
| 527 |
+
outputs=[
|
| 528 |
+
measure_view_selector,
|
| 529 |
+
measure_image,
|
| 530 |
+
measure_depth_image,
|
| 531 |
+
measure_points_state,
|
| 532 |
+
],
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
next_measure_btn.click(
|
| 536 |
+
fn=lambda processed_data, current_selector: self.event_handlers.navigate_measure_view(
|
| 537 |
+
processed_data, current_selector, 1
|
| 538 |
+
),
|
| 539 |
+
inputs=[processed_data_state, measure_view_selector],
|
| 540 |
+
outputs=[
|
| 541 |
+
measure_view_selector,
|
| 542 |
+
measure_image,
|
| 543 |
+
measure_depth_image,
|
| 544 |
+
measure_points_state,
|
| 545 |
+
],
|
| 546 |
+
)
|
| 547 |
+
|
| 548 |
+
measure_view_selector.change(
|
| 549 |
+
fn=lambda processed_data, selector_value: (
|
| 550 |
+
self.event_handlers.update_measure_view(
|
| 551 |
+
processed_data, int(selector_value.split()[1]) - 1
|
| 552 |
+
)
|
| 553 |
+
if selector_value
|
| 554 |
+
else (None, None, [])
|
| 555 |
+
),
|
| 556 |
+
inputs=[processed_data_state, measure_view_selector],
|
| 557 |
+
outputs=[measure_image, measure_depth_image, measure_points_state],
|
| 558 |
+
)
|
| 559 |
+
|
| 560 |
+
def _setup_example_scene_handlers(
|
| 561 |
+
self,
|
| 562 |
+
scenes: List[Dict[str, Any]],
|
| 563 |
+
scene_components: List[gr.Image],
|
| 564 |
+
reconstruction_output: gr.Model3D,
|
| 565 |
+
target_dir_output: gr.Textbox,
|
| 566 |
+
image_gallery: gr.Gallery,
|
| 567 |
+
log_output: gr.Markdown,
|
| 568 |
+
is_example: gr.Textbox,
|
| 569 |
+
processed_data_state: gr.State,
|
| 570 |
+
measure_view_selector: gr.Dropdown,
|
| 571 |
+
measure_image: gr.Image,
|
| 572 |
+
measure_depth_image: gr.Image,
|
| 573 |
+
gs_video: gr.Video,
|
| 574 |
+
gs_info: gr.Markdown,
|
| 575 |
+
) -> None:
|
| 576 |
+
"""Set up example scene handlers."""
|
| 577 |
+
|
| 578 |
+
def load_and_update_measure(name):
|
| 579 |
+
result = self.event_handlers.load_example_scene(name)
|
| 580 |
+
# result = (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
|
| 581 |
+
|
| 582 |
+
# Update measure view if processed_data is available
|
| 583 |
+
measure_img = None
|
| 584 |
+
measure_depth = None
|
| 585 |
+
if result[4] is not None: # processed_data exists
|
| 586 |
+
measure_img, measure_depth, _ = (
|
| 587 |
+
self.event_handlers.visualization_handler.update_measure_view(result[4], 0)
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
return result + ("True", measure_img, measure_depth)
|
| 591 |
+
|
| 592 |
+
for i, scene in enumerate(scenes):
|
| 593 |
+
if i < len(scene_components):
|
| 594 |
+
scene_components[i].select(
|
| 595 |
+
fn=lambda name=scene["name"]: load_and_update_measure(name),
|
| 596 |
+
outputs=[
|
| 597 |
+
reconstruction_output,
|
| 598 |
+
target_dir_output,
|
| 599 |
+
image_gallery,
|
| 600 |
+
log_output,
|
| 601 |
+
processed_data_state,
|
| 602 |
+
measure_view_selector,
|
| 603 |
+
gs_video,
|
| 604 |
+
gs_video, # gs_video_visibility
|
| 605 |
+
gs_info, # gs_info_visibility
|
| 606 |
+
is_example,
|
| 607 |
+
measure_image,
|
| 608 |
+
measure_depth_image,
|
| 609 |
+
],
|
| 610 |
+
)
|
| 611 |
+
|
| 612 |
+
def launch(self, host: str = "127.0.0.1", port: int = 7860, **kwargs) -> None:
|
| 613 |
+
"""
|
| 614 |
+
Launch the application.
|
| 615 |
+
|
| 616 |
+
Args:
|
| 617 |
+
host: Host address to bind to
|
| 618 |
+
port: Port number to bind to
|
| 619 |
+
**kwargs: Additional arguments for demo.launch()
|
| 620 |
+
"""
|
| 621 |
+
demo = self.create_app()
|
| 622 |
+
demo.queue(max_size=20).launch(
|
| 623 |
+
show_error=True, ssr_mode=False, server_name=host, server_port=port, **kwargs
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
|
| 627 |
+
def main():
|
| 628 |
+
"""Main function to run the application."""
|
| 629 |
+
parser = argparse.ArgumentParser(
|
| 630 |
+
description="Depth Anything 3 Gradio Application",
|
| 631 |
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
| 632 |
+
epilog="""
|
| 633 |
+
Examples:
|
| 634 |
+
# Basic usage
|
| 635 |
+
python gradio_app.py --help
|
| 636 |
+
python gradio_app.py --host 0.0.0.0 --port 8080
|
| 637 |
+
python gradio_app.py --model-dir /path/to/model --workspace-dir /path/to/workspace
|
| 638 |
+
|
| 639 |
+
# Cache examples at startup (all low-res)
|
| 640 |
+
python gradio_app.py --cache-examples
|
| 641 |
+
|
| 642 |
+
# Cache with selective high-res+3DGS for scenes matching tag
|
| 643 |
+
python gradio_app.py --cache-examples --cache-gs-tag dl3dv
|
| 644 |
+
# This will use high-res + 3DGS for scenes containing "dl3dv" in their name,
|
| 645 |
+
# and low-res only for other scenes
|
| 646 |
+
""",
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
# Server configuration
|
| 650 |
+
parser.add_argument(
|
| 651 |
+
"--host", default="127.0.0.1", help="Host address to bind to (default: 127.0.0.1)"
|
| 652 |
+
)
|
| 653 |
+
parser.add_argument(
|
| 654 |
+
"--port", type=int, default=7860, help="Port number to bind to (default: 7860)"
|
| 655 |
+
)
|
| 656 |
+
|
| 657 |
+
# Directory configuration
|
| 658 |
+
parser.add_argument(
|
| 659 |
+
"--model-dir",
|
| 660 |
+
default="depth-anything/DA3NESTED-GIANT-LARGE",
|
| 661 |
+
help="Path to the model directory (default: depth-anything/DA3NESTED-GIANT-LARGE)",
|
| 662 |
+
)
|
| 663 |
+
parser.add_argument(
|
| 664 |
+
"--workspace-dir",
|
| 665 |
+
default="workspace/gradio", # noqa: E501
|
| 666 |
+
help="Path to the workspace directory (default: workspace/gradio)", # noqa: E501
|
| 667 |
+
)
|
| 668 |
+
parser.add_argument(
|
| 669 |
+
"--gallery-dir",
|
| 670 |
+
default="workspace/gallery",
|
| 671 |
+
help="Path to the gallery directory (default: workspace/gallery)", # noqa: E501
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
# Additional Gradio options
|
| 675 |
+
parser.add_argument("--share", action="store_true", help="Create a public link for the app")
|
| 676 |
+
parser.add_argument("--debug", action="store_true", help="Enable debug mode")
|
| 677 |
+
|
| 678 |
+
# Example caching options
|
| 679 |
+
parser.add_argument(
|
| 680 |
+
"--cache-examples",
|
| 681 |
+
action="store_true",
|
| 682 |
+
help="Pre-cache all example scenes at startup for faster loading",
|
| 683 |
+
)
|
| 684 |
+
parser.add_argument(
|
| 685 |
+
"--cache-gs-tag",
|
| 686 |
+
type=str,
|
| 687 |
+
default="",
|
| 688 |
+
help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.", # noqa: E501
|
| 689 |
+
)
|
| 690 |
+
|
| 691 |
+
args = parser.parse_args()
|
| 692 |
+
|
| 693 |
+
# Create directories if they don't exist
|
| 694 |
+
os.makedirs(args.workspace_dir, exist_ok=True)
|
| 695 |
+
os.makedirs(args.gallery_dir, exist_ok=True)
|
| 696 |
+
|
| 697 |
+
# Initialize and launch the application
|
| 698 |
+
app = DepthAnything3App(
|
| 699 |
+
model_dir=args.model_dir, workspace_dir=args.workspace_dir, gallery_dir=args.gallery_dir
|
| 700 |
+
)
|
| 701 |
+
|
| 702 |
+
# Prepare launch arguments
|
| 703 |
+
launch_kwargs = {"share": args.share, "debug": args.debug}
|
| 704 |
+
|
| 705 |
+
print("Starting Depth Anything 3 Gradio App...")
|
| 706 |
+
print(f"Host: {args.host}")
|
| 707 |
+
print(f"Port: {args.port}")
|
| 708 |
+
print(f"Model Directory: {args.model_dir}")
|
| 709 |
+
print(f"Workspace Directory: {args.workspace_dir}")
|
| 710 |
+
print(f"Gallery Directory: {args.gallery_dir}")
|
| 711 |
+
print(f"Share: {args.share}")
|
| 712 |
+
print(f"Debug: {args.debug}")
|
| 713 |
+
print(f"Cache Examples: {args.cache_examples}")
|
| 714 |
+
if args.cache_examples:
|
| 715 |
+
if args.cache_gs_tag:
|
| 716 |
+
print(
|
| 717 |
+
f"Cache GS Tag: '{args.cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)" # noqa: E501
|
| 718 |
+
) # noqa: E501
|
| 719 |
+
else:
|
| 720 |
+
print("Cache GS Tag: None (all scenes will use low-res only)")
|
| 721 |
+
|
| 722 |
+
# Pre-cache examples if requested
|
| 723 |
+
if args.cache_examples:
|
| 724 |
+
print("\n" + "=" * 60)
|
| 725 |
+
print("Pre-caching mode enabled")
|
| 726 |
+
if args.cache_gs_tag:
|
| 727 |
+
print(f"Scenes containing '{args.cache_gs_tag}' will use HIGH-RES + 3DGS")
|
| 728 |
+
print("Other scenes will use LOW-RES only")
|
| 729 |
+
else:
|
| 730 |
+
print("All scenes will use LOW-RES only")
|
| 731 |
+
print("=" * 60)
|
| 732 |
+
app.cache_examples(
|
| 733 |
+
show_cam=True,
|
| 734 |
+
filter_black_bg=False,
|
| 735 |
+
filter_white_bg=False,
|
| 736 |
+
save_percentage=5.0,
|
| 737 |
+
num_max_points=1000,
|
| 738 |
+
cache_gs_tag=args.cache_gs_tag,
|
| 739 |
+
gs_trj_mode="smooth",
|
| 740 |
+
gs_video_quality="low",
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
app.launch(host=args.host, port=args.port, **launch_kwargs)
|
| 744 |
+
|
| 745 |
+
|
| 746 |
+
if __name__ == "__main__":
|
| 747 |
+
main()
|
src/depth_anything_3/app/modules/__init__.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Modules package for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This package contains all the modular components for the Gradio application.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
from depth_anything_3.app.modules.event_handlers import EventHandlers
|
| 22 |
+
from depth_anything_3.app.modules.file_handlers import FileHandler
|
| 23 |
+
from depth_anything_3.app.modules.model_inference import ModelInference
|
| 24 |
+
from depth_anything_3.app.modules.ui_components import UIComponents
|
| 25 |
+
from depth_anything_3.app.modules.utils import (
|
| 26 |
+
cleanup_memory,
|
| 27 |
+
create_depth_visualization,
|
| 28 |
+
get_logo_base64,
|
| 29 |
+
get_scene_info,
|
| 30 |
+
save_to_gallery_func,
|
| 31 |
+
)
|
| 32 |
+
from depth_anything_3.app.modules.visualization import VisualizationHandler
|
| 33 |
+
|
| 34 |
+
__all__ = [
|
| 35 |
+
"ModelInference",
|
| 36 |
+
"FileHandler",
|
| 37 |
+
"VisualizationHandler",
|
| 38 |
+
"EventHandlers",
|
| 39 |
+
"UIComponents",
|
| 40 |
+
"create_depth_visualization",
|
| 41 |
+
"save_to_gallery_func",
|
| 42 |
+
"get_scene_info",
|
| 43 |
+
"cleanup_memory",
|
| 44 |
+
"get_logo_base64",
|
| 45 |
+
]
|
src/depth_anything_3/app/modules/event_handlers.py
ADDED
|
@@ -0,0 +1,629 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Event handling module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles all event callbacks and user interactions.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import time
|
| 23 |
+
from glob import glob
|
| 24 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 25 |
+
import gradio as gr
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from depth_anything_3.app.modules.file_handlers import FileHandler
|
| 30 |
+
from depth_anything_3.app.modules.model_inference import ModelInference
|
| 31 |
+
from depth_anything_3.app.modules.utils import cleanup_memory
|
| 32 |
+
from depth_anything_3.app.modules.visualization import VisualizationHandler
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class EventHandlers:
|
| 36 |
+
"""
|
| 37 |
+
Handles all event callbacks and user interactions for the Gradio app.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def __init__(self):
|
| 41 |
+
"""Initialize the event handlers."""
|
| 42 |
+
self.model_inference = ModelInference()
|
| 43 |
+
self.file_handler = FileHandler()
|
| 44 |
+
self.visualization_handler = VisualizationHandler()
|
| 45 |
+
|
| 46 |
+
def clear_fields(self) -> None:
|
| 47 |
+
"""
|
| 48 |
+
Clears the 3D viewer, the stored target_dir, and empties the gallery.
|
| 49 |
+
"""
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
def update_log(self) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Display a quick log message while waiting.
|
| 55 |
+
"""
|
| 56 |
+
return "Loading and Reconstructing..."
|
| 57 |
+
|
| 58 |
+
def save_current_visualization(
|
| 59 |
+
self,
|
| 60 |
+
target_dir: str,
|
| 61 |
+
save_percentage: float,
|
| 62 |
+
show_cam: bool,
|
| 63 |
+
filter_black_bg: bool,
|
| 64 |
+
filter_white_bg: bool,
|
| 65 |
+
processed_data: Optional[Dict],
|
| 66 |
+
scene_name: str = "",
|
| 67 |
+
) -> str:
|
| 68 |
+
"""
|
| 69 |
+
Save current visualization results to gallery with specified save percentage.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
target_dir: Directory containing results
|
| 73 |
+
save_percentage: Percentage of points to save (0-100)
|
| 74 |
+
show_cam: Whether to show cameras
|
| 75 |
+
filter_black_bg: Whether to filter black background
|
| 76 |
+
filter_white_bg: Whether to filter white background
|
| 77 |
+
processed_data: Processed data from reconstruction
|
| 78 |
+
|
| 79 |
+
Returns:
|
| 80 |
+
Status message
|
| 81 |
+
"""
|
| 82 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 83 |
+
return "No reconstruction available. Please run 'Reconstruct' first."
|
| 84 |
+
|
| 85 |
+
if processed_data is None:
|
| 86 |
+
return "No processed data available. Please run 'Reconstruct' first."
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
# Add debug information
|
| 90 |
+
print("[DEBUG] save_current_visualization called with:")
|
| 91 |
+
print(f" target_dir: {target_dir}")
|
| 92 |
+
print(f" save_percentage: {save_percentage}")
|
| 93 |
+
print(f" show_cam: {show_cam}")
|
| 94 |
+
print(f" filter_black_bg: {filter_black_bg}")
|
| 95 |
+
print(f" filter_white_bg: {filter_white_bg}")
|
| 96 |
+
print(f" processed_data: {processed_data is not None}")
|
| 97 |
+
|
| 98 |
+
# Import the gallery save function
|
| 99 |
+
# Create gallery name with user input or auto-generated
|
| 100 |
+
import datetime
|
| 101 |
+
|
| 102 |
+
from .utils import save_to_gallery_func
|
| 103 |
+
|
| 104 |
+
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 105 |
+
if scene_name and scene_name.strip():
|
| 106 |
+
gallery_name = f"{scene_name.strip()}_{timestamp}_pct{save_percentage:.0f}"
|
| 107 |
+
else:
|
| 108 |
+
gallery_name = f"save_{timestamp}_pct{save_percentage:.0f}"
|
| 109 |
+
|
| 110 |
+
print(f"[DEBUG] Saving to gallery with name: {gallery_name}")
|
| 111 |
+
|
| 112 |
+
# Save entire process folder to gallery
|
| 113 |
+
success, message = save_to_gallery_func(
|
| 114 |
+
target_dir=target_dir, processed_data=processed_data, gallery_name=gallery_name
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if success:
|
| 118 |
+
print(f"[DEBUG] Gallery save completed successfully: {message}")
|
| 119 |
+
return (
|
| 120 |
+
"Successfully saved to gallery!\n"
|
| 121 |
+
f"Gallery name: {gallery_name}\n"
|
| 122 |
+
f"Save percentage: {save_percentage}%\n"
|
| 123 |
+
f"Show cameras: {show_cam}\n"
|
| 124 |
+
f"Filter black bg: {filter_black_bg}\n"
|
| 125 |
+
f"Filter white bg: {filter_white_bg}\n\n"
|
| 126 |
+
f"{message}"
|
| 127 |
+
)
|
| 128 |
+
else:
|
| 129 |
+
print(f"[DEBUG] Gallery save failed: {message}")
|
| 130 |
+
return f"Failed to save to gallery: {message}"
|
| 131 |
+
|
| 132 |
+
except Exception as e:
|
| 133 |
+
return f"Error saving visualization: {str(e)}"
|
| 134 |
+
|
| 135 |
+
def gradio_demo(
|
| 136 |
+
self,
|
| 137 |
+
target_dir: str,
|
| 138 |
+
show_cam: bool = True,
|
| 139 |
+
filter_black_bg: bool = False,
|
| 140 |
+
filter_white_bg: bool = False,
|
| 141 |
+
process_res_method: str = "upper_bound_resize",
|
| 142 |
+
selected_first_frame: str = "",
|
| 143 |
+
save_percentage: float = 30.0,
|
| 144 |
+
num_max_points: int = 1_000_000,
|
| 145 |
+
infer_gs: bool = False,
|
| 146 |
+
gs_trj_mode: str = "extend",
|
| 147 |
+
gs_video_quality: str = "high",
|
| 148 |
+
) -> Tuple[
|
| 149 |
+
Optional[str],
|
| 150 |
+
str,
|
| 151 |
+
Optional[Dict],
|
| 152 |
+
Optional[np.ndarray],
|
| 153 |
+
Optional[np.ndarray],
|
| 154 |
+
str,
|
| 155 |
+
gr.Dropdown,
|
| 156 |
+
Optional[str], # gs video path
|
| 157 |
+
gr.update, # gs video visibility update
|
| 158 |
+
gr.update, # gs info visibility update
|
| 159 |
+
]:
|
| 160 |
+
"""
|
| 161 |
+
Perform reconstruction using the already-created target_dir/images.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
target_dir: Directory containing images
|
| 165 |
+
show_cam: Whether to show camera
|
| 166 |
+
filter_black_bg: Whether to filter black background
|
| 167 |
+
filter_white_bg: Whether to filter white background
|
| 168 |
+
process_res_method: Method for resizing input images
|
| 169 |
+
selected_first_frame: Selected first frame filename
|
| 170 |
+
infer_gs: Whether to infer 3D Gaussian Splatting
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
Tuple of reconstruction results
|
| 174 |
+
"""
|
| 175 |
+
if not os.path.isdir(target_dir) or target_dir == "None":
|
| 176 |
+
return (
|
| 177 |
+
None,
|
| 178 |
+
"No valid target directory found. Please upload first.",
|
| 179 |
+
None,
|
| 180 |
+
None,
|
| 181 |
+
None,
|
| 182 |
+
"",
|
| 183 |
+
None,
|
| 184 |
+
None,
|
| 185 |
+
gr.update(visible=False), # gs_video
|
| 186 |
+
gr.update(visible=True), # gs_info
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
start_time = time.time()
|
| 190 |
+
cleanup_memory()
|
| 191 |
+
|
| 192 |
+
# Get image files for logging
|
| 193 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 194 |
+
all_files = (
|
| 195 |
+
sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
print("Running DepthAnything3 model...")
|
| 199 |
+
print(f"Selected first frame: {selected_first_frame}")
|
| 200 |
+
|
| 201 |
+
# Validate selected_first_frame against current image list
|
| 202 |
+
if selected_first_frame and target_dir_images:
|
| 203 |
+
current_files = (
|
| 204 |
+
sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
|
| 205 |
+
)
|
| 206 |
+
if selected_first_frame not in current_files:
|
| 207 |
+
print(
|
| 208 |
+
f"Selected first frame '{selected_first_frame}' not found in "
|
| 209 |
+
"current images. Using default order."
|
| 210 |
+
)
|
| 211 |
+
selected_first_frame = "" # Reset to use default order
|
| 212 |
+
|
| 213 |
+
with torch.no_grad():
|
| 214 |
+
prediction, processed_data = self.model_inference.run_inference(
|
| 215 |
+
target_dir,
|
| 216 |
+
process_res_method=process_res_method,
|
| 217 |
+
show_camera=show_cam,
|
| 218 |
+
selected_first_frame=selected_first_frame,
|
| 219 |
+
save_percentage=save_percentage,
|
| 220 |
+
num_max_points=int(num_max_points * 1000), # Convert K to actual count
|
| 221 |
+
infer_gs=infer_gs,
|
| 222 |
+
gs_trj_mode=gs_trj_mode,
|
| 223 |
+
gs_video_quality=gs_video_quality,
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
# The GLB file is already generated by the API
|
| 227 |
+
glbfile = os.path.join(target_dir, "scene.glb")
|
| 228 |
+
|
| 229 |
+
# Handle 3DGS video based on infer_gs flag
|
| 230 |
+
gsvideo_path = None
|
| 231 |
+
gs_video_visible = False
|
| 232 |
+
gs_info_visible = True
|
| 233 |
+
|
| 234 |
+
if infer_gs:
|
| 235 |
+
try:
|
| 236 |
+
gsvideo_path = sorted(glob(os.path.join(target_dir, "gs_video", "*.mp4")))[-1]
|
| 237 |
+
gs_video_visible = True
|
| 238 |
+
gs_info_visible = False
|
| 239 |
+
except IndexError:
|
| 240 |
+
gsvideo_path = None
|
| 241 |
+
print("3DGS video not found, but infer_gs was enabled")
|
| 242 |
+
|
| 243 |
+
# Cleanup
|
| 244 |
+
cleanup_memory()
|
| 245 |
+
|
| 246 |
+
end_time = time.time()
|
| 247 |
+
print(f"Total time: {end_time - start_time:.2f} seconds")
|
| 248 |
+
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
|
| 249 |
+
|
| 250 |
+
# Populate visualization tabs with processed data
|
| 251 |
+
depth_vis, measure_img, measure_depth_vis, measure_pts = (
|
| 252 |
+
self.visualization_handler.populate_visualization_tabs(processed_data)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
# Update view selectors based on available views
|
| 256 |
+
depth_selector, measure_selector = self.visualization_handler.update_view_selectors(
|
| 257 |
+
processed_data
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
return (
|
| 261 |
+
glbfile,
|
| 262 |
+
log_msg,
|
| 263 |
+
processed_data,
|
| 264 |
+
measure_img, # measure_image
|
| 265 |
+
measure_depth_vis, # measure_depth_image
|
| 266 |
+
"", # measure_text (empty initially)
|
| 267 |
+
measure_selector, # measure_view_selector
|
| 268 |
+
gsvideo_path,
|
| 269 |
+
gr.update(visible=gs_video_visible), # gs_video visibility
|
| 270 |
+
gr.update(visible=gs_info_visible), # gs_info visibility
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
def update_visualization(
|
| 274 |
+
self,
|
| 275 |
+
target_dir: str,
|
| 276 |
+
show_cam: bool,
|
| 277 |
+
is_example: str,
|
| 278 |
+
filter_black_bg: bool = False,
|
| 279 |
+
filter_white_bg: bool = False,
|
| 280 |
+
process_res_method: str = "upper_bound_resize",
|
| 281 |
+
) -> Tuple[gr.update, str]:
|
| 282 |
+
"""
|
| 283 |
+
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
|
| 284 |
+
and return it for the 3D viewer.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
target_dir: Directory containing results
|
| 288 |
+
show_cam: Whether to show camera
|
| 289 |
+
is_example: Whether this is an example scene
|
| 290 |
+
filter_black_bg: Whether to filter black background
|
| 291 |
+
filter_white_bg: Whether to filter white background
|
| 292 |
+
process_res_method: Method for resizing input images
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Tuple of (glb_file, log_message)
|
| 296 |
+
"""
|
| 297 |
+
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
|
| 298 |
+
return (
|
| 299 |
+
gr.update(),
|
| 300 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# Check if GLB exists (could be cached example or reconstructed scene)
|
| 304 |
+
glbfile = os.path.join(target_dir, "scene.glb")
|
| 305 |
+
if os.path.exists(glbfile):
|
| 306 |
+
return (
|
| 307 |
+
glbfile,
|
| 308 |
+
(
|
| 309 |
+
"Visualization loaded from cache."
|
| 310 |
+
if is_example == "True"
|
| 311 |
+
else "Visualization updated."
|
| 312 |
+
),
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
# If no GLB but it's an example that hasn't been reconstructed yet
|
| 316 |
+
if is_example == "True":
|
| 317 |
+
return (
|
| 318 |
+
gr.update(),
|
| 319 |
+
"No reconstruction available. Please click the Reconstruct button first.",
|
| 320 |
+
)
|
| 321 |
+
|
| 322 |
+
# For non-examples, check predictions.npz
|
| 323 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
| 324 |
+
if not os.path.exists(predictions_path):
|
| 325 |
+
error_message = (
|
| 326 |
+
f"No reconstruction available at {predictions_path}. "
|
| 327 |
+
"Please run 'Reconstruct' first."
|
| 328 |
+
)
|
| 329 |
+
return gr.update(), error_message
|
| 330 |
+
|
| 331 |
+
loaded = np.load(predictions_path, allow_pickle=True)
|
| 332 |
+
predictions = {key: loaded[key] for key in loaded.keys()} # noqa: F841
|
| 333 |
+
|
| 334 |
+
return (
|
| 335 |
+
glbfile,
|
| 336 |
+
"Visualization updated.",
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
def handle_uploads(
|
| 340 |
+
self,
|
| 341 |
+
input_video: Optional[str],
|
| 342 |
+
input_images: Optional[List],
|
| 343 |
+
s_time_interval: float = 10.0,
|
| 344 |
+
) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
|
| 345 |
+
"""
|
| 346 |
+
Handle file uploads and update gallery.
|
| 347 |
+
|
| 348 |
+
Args:
|
| 349 |
+
input_video: Path to input video file
|
| 350 |
+
input_images: List of input image files
|
| 351 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
|
| 355 |
+
"""
|
| 356 |
+
return self.file_handler.update_gallery_on_upload(
|
| 357 |
+
input_video, input_images, s_time_interval
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
def load_example_scene(self, scene_name: str, examples_dir: str = None) -> Tuple[
|
| 361 |
+
Optional[str],
|
| 362 |
+
Optional[str],
|
| 363 |
+
Optional[List],
|
| 364 |
+
str,
|
| 365 |
+
Optional[Dict],
|
| 366 |
+
gr.Dropdown,
|
| 367 |
+
Optional[str],
|
| 368 |
+
gr.update,
|
| 369 |
+
gr.update,
|
| 370 |
+
]:
|
| 371 |
+
"""
|
| 372 |
+
Load a scene from examples directory.
|
| 373 |
+
|
| 374 |
+
Args:
|
| 375 |
+
scene_name: Name of the scene to load
|
| 376 |
+
examples_dir: Path to examples directory (if None, uses workspace_dir/examples)
|
| 377 |
+
|
| 378 |
+
Returns:
|
| 379 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message, processed_data, measure_view_selector, gs_video, gs_video_vis, gs_info_vis) # noqa: E501
|
| 380 |
+
"""
|
| 381 |
+
if examples_dir is None:
|
| 382 |
+
# Get workspace directory from environment variable
|
| 383 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 384 |
+
examples_dir = os.path.join(workspace_dir, "examples")
|
| 385 |
+
|
| 386 |
+
reconstruction_output, target_dir, image_paths, log_message = (
|
| 387 |
+
self.file_handler.load_example_scene(scene_name, examples_dir)
|
| 388 |
+
)
|
| 389 |
+
|
| 390 |
+
# Try to load cached processed data if available
|
| 391 |
+
processed_data = None
|
| 392 |
+
measure_view_selector = gr.Dropdown(choices=["View 1"], value="View 1")
|
| 393 |
+
gs_video_path = None
|
| 394 |
+
gs_video_visible = False
|
| 395 |
+
gs_info_visible = True
|
| 396 |
+
|
| 397 |
+
if target_dir and target_dir != "None":
|
| 398 |
+
predictions_path = os.path.join(target_dir, "predictions.npz")
|
| 399 |
+
if os.path.exists(predictions_path):
|
| 400 |
+
try:
|
| 401 |
+
# Load predictions from cache
|
| 402 |
+
loaded = np.load(predictions_path, allow_pickle=True)
|
| 403 |
+
predictions = {key: loaded[key] for key in loaded.keys()}
|
| 404 |
+
|
| 405 |
+
# Reconstruct processed_data structure
|
| 406 |
+
num_images = len(predictions.get("images", []))
|
| 407 |
+
processed_data = {}
|
| 408 |
+
|
| 409 |
+
for i in range(num_images):
|
| 410 |
+
processed_data[i] = {
|
| 411 |
+
"image": predictions["images"][i] if "images" in predictions else None,
|
| 412 |
+
"depth": predictions["depths"][i] if "depths" in predictions else None,
|
| 413 |
+
"depth_image": os.path.join(
|
| 414 |
+
target_dir, "depth_vis", f"{i:04d}.jpg" # Fixed: use .jpg not .png
|
| 415 |
+
),
|
| 416 |
+
"intrinsics": (
|
| 417 |
+
predictions["intrinsics"][i]
|
| 418 |
+
if "intrinsics" in predictions
|
| 419 |
+
and i < len(predictions["intrinsics"])
|
| 420 |
+
else None
|
| 421 |
+
),
|
| 422 |
+
"mask": None,
|
| 423 |
+
}
|
| 424 |
+
|
| 425 |
+
# Update measure view selector
|
| 426 |
+
choices = [f"View {i + 1}" for i in range(num_images)]
|
| 427 |
+
measure_view_selector = gr.Dropdown(choices=choices, value=choices[0])
|
| 428 |
+
|
| 429 |
+
except Exception as e:
|
| 430 |
+
print(f"Error loading cached data: {e}")
|
| 431 |
+
|
| 432 |
+
# Check for cached 3DGS video
|
| 433 |
+
gs_video_dir = os.path.join(target_dir, "gs_video")
|
| 434 |
+
if os.path.exists(gs_video_dir):
|
| 435 |
+
try:
|
| 436 |
+
from glob import glob
|
| 437 |
+
|
| 438 |
+
gs_videos = sorted(glob(os.path.join(gs_video_dir, "*.mp4")))
|
| 439 |
+
if gs_videos:
|
| 440 |
+
gs_video_path = gs_videos[-1]
|
| 441 |
+
gs_video_visible = True
|
| 442 |
+
gs_info_visible = False
|
| 443 |
+
print(f"Loaded cached 3DGS video: {gs_video_path}")
|
| 444 |
+
except Exception as e:
|
| 445 |
+
print(f"Error loading cached 3DGS video: {e}")
|
| 446 |
+
|
| 447 |
+
return (
|
| 448 |
+
reconstruction_output,
|
| 449 |
+
target_dir,
|
| 450 |
+
image_paths,
|
| 451 |
+
log_message,
|
| 452 |
+
processed_data,
|
| 453 |
+
measure_view_selector,
|
| 454 |
+
gs_video_path,
|
| 455 |
+
gr.update(visible=gs_video_visible),
|
| 456 |
+
gr.update(visible=gs_info_visible),
|
| 457 |
+
)
|
| 458 |
+
|
| 459 |
+
def navigate_depth_view(
|
| 460 |
+
self,
|
| 461 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 462 |
+
current_selector: str,
|
| 463 |
+
direction: int,
|
| 464 |
+
) -> Tuple[str, Optional[str]]:
|
| 465 |
+
"""
|
| 466 |
+
Navigate depth view.
|
| 467 |
+
|
| 468 |
+
Args:
|
| 469 |
+
processed_data: Processed data dictionary
|
| 470 |
+
current_selector: Current selector value
|
| 471 |
+
direction: Direction to navigate
|
| 472 |
+
|
| 473 |
+
Returns:
|
| 474 |
+
Tuple of (new_selector_value, depth_vis)
|
| 475 |
+
"""
|
| 476 |
+
return self.visualization_handler.navigate_depth_view(
|
| 477 |
+
processed_data, current_selector, direction
|
| 478 |
+
)
|
| 479 |
+
|
| 480 |
+
def update_depth_view(
|
| 481 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 482 |
+
) -> Optional[str]:
|
| 483 |
+
"""
|
| 484 |
+
Update depth view for a specific view index.
|
| 485 |
+
|
| 486 |
+
Args:
|
| 487 |
+
processed_data: Processed data dictionary
|
| 488 |
+
view_index: Index of the view to update
|
| 489 |
+
|
| 490 |
+
Returns:
|
| 491 |
+
Path to depth visualization image or None
|
| 492 |
+
"""
|
| 493 |
+
return self.visualization_handler.update_depth_view(processed_data, view_index)
|
| 494 |
+
|
| 495 |
+
def navigate_measure_view(
|
| 496 |
+
self,
|
| 497 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 498 |
+
current_selector: str,
|
| 499 |
+
direction: int,
|
| 500 |
+
) -> Tuple[str, Optional[np.ndarray], Optional[np.ndarray], List]:
|
| 501 |
+
"""
|
| 502 |
+
Navigate measure view.
|
| 503 |
+
|
| 504 |
+
Args:
|
| 505 |
+
processed_data: Processed data dictionary
|
| 506 |
+
current_selector: Current selector value
|
| 507 |
+
direction: Direction to navigate
|
| 508 |
+
|
| 509 |
+
Returns:
|
| 510 |
+
Tuple of (new_selector_value, measure_image, depth_right_half, measure_points)
|
| 511 |
+
"""
|
| 512 |
+
return self.visualization_handler.navigate_measure_view(
|
| 513 |
+
processed_data, current_selector, direction
|
| 514 |
+
)
|
| 515 |
+
|
| 516 |
+
def update_measure_view(
|
| 517 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 518 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
|
| 519 |
+
"""
|
| 520 |
+
Update measure view for a specific view index.
|
| 521 |
+
|
| 522 |
+
Args:
|
| 523 |
+
processed_data: Processed data dictionary
|
| 524 |
+
view_index: Index of the view to update
|
| 525 |
+
|
| 526 |
+
Returns:
|
| 527 |
+
Tuple of (measure_image, depth_right_half, measure_points)
|
| 528 |
+
"""
|
| 529 |
+
return self.visualization_handler.update_measure_view(processed_data, view_index)
|
| 530 |
+
|
| 531 |
+
def measure(
|
| 532 |
+
self,
|
| 533 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 534 |
+
measure_points: List,
|
| 535 |
+
current_view_selector: str,
|
| 536 |
+
event: gr.SelectData,
|
| 537 |
+
) -> List:
|
| 538 |
+
"""
|
| 539 |
+
Handle measurement on images.
|
| 540 |
+
|
| 541 |
+
Args:
|
| 542 |
+
processed_data: Processed data dictionary
|
| 543 |
+
measure_points: List of current measure points
|
| 544 |
+
current_view_selector: Current view selector value
|
| 545 |
+
event: Gradio select event
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
List of [image, depth_right_half, measure_points, text]
|
| 549 |
+
"""
|
| 550 |
+
return self.visualization_handler.measure(
|
| 551 |
+
processed_data, measure_points, current_view_selector, event
|
| 552 |
+
)
|
| 553 |
+
|
| 554 |
+
def select_first_frame(
|
| 555 |
+
self, image_gallery: List, selected_index: int = 0
|
| 556 |
+
) -> Tuple[List, str, str]:
|
| 557 |
+
"""
|
| 558 |
+
Select the first frame from the image gallery.
|
| 559 |
+
|
| 560 |
+
Args:
|
| 561 |
+
image_gallery: List of images in the gallery
|
| 562 |
+
selected_index: Index of the selected image (default: 0)
|
| 563 |
+
|
| 564 |
+
Returns:
|
| 565 |
+
Tuple of (updated_image_gallery, log_message, selected_frame_path)
|
| 566 |
+
"""
|
| 567 |
+
try:
|
| 568 |
+
if not image_gallery or len(image_gallery) == 0:
|
| 569 |
+
return image_gallery, "No images available to select as first frame.", ""
|
| 570 |
+
|
| 571 |
+
# Handle None or invalid selected_index
|
| 572 |
+
if (
|
| 573 |
+
selected_index is None
|
| 574 |
+
or selected_index < 0
|
| 575 |
+
or selected_index >= len(image_gallery)
|
| 576 |
+
):
|
| 577 |
+
selected_index = 0
|
| 578 |
+
print(f"Invalid selected_index: {selected_index}, using default: 0")
|
| 579 |
+
|
| 580 |
+
# Get the selected image based on index
|
| 581 |
+
selected_image = image_gallery[selected_index]
|
| 582 |
+
print(f"Selected image index: {selected_index}")
|
| 583 |
+
print(f"Total images: {len(image_gallery)}")
|
| 584 |
+
|
| 585 |
+
# Extract the file path from the selected image
|
| 586 |
+
selected_frame_path = ""
|
| 587 |
+
print(f"Selected image type: {type(selected_image)}")
|
| 588 |
+
print(f"Selected image: {selected_image}")
|
| 589 |
+
|
| 590 |
+
if isinstance(selected_image, tuple):
|
| 591 |
+
# Gradio Gallery returns tuple (path, None)
|
| 592 |
+
selected_frame_path = selected_image[0]
|
| 593 |
+
elif isinstance(selected_image, str):
|
| 594 |
+
selected_frame_path = selected_image
|
| 595 |
+
elif hasattr(selected_image, "name"):
|
| 596 |
+
selected_frame_path = selected_image.name
|
| 597 |
+
elif isinstance(selected_image, dict):
|
| 598 |
+
if "name" in selected_image:
|
| 599 |
+
selected_frame_path = selected_image["name"]
|
| 600 |
+
elif "path" in selected_image:
|
| 601 |
+
selected_frame_path = selected_image["path"]
|
| 602 |
+
elif "src" in selected_image:
|
| 603 |
+
selected_frame_path = selected_image["src"]
|
| 604 |
+
else:
|
| 605 |
+
# Try to convert to string
|
| 606 |
+
selected_frame_path = str(selected_image)
|
| 607 |
+
|
| 608 |
+
print(f"Extracted path: {selected_frame_path}")
|
| 609 |
+
|
| 610 |
+
# Extract filename from the path for matching
|
| 611 |
+
import os
|
| 612 |
+
|
| 613 |
+
selected_filename = os.path.basename(selected_frame_path)
|
| 614 |
+
print(f"Selected filename: {selected_filename}")
|
| 615 |
+
|
| 616 |
+
# Move the selected image to the front
|
| 617 |
+
updated_gallery = [selected_image] + [
|
| 618 |
+
img for img in image_gallery if img != selected_image
|
| 619 |
+
]
|
| 620 |
+
|
| 621 |
+
log_message = (
|
| 622 |
+
f"Selected frame: {selected_filename}. "
|
| 623 |
+
f"Moved to first position. Total frames: {len(updated_gallery)}"
|
| 624 |
+
)
|
| 625 |
+
return updated_gallery, log_message, selected_filename
|
| 626 |
+
|
| 627 |
+
except Exception as e:
|
| 628 |
+
print(f"Error selecting first frame: {e}")
|
| 629 |
+
return image_gallery, f"Error selecting first frame: {e}", ""
|
src/depth_anything_3/app/modules/file_handlers.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
File handling module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles file uploads, video processing, and file operations.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
import shutil
|
| 23 |
+
import time
|
| 24 |
+
from datetime import datetime
|
| 25 |
+
from typing import List, Optional, Tuple
|
| 26 |
+
import cv2
|
| 27 |
+
from PIL import Image
|
| 28 |
+
from pillow_heif import register_heif_opener
|
| 29 |
+
|
| 30 |
+
register_heif_opener()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class FileHandler:
|
| 34 |
+
"""
|
| 35 |
+
Handles file uploads and processing for the Gradio app.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(self):
|
| 39 |
+
"""Initialize the file handler."""
|
| 40 |
+
|
| 41 |
+
def handle_uploads(
|
| 42 |
+
self,
|
| 43 |
+
input_video: Optional[str],
|
| 44 |
+
input_images: Optional[List],
|
| 45 |
+
s_time_interval: float = 10.0,
|
| 46 |
+
) -> Tuple[str, List[str]]:
|
| 47 |
+
"""
|
| 48 |
+
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
|
| 49 |
+
images or extracted frames from video into it.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
input_video: Path to input video file
|
| 53 |
+
input_images: List of input image files
|
| 54 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
Tuple of (target_dir, image_paths)
|
| 58 |
+
"""
|
| 59 |
+
start_time = time.time()
|
| 60 |
+
|
| 61 |
+
# Get workspace directory from environment variable or use default
|
| 62 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 63 |
+
if not os.path.exists(workspace_dir):
|
| 64 |
+
os.makedirs(workspace_dir)
|
| 65 |
+
|
| 66 |
+
# Create input_images subdirectory
|
| 67 |
+
input_images_dir = os.path.join(workspace_dir, "input_images")
|
| 68 |
+
if not os.path.exists(input_images_dir):
|
| 69 |
+
os.makedirs(input_images_dir)
|
| 70 |
+
|
| 71 |
+
# Create a unique folder name within input_images
|
| 72 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
|
| 73 |
+
target_dir = os.path.join(input_images_dir, f"session_{timestamp}")
|
| 74 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 75 |
+
|
| 76 |
+
# Clean up if somehow that folder already exists
|
| 77 |
+
if os.path.exists(target_dir):
|
| 78 |
+
shutil.rmtree(target_dir)
|
| 79 |
+
os.makedirs(target_dir)
|
| 80 |
+
os.makedirs(target_dir_images)
|
| 81 |
+
|
| 82 |
+
image_paths = []
|
| 83 |
+
|
| 84 |
+
# Handle images
|
| 85 |
+
if input_images is not None:
|
| 86 |
+
image_paths.extend(self._process_images(input_images, target_dir_images))
|
| 87 |
+
|
| 88 |
+
# Handle video
|
| 89 |
+
if input_video is not None:
|
| 90 |
+
image_paths.extend(
|
| 91 |
+
self._process_video(input_video, target_dir_images, s_time_interval)
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Sort final images for gallery
|
| 95 |
+
image_paths = sorted(image_paths)
|
| 96 |
+
|
| 97 |
+
end_time = time.time()
|
| 98 |
+
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
|
| 99 |
+
return target_dir, image_paths
|
| 100 |
+
|
| 101 |
+
def _process_images(self, input_images: List, target_dir_images: str) -> List[str]:
|
| 102 |
+
"""
|
| 103 |
+
Process uploaded images.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
input_images: List of input image files
|
| 107 |
+
target_dir_images: Target directory for images
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
List of processed image paths
|
| 111 |
+
"""
|
| 112 |
+
image_paths = []
|
| 113 |
+
|
| 114 |
+
for file_data in input_images:
|
| 115 |
+
if isinstance(file_data, dict) and "name" in file_data:
|
| 116 |
+
file_path = file_data["name"]
|
| 117 |
+
else:
|
| 118 |
+
file_path = file_data
|
| 119 |
+
|
| 120 |
+
# Check if the file is a HEIC image
|
| 121 |
+
file_ext = os.path.splitext(file_path)[1].lower()
|
| 122 |
+
if file_ext in [".heic", ".heif"]:
|
| 123 |
+
# Convert HEIC to JPEG for better gallery compatibility
|
| 124 |
+
try:
|
| 125 |
+
with Image.open(file_path) as img:
|
| 126 |
+
# Convert to RGB if necessary (HEIC can have different color modes)
|
| 127 |
+
if img.mode not in ("RGB", "L"):
|
| 128 |
+
img = img.convert("RGB")
|
| 129 |
+
|
| 130 |
+
# Create JPEG filename
|
| 131 |
+
base_name = os.path.splitext(os.path.basename(file_path))[0]
|
| 132 |
+
dst_path = os.path.join(target_dir_images, f"{base_name}.jpg")
|
| 133 |
+
|
| 134 |
+
# Save as JPEG with high quality
|
| 135 |
+
img.save(dst_path, "JPEG", quality=95)
|
| 136 |
+
image_paths.append(dst_path)
|
| 137 |
+
print(
|
| 138 |
+
f"Converted HEIC to JPEG: {os.path.basename(file_path)} -> "
|
| 139 |
+
f"{os.path.basename(dst_path)}"
|
| 140 |
+
)
|
| 141 |
+
except Exception as e:
|
| 142 |
+
print(f"Error converting HEIC file {file_path}: {e}")
|
| 143 |
+
# Fall back to copying as is
|
| 144 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 145 |
+
shutil.copy(file_path, dst_path)
|
| 146 |
+
image_paths.append(dst_path)
|
| 147 |
+
else:
|
| 148 |
+
# Regular image files - copy as is
|
| 149 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 150 |
+
shutil.copy(file_path, dst_path)
|
| 151 |
+
image_paths.append(dst_path)
|
| 152 |
+
|
| 153 |
+
return image_paths
|
| 154 |
+
|
| 155 |
+
def _process_video(
|
| 156 |
+
self, input_video: str, target_dir_images: str, s_time_interval: float
|
| 157 |
+
) -> List[str]:
|
| 158 |
+
"""
|
| 159 |
+
Process video file and extract frames.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
input_video: Path to input video file
|
| 163 |
+
target_dir_images: Target directory for extracted frames
|
| 164 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
List of extracted frame paths
|
| 168 |
+
"""
|
| 169 |
+
image_paths = []
|
| 170 |
+
|
| 171 |
+
if isinstance(input_video, dict) and "name" in input_video:
|
| 172 |
+
video_path = input_video["name"]
|
| 173 |
+
else:
|
| 174 |
+
video_path = input_video
|
| 175 |
+
|
| 176 |
+
vs = cv2.VideoCapture(video_path)
|
| 177 |
+
fps = vs.get(cv2.CAP_PROP_FPS)
|
| 178 |
+
frame_interval = max(1, int(fps / s_time_interval)) # Convert FPS to frame interval
|
| 179 |
+
|
| 180 |
+
count = 0
|
| 181 |
+
video_frame_num = 0
|
| 182 |
+
while True:
|
| 183 |
+
gotit, frame = vs.read()
|
| 184 |
+
if not gotit:
|
| 185 |
+
break
|
| 186 |
+
count += 1
|
| 187 |
+
if count % frame_interval == 0:
|
| 188 |
+
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png")
|
| 189 |
+
cv2.imwrite(image_path, frame)
|
| 190 |
+
image_paths.append(image_path)
|
| 191 |
+
video_frame_num += 1
|
| 192 |
+
|
| 193 |
+
return image_paths
|
| 194 |
+
|
| 195 |
+
def update_gallery_on_upload(
|
| 196 |
+
self,
|
| 197 |
+
input_video: Optional[str],
|
| 198 |
+
input_images: Optional[List],
|
| 199 |
+
s_time_interval: float = 10.0,
|
| 200 |
+
) -> Tuple[Optional[str], Optional[str], Optional[List], Optional[str]]:
|
| 201 |
+
"""
|
| 202 |
+
Handle file uploads and update gallery.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
input_video: Path to input video file
|
| 206 |
+
input_images: List of input image files
|
| 207 |
+
s_time_interval: Sampling FPS (frames per second) for frame extraction
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
|
| 211 |
+
"""
|
| 212 |
+
if not input_video and not input_images:
|
| 213 |
+
return None, None, None, None
|
| 214 |
+
|
| 215 |
+
target_dir, image_paths = self.handle_uploads(input_video, input_images, s_time_interval)
|
| 216 |
+
return (
|
| 217 |
+
None,
|
| 218 |
+
target_dir,
|
| 219 |
+
image_paths,
|
| 220 |
+
"Upload complete. Click 'Reconstruct' to begin 3D processing.",
|
| 221 |
+
)
|
| 222 |
+
|
| 223 |
+
def load_example_scene(
|
| 224 |
+
self, scene_name: str, examples_dir: str = "examples"
|
| 225 |
+
) -> Tuple[Optional[str], Optional[str], Optional[List], str]:
|
| 226 |
+
"""
|
| 227 |
+
Load a scene from examples directory.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
scene_name: Name of the scene to load
|
| 231 |
+
examples_dir: Path to examples directory
|
| 232 |
+
|
| 233 |
+
Returns:
|
| 234 |
+
Tuple of (reconstruction_output, target_dir, image_paths, log_message)
|
| 235 |
+
"""
|
| 236 |
+
from depth_anything_3.app.modules.utils import get_scene_info
|
| 237 |
+
|
| 238 |
+
scenes = get_scene_info(examples_dir)
|
| 239 |
+
|
| 240 |
+
# Find the selected scene
|
| 241 |
+
selected_scene = None
|
| 242 |
+
for scene in scenes:
|
| 243 |
+
if scene["name"] == scene_name:
|
| 244 |
+
selected_scene = scene
|
| 245 |
+
break
|
| 246 |
+
|
| 247 |
+
if selected_scene is None:
|
| 248 |
+
return None, None, None, "Scene not found"
|
| 249 |
+
|
| 250 |
+
# Use fixed directory name for examples (not timestamp-based)
|
| 251 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 252 |
+
input_images_dir = os.path.join(workspace_dir, "input_images")
|
| 253 |
+
if not os.path.exists(input_images_dir):
|
| 254 |
+
os.makedirs(input_images_dir)
|
| 255 |
+
|
| 256 |
+
# Create a fixed folder name based on scene name
|
| 257 |
+
target_dir = os.path.join(input_images_dir, f"example_{scene_name}")
|
| 258 |
+
target_dir_images = os.path.join(target_dir, "images")
|
| 259 |
+
|
| 260 |
+
# Check if already cached (GLB file exists)
|
| 261 |
+
glb_path = os.path.join(target_dir, "scene.glb")
|
| 262 |
+
is_cached = os.path.exists(glb_path)
|
| 263 |
+
|
| 264 |
+
# Create directory if it doesn't exist
|
| 265 |
+
if not os.path.exists(target_dir):
|
| 266 |
+
os.makedirs(target_dir)
|
| 267 |
+
os.makedirs(target_dir_images)
|
| 268 |
+
|
| 269 |
+
# Copy images if directory is new or empty
|
| 270 |
+
if not os.path.exists(target_dir_images) or len(os.listdir(target_dir_images)) == 0:
|
| 271 |
+
os.makedirs(target_dir_images, exist_ok=True)
|
| 272 |
+
image_paths = []
|
| 273 |
+
for file_path in selected_scene["image_files"]:
|
| 274 |
+
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
|
| 275 |
+
shutil.copy(file_path, dst_path)
|
| 276 |
+
image_paths.append(dst_path)
|
| 277 |
+
else:
|
| 278 |
+
# Use existing images
|
| 279 |
+
image_paths = sorted(
|
| 280 |
+
[
|
| 281 |
+
os.path.join(target_dir_images, f)
|
| 282 |
+
for f in os.listdir(target_dir_images)
|
| 283 |
+
if f.lower().endswith((".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".tif"))
|
| 284 |
+
]
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
# Return cached GLB if available
|
| 288 |
+
if is_cached:
|
| 289 |
+
return (
|
| 290 |
+
glb_path, # Return cached reconstruction
|
| 291 |
+
target_dir, # Set target directory
|
| 292 |
+
image_paths, # Set gallery
|
| 293 |
+
f"Loaded cached scene '{scene_name}' with {selected_scene['num_images']} images.",
|
| 294 |
+
)
|
| 295 |
+
else:
|
| 296 |
+
return (
|
| 297 |
+
None, # No cached reconstruction
|
| 298 |
+
target_dir, # Set target directory
|
| 299 |
+
image_paths, # Set gallery
|
| 300 |
+
(
|
| 301 |
+
f"Loaded scene '{scene_name}' with {selected_scene['num_images']} images. "
|
| 302 |
+
"Click 'Reconstruct' to begin 3D processing."
|
| 303 |
+
),
|
| 304 |
+
)
|
src/depth_anything_3/app/modules/model_inference.py
ADDED
|
@@ -0,0 +1,286 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Model inference module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles all model-related operations including inference,
|
| 19 |
+
data processing, and result preparation.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import gc
|
| 23 |
+
import glob
|
| 24 |
+
import os
|
| 25 |
+
from typing import Any, Dict, Optional, Tuple
|
| 26 |
+
import numpy as np
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
from depth_anything_3.api import DepthAnything3
|
| 30 |
+
from depth_anything_3.utils.export.glb import export_to_glb
|
| 31 |
+
from depth_anything_3.utils.export.gs import export_to_gs_video
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class ModelInference:
|
| 35 |
+
"""
|
| 36 |
+
Handles model inference and data processing for Depth Anything 3.
|
| 37 |
+
"""
|
| 38 |
+
|
| 39 |
+
def __init__(self):
|
| 40 |
+
"""Initialize the model inference handler."""
|
| 41 |
+
self.model = None
|
| 42 |
+
|
| 43 |
+
def initialize_model(self, device: str = "cuda") -> None:
|
| 44 |
+
"""
|
| 45 |
+
Initialize the DepthAnything3 model.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
device: Device to load the model on
|
| 49 |
+
"""
|
| 50 |
+
if self.model is None:
|
| 51 |
+
# Get model directory from environment variable or use default
|
| 52 |
+
model_dir = os.environ.get(
|
| 53 |
+
"DA3_MODEL_DIR", "/dev/shm/da3_models/DA3HF-VITG-METRIC_VITL"
|
| 54 |
+
)
|
| 55 |
+
self.model = DepthAnything3.from_pretrained(model_dir)
|
| 56 |
+
self.model = self.model.to(device)
|
| 57 |
+
else:
|
| 58 |
+
self.model = self.model.to(device)
|
| 59 |
+
|
| 60 |
+
self.model.eval()
|
| 61 |
+
|
| 62 |
+
def run_inference(
|
| 63 |
+
self,
|
| 64 |
+
target_dir: str,
|
| 65 |
+
filter_black_bg: bool = False,
|
| 66 |
+
filter_white_bg: bool = False,
|
| 67 |
+
process_res_method: str = "upper_bound_resize",
|
| 68 |
+
show_camera: bool = True,
|
| 69 |
+
selected_first_frame: Optional[str] = None,
|
| 70 |
+
save_percentage: float = 30.0,
|
| 71 |
+
num_max_points: int = 1_000_000,
|
| 72 |
+
infer_gs: bool = False,
|
| 73 |
+
gs_trj_mode: str = "extend",
|
| 74 |
+
gs_video_quality: str = "high",
|
| 75 |
+
) -> Tuple[Any, Dict[int, Dict[str, Any]]]:
|
| 76 |
+
"""
|
| 77 |
+
Run DepthAnything3 model inference on images.
|
| 78 |
+
|
| 79 |
+
Args:
|
| 80 |
+
target_dir: Directory containing images
|
| 81 |
+
apply_mask: Whether to apply mask for ambiguous depth classes
|
| 82 |
+
mask_edges: Whether to mask edges
|
| 83 |
+
filter_black_bg: Whether to filter black background
|
| 84 |
+
filter_white_bg: Whether to filter white background
|
| 85 |
+
process_res_method: Method for resizing input images
|
| 86 |
+
show_camera: Whether to show camera in 3D view
|
| 87 |
+
selected_first_frame: Selected first frame filename
|
| 88 |
+
save_percentage: Percentage of points to save (0-100)
|
| 89 |
+
infer_gs: Whether to infer 3D Gaussian Splatting
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Tuple of (prediction, processed_data)
|
| 93 |
+
"""
|
| 94 |
+
print(f"Processing images from {target_dir}")
|
| 95 |
+
|
| 96 |
+
# Device check
|
| 97 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 98 |
+
device = torch.device(device)
|
| 99 |
+
|
| 100 |
+
# Initialize model if needed
|
| 101 |
+
self.initialize_model(device)
|
| 102 |
+
|
| 103 |
+
# Get image paths
|
| 104 |
+
print("Loading images...")
|
| 105 |
+
image_folder_path = os.path.join(target_dir, "images")
|
| 106 |
+
all_image_paths = sorted(glob.glob(os.path.join(image_folder_path, "*")))
|
| 107 |
+
|
| 108 |
+
# Filter for image files
|
| 109 |
+
image_extensions = [".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".tif"]
|
| 110 |
+
all_image_paths = [
|
| 111 |
+
path
|
| 112 |
+
for path in all_image_paths
|
| 113 |
+
if any(path.lower().endswith(ext) for ext in image_extensions)
|
| 114 |
+
]
|
| 115 |
+
|
| 116 |
+
print(f"Found {len(all_image_paths)} images")
|
| 117 |
+
print(f"All image paths: {all_image_paths}")
|
| 118 |
+
|
| 119 |
+
# Apply first frame selection logic
|
| 120 |
+
if selected_first_frame:
|
| 121 |
+
# Find the image with matching filename
|
| 122 |
+
selected_path = None
|
| 123 |
+
for path in all_image_paths:
|
| 124 |
+
if os.path.basename(path) == selected_first_frame:
|
| 125 |
+
selected_path = path
|
| 126 |
+
break
|
| 127 |
+
|
| 128 |
+
if selected_path:
|
| 129 |
+
# Move selected frame to the front
|
| 130 |
+
image_paths = [selected_path] + [
|
| 131 |
+
path for path in all_image_paths if path != selected_path
|
| 132 |
+
]
|
| 133 |
+
print(f"User selected first frame: {selected_first_frame} -> {selected_path}")
|
| 134 |
+
print(f"Reordered image paths: {image_paths}")
|
| 135 |
+
else:
|
| 136 |
+
# Use default order if no match found
|
| 137 |
+
image_paths = all_image_paths
|
| 138 |
+
print(
|
| 139 |
+
f"Selected frame '{selected_first_frame}' not found in image paths. "
|
| 140 |
+
"Using default order."
|
| 141 |
+
)
|
| 142 |
+
first_frame_display = image_paths[0] if image_paths else "No images"
|
| 143 |
+
print(f"Using default order (first frame): {first_frame_display}")
|
| 144 |
+
else:
|
| 145 |
+
# Use default order (sorted)
|
| 146 |
+
image_paths = all_image_paths
|
| 147 |
+
first_frame_display = image_paths[0] if image_paths else "No images"
|
| 148 |
+
print(f"Using default order (first frame): {first_frame_display}")
|
| 149 |
+
|
| 150 |
+
if len(image_paths) == 0:
|
| 151 |
+
raise ValueError("No images found. Check your upload.")
|
| 152 |
+
|
| 153 |
+
# Map UI options to actual method names
|
| 154 |
+
method_mapping = {"high_res": "lower_bound_resize", "low_res": "upper_bound_resize"}
|
| 155 |
+
actual_method = method_mapping.get(process_res_method, "upper_bound_crop")
|
| 156 |
+
|
| 157 |
+
# Run model inference
|
| 158 |
+
print(f"Running inference with method: {actual_method}")
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
prediction = self.model.inference(
|
| 161 |
+
image_paths, export_dir=None, process_res_method=actual_method, infer_gs=infer_gs
|
| 162 |
+
)
|
| 163 |
+
# num_max_points: int = 1_000_000,
|
| 164 |
+
export_to_glb(
|
| 165 |
+
prediction,
|
| 166 |
+
filter_black_bg=filter_black_bg,
|
| 167 |
+
filter_white_bg=filter_white_bg,
|
| 168 |
+
export_dir=target_dir,
|
| 169 |
+
show_cameras=show_camera,
|
| 170 |
+
conf_thresh_percentile=save_percentage,
|
| 171 |
+
num_max_points=int(num_max_points),
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# export to gs video if needed
|
| 175 |
+
if infer_gs:
|
| 176 |
+
mode_mapping = {"extend": "extend", "smooth": "interpolate_smooth"}
|
| 177 |
+
print(f"GS mode: {gs_trj_mode}; Backend mode: {mode_mapping[gs_trj_mode]}")
|
| 178 |
+
export_to_gs_video(
|
| 179 |
+
prediction,
|
| 180 |
+
export_dir=target_dir,
|
| 181 |
+
chunk_size=4,
|
| 182 |
+
trj_mode=mode_mapping.get(gs_trj_mode, "extend"),
|
| 183 |
+
enable_tqdm=True,
|
| 184 |
+
vis_depth="hcat",
|
| 185 |
+
video_quality=gs_video_quality,
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Save predictions.npz for caching metric depth data
|
| 189 |
+
self._save_predictions_cache(target_dir, prediction)
|
| 190 |
+
|
| 191 |
+
# Process results
|
| 192 |
+
processed_data = self._process_results(target_dir, prediction, image_paths)
|
| 193 |
+
|
| 194 |
+
# Clean up
|
| 195 |
+
torch.cuda.empty_cache()
|
| 196 |
+
|
| 197 |
+
return prediction, processed_data
|
| 198 |
+
|
| 199 |
+
def _save_predictions_cache(self, target_dir: str, prediction: Any) -> None:
|
| 200 |
+
"""
|
| 201 |
+
Save predictions data to predictions.npz for caching.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
target_dir: Directory to save the cache
|
| 205 |
+
prediction: Model prediction object
|
| 206 |
+
"""
|
| 207 |
+
try:
|
| 208 |
+
output_file = os.path.join(target_dir, "predictions.npz")
|
| 209 |
+
|
| 210 |
+
# Build save dict with prediction data
|
| 211 |
+
save_dict = {}
|
| 212 |
+
|
| 213 |
+
# Save processed images if available
|
| 214 |
+
if prediction.processed_images is not None:
|
| 215 |
+
save_dict["images"] = prediction.processed_images
|
| 216 |
+
|
| 217 |
+
# Save depth data
|
| 218 |
+
if prediction.depth is not None:
|
| 219 |
+
save_dict["depths"] = np.round(prediction.depth, 6)
|
| 220 |
+
|
| 221 |
+
# Save confidence if available
|
| 222 |
+
if prediction.conf is not None:
|
| 223 |
+
save_dict["conf"] = np.round(prediction.conf, 2)
|
| 224 |
+
|
| 225 |
+
# Save camera parameters
|
| 226 |
+
if prediction.extrinsics is not None:
|
| 227 |
+
save_dict["extrinsics"] = prediction.extrinsics
|
| 228 |
+
if prediction.intrinsics is not None:
|
| 229 |
+
save_dict["intrinsics"] = prediction.intrinsics
|
| 230 |
+
|
| 231 |
+
# Save to file
|
| 232 |
+
np.savez_compressed(output_file, **save_dict)
|
| 233 |
+
print(f"Saved predictions cache to: {output_file}")
|
| 234 |
+
|
| 235 |
+
except Exception as e:
|
| 236 |
+
print(f"Warning: Failed to save predictions cache: {e}")
|
| 237 |
+
|
| 238 |
+
def _process_results(
|
| 239 |
+
self, target_dir: str, prediction: Any, image_paths: list
|
| 240 |
+
) -> Dict[int, Dict[str, Any]]:
|
| 241 |
+
"""
|
| 242 |
+
Process model results into structured data.
|
| 243 |
+
|
| 244 |
+
Args:
|
| 245 |
+
target_dir: Directory containing results
|
| 246 |
+
prediction: Model prediction object
|
| 247 |
+
image_paths: List of input image paths
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Dictionary containing processed data for each view
|
| 251 |
+
"""
|
| 252 |
+
processed_data = {}
|
| 253 |
+
|
| 254 |
+
# Read generated depth visualization files
|
| 255 |
+
depth_vis_dir = os.path.join(target_dir, "depth_vis")
|
| 256 |
+
|
| 257 |
+
if os.path.exists(depth_vis_dir):
|
| 258 |
+
depth_files = sorted(glob.glob(os.path.join(depth_vis_dir, "*.jpg")))
|
| 259 |
+
for i, depth_file in enumerate(depth_files):
|
| 260 |
+
# Use processed images directly from API
|
| 261 |
+
processed_image = None
|
| 262 |
+
if prediction.processed_images is not None and i < len(
|
| 263 |
+
prediction.processed_images
|
| 264 |
+
):
|
| 265 |
+
processed_image = prediction.processed_images[i]
|
| 266 |
+
|
| 267 |
+
processed_data[i] = {
|
| 268 |
+
"depth_image": depth_file,
|
| 269 |
+
"image": processed_image,
|
| 270 |
+
"original_image_path": image_paths[i] if i < len(image_paths) else None,
|
| 271 |
+
"depth": prediction.depth[i] if i < len(prediction.depth) else None,
|
| 272 |
+
"intrinsics": (
|
| 273 |
+
prediction.intrinsics[i]
|
| 274 |
+
if prediction.intrinsics is not None and i < len(prediction.intrinsics)
|
| 275 |
+
else None
|
| 276 |
+
),
|
| 277 |
+
"mask": None, # No mask information available
|
| 278 |
+
}
|
| 279 |
+
|
| 280 |
+
return processed_data
|
| 281 |
+
|
| 282 |
+
def cleanup(self) -> None:
|
| 283 |
+
"""Clean up GPU memory."""
|
| 284 |
+
if torch.cuda.is_available():
|
| 285 |
+
torch.cuda.empty_cache()
|
| 286 |
+
gc.collect()
|
src/depth_anything_3/app/modules/ui_components.py
ADDED
|
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
UI components module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module contains UI component definitions and layout functions.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Dict, List, Tuple
|
| 23 |
+
import gradio as gr
|
| 24 |
+
|
| 25 |
+
from depth_anything_3.app.modules.utils import get_logo_base64, get_scene_info
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class UIComponents:
|
| 29 |
+
"""
|
| 30 |
+
Handles UI component creation and layout for the Gradio app.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
"""Initialize the UI components handler."""
|
| 35 |
+
|
| 36 |
+
def create_upload_section(self) -> Tuple[gr.Video, gr.Slider, gr.File, gr.Gallery, gr.Button]:
|
| 37 |
+
"""
|
| 38 |
+
Create the upload section with video, images, and gallery components.
|
| 39 |
+
|
| 40 |
+
Returns:
|
| 41 |
+
A tuple of Gradio components: (input_video, s_time_interval, input_images,
|
| 42 |
+
image_gallery, select_first_frame_btn).
|
| 43 |
+
"""
|
| 44 |
+
input_video = gr.Video(label="Upload Video", interactive=True)
|
| 45 |
+
s_time_interval = gr.Slider(
|
| 46 |
+
minimum=0.1,
|
| 47 |
+
maximum=60,
|
| 48 |
+
value=10,
|
| 49 |
+
step=0.1,
|
| 50 |
+
label="Sampling FPS (Frames Per Second)",
|
| 51 |
+
interactive=True,
|
| 52 |
+
visible=True,
|
| 53 |
+
)
|
| 54 |
+
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
|
| 55 |
+
image_gallery = gr.Gallery(
|
| 56 |
+
label="Preview",
|
| 57 |
+
columns=4,
|
| 58 |
+
height="300px",
|
| 59 |
+
show_download_button=True,
|
| 60 |
+
object_fit="contain",
|
| 61 |
+
preview=True,
|
| 62 |
+
interactive=False,
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
# Select first frame button (moved below image gallery)
|
| 66 |
+
select_first_frame_btn = gr.Button("Select First Frame", scale=1)
|
| 67 |
+
|
| 68 |
+
return input_video, s_time_interval, input_images, image_gallery, select_first_frame_btn
|
| 69 |
+
|
| 70 |
+
def create_3d_viewer_section(self) -> gr.Model3D:
|
| 71 |
+
"""
|
| 72 |
+
Create the 3D viewer component.
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
3D model viewer component
|
| 76 |
+
"""
|
| 77 |
+
return gr.Model3D(
|
| 78 |
+
height=520,
|
| 79 |
+
zoom_speed=0.5,
|
| 80 |
+
pan_speed=0.5,
|
| 81 |
+
clear_color=[0.0, 0.0, 0.0, 0.0],
|
| 82 |
+
key="persistent_3d_viewer",
|
| 83 |
+
elem_id="reconstruction_3d_viewer",
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
def create_nvs_video(self) -> Tuple[gr.Video, gr.Markdown]:
|
| 87 |
+
"""
|
| 88 |
+
Create the 3DGS rendered video display component and info message.
|
| 89 |
+
|
| 90 |
+
Returns:
|
| 91 |
+
Tuple of (video component, info message component)
|
| 92 |
+
"""
|
| 93 |
+
with gr.Column():
|
| 94 |
+
gs_info = gr.Markdown(
|
| 95 |
+
(
|
| 96 |
+
"‼️ **3D Gaussian Splatting rendering is currently DISABLED.** <br><br><br>"
|
| 97 |
+
"To render novel views from 3DGS, "
|
| 98 |
+
"enable **Infer 3D Gaussian Splatting** below. <br>"
|
| 99 |
+
"Next, in **Visualization Options**, "
|
| 100 |
+
"*optionally* configure the **rendering trajectory** (default: smooth) "
|
| 101 |
+
"and **video quality** (default: low), "
|
| 102 |
+
"then click **Reconstruct**."
|
| 103 |
+
),
|
| 104 |
+
visible=True,
|
| 105 |
+
height=520,
|
| 106 |
+
)
|
| 107 |
+
gs_video = gr.Video(
|
| 108 |
+
height=520,
|
| 109 |
+
label="3DGS Rendered NVS Video (depth shown for reference only)",
|
| 110 |
+
interactive=False,
|
| 111 |
+
visible=False,
|
| 112 |
+
)
|
| 113 |
+
return gs_video, gs_info
|
| 114 |
+
|
| 115 |
+
def create_depth_section(self) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image]:
|
| 116 |
+
"""
|
| 117 |
+
Create the depth visualization section.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
A tuple of (prev_depth_btn, depth_view_selector, next_depth_btn, depth_map)
|
| 121 |
+
"""
|
| 122 |
+
with gr.Row(elem_classes=["navigation-row"]):
|
| 123 |
+
prev_depth_btn = gr.Button("◀ Previous", size="sm", scale=1)
|
| 124 |
+
depth_view_selector = gr.Dropdown(
|
| 125 |
+
choices=["View 1"],
|
| 126 |
+
value="View 1",
|
| 127 |
+
label="Select View",
|
| 128 |
+
scale=2,
|
| 129 |
+
interactive=True,
|
| 130 |
+
allow_custom_value=True,
|
| 131 |
+
)
|
| 132 |
+
next_depth_btn = gr.Button("Next ▶", size="sm", scale=1)
|
| 133 |
+
depth_map = gr.Image(
|
| 134 |
+
type="numpy",
|
| 135 |
+
label="Colorized Depth Map",
|
| 136 |
+
format="png",
|
| 137 |
+
interactive=False,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
return prev_depth_btn, depth_view_selector, next_depth_btn, depth_map
|
| 141 |
+
|
| 142 |
+
def create_measure_section(
|
| 143 |
+
self,
|
| 144 |
+
) -> Tuple[gr.Button, gr.Dropdown, gr.Button, gr.Image, gr.Image, gr.Markdown]:
|
| 145 |
+
"""
|
| 146 |
+
Create the measurement section.
|
| 147 |
+
|
| 148 |
+
Returns:
|
| 149 |
+
A tuple of (prev_measure_btn, measure_view_selector, next_measure_btn, measure_image,
|
| 150 |
+
measure_depth_image, measure_text)
|
| 151 |
+
"""
|
| 152 |
+
from depth_anything_3.app.css_and_html import MEASURE_INSTRUCTIONS_HTML
|
| 153 |
+
|
| 154 |
+
gr.Markdown(MEASURE_INSTRUCTIONS_HTML)
|
| 155 |
+
with gr.Row(elem_classes=["navigation-row"]):
|
| 156 |
+
prev_measure_btn = gr.Button("◀ Previous", size="sm", scale=1)
|
| 157 |
+
measure_view_selector = gr.Dropdown(
|
| 158 |
+
choices=["View 1"],
|
| 159 |
+
value="View 1",
|
| 160 |
+
label="Select View",
|
| 161 |
+
scale=2,
|
| 162 |
+
interactive=True,
|
| 163 |
+
allow_custom_value=True,
|
| 164 |
+
)
|
| 165 |
+
next_measure_btn = gr.Button("Next ▶", size="sm", scale=1)
|
| 166 |
+
with gr.Row():
|
| 167 |
+
measure_image = gr.Image(
|
| 168 |
+
type="numpy",
|
| 169 |
+
show_label=False,
|
| 170 |
+
format="webp",
|
| 171 |
+
interactive=False,
|
| 172 |
+
sources=[],
|
| 173 |
+
label="RGB Image",
|
| 174 |
+
scale=1,
|
| 175 |
+
height=275,
|
| 176 |
+
)
|
| 177 |
+
measure_depth_image = gr.Image(
|
| 178 |
+
type="numpy",
|
| 179 |
+
show_label=False,
|
| 180 |
+
format="webp",
|
| 181 |
+
interactive=False,
|
| 182 |
+
sources=[],
|
| 183 |
+
label="Depth Visualization (Right Half)",
|
| 184 |
+
scale=1,
|
| 185 |
+
height=275,
|
| 186 |
+
)
|
| 187 |
+
gr.Markdown(
|
| 188 |
+
"**Note:** Images have been adjusted to model processing size. "
|
| 189 |
+
"Click two points on the RGB image to measure distance."
|
| 190 |
+
)
|
| 191 |
+
measure_text = gr.Markdown("")
|
| 192 |
+
|
| 193 |
+
return (
|
| 194 |
+
prev_measure_btn,
|
| 195 |
+
measure_view_selector,
|
| 196 |
+
next_measure_btn,
|
| 197 |
+
measure_image,
|
| 198 |
+
measure_depth_image,
|
| 199 |
+
measure_text,
|
| 200 |
+
)
|
| 201 |
+
|
| 202 |
+
def create_inference_control_section(self) -> Tuple[gr.Dropdown, gr.Checkbox]:
|
| 203 |
+
"""
|
| 204 |
+
Create the inference control section (before inference).
|
| 205 |
+
|
| 206 |
+
Returns:
|
| 207 |
+
Tuple of (process_res_method_dropdown, infer_gs)
|
| 208 |
+
"""
|
| 209 |
+
with gr.Row():
|
| 210 |
+
process_res_method_dropdown = gr.Dropdown(
|
| 211 |
+
choices=["high_res", "low_res"],
|
| 212 |
+
value="low_res",
|
| 213 |
+
label="Image Processing Method",
|
| 214 |
+
info="low_res for much more images",
|
| 215 |
+
scale=1,
|
| 216 |
+
)
|
| 217 |
+
# Modify line 220, add color class
|
| 218 |
+
infer_gs = gr.Checkbox(
|
| 219 |
+
label="Infer 3D Gaussian Splatting",
|
| 220 |
+
value=False,
|
| 221 |
+
info=(
|
| 222 |
+
'Enable novel view rendering from 3DGS (<i class="fas fa-triangle-exclamation '
|
| 223 |
+
'fa-color-red"></i> requires extra processing time)'
|
| 224 |
+
),
|
| 225 |
+
scale=1,
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
return (process_res_method_dropdown, infer_gs)
|
| 229 |
+
|
| 230 |
+
def create_display_control_section(
|
| 231 |
+
self,
|
| 232 |
+
) -> Tuple[
|
| 233 |
+
gr.Checkbox,
|
| 234 |
+
gr.Checkbox,
|
| 235 |
+
gr.Checkbox,
|
| 236 |
+
gr.Slider,
|
| 237 |
+
gr.Slider,
|
| 238 |
+
gr.Dropdown,
|
| 239 |
+
gr.Dropdown,
|
| 240 |
+
gr.Button,
|
| 241 |
+
gr.ClearButton,
|
| 242 |
+
]:
|
| 243 |
+
"""
|
| 244 |
+
Create the display control section (options for visualization).
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
Tuple of display control components including buttons
|
| 248 |
+
"""
|
| 249 |
+
with gr.Column():
|
| 250 |
+
# 3DGS options at the top
|
| 251 |
+
with gr.Row():
|
| 252 |
+
gs_trj_mode = gr.Dropdown(
|
| 253 |
+
choices=["smooth", "extend"],
|
| 254 |
+
value="smooth",
|
| 255 |
+
label=("Rendering trajectory for 3DGS viewpoints (requires n_views ≥ 2)"),
|
| 256 |
+
info=("'smooth' for view interpolation; 'extend' for longer trajectory"),
|
| 257 |
+
visible=False, # initially hidden
|
| 258 |
+
)
|
| 259 |
+
gs_video_quality = gr.Dropdown(
|
| 260 |
+
choices=["low", "medium", "high"],
|
| 261 |
+
value="low",
|
| 262 |
+
label=("Video quality for 3DGS rendered outputs"),
|
| 263 |
+
info=("'low' for faster loading speed; 'high' for better visual quality"),
|
| 264 |
+
visible=False, # initially hidden
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Reconstruct and Clear buttons (before Visualization Options)
|
| 268 |
+
with gr.Row():
|
| 269 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 270 |
+
clear_btn = gr.ClearButton(scale=1)
|
| 271 |
+
|
| 272 |
+
gr.Markdown("### Visualization Options: (Click Reconstruct to update)")
|
| 273 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 274 |
+
filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
|
| 275 |
+
filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
|
| 276 |
+
save_percentage = gr.Slider(
|
| 277 |
+
minimum=0,
|
| 278 |
+
maximum=100,
|
| 279 |
+
value=10,
|
| 280 |
+
step=1,
|
| 281 |
+
label="Filter Percentage",
|
| 282 |
+
info="Confidence Threshold (%): Higher values filter more points.",
|
| 283 |
+
)
|
| 284 |
+
num_max_points = gr.Slider(
|
| 285 |
+
minimum=1000,
|
| 286 |
+
maximum=100000,
|
| 287 |
+
value=1000,
|
| 288 |
+
step=1000,
|
| 289 |
+
label="Max Points (K points)",
|
| 290 |
+
info="Maximum number of points to export to GLB (in thousands)",
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
return (
|
| 294 |
+
show_cam,
|
| 295 |
+
filter_black_bg,
|
| 296 |
+
filter_white_bg,
|
| 297 |
+
save_percentage,
|
| 298 |
+
num_max_points,
|
| 299 |
+
gs_trj_mode,
|
| 300 |
+
gs_video_quality,
|
| 301 |
+
submit_btn,
|
| 302 |
+
clear_btn,
|
| 303 |
+
)
|
| 304 |
+
|
| 305 |
+
def create_control_section(
|
| 306 |
+
self,
|
| 307 |
+
) -> Tuple[
|
| 308 |
+
gr.Button,
|
| 309 |
+
gr.ClearButton,
|
| 310 |
+
gr.Dropdown,
|
| 311 |
+
gr.Checkbox,
|
| 312 |
+
gr.Checkbox,
|
| 313 |
+
gr.Checkbox,
|
| 314 |
+
gr.Checkbox,
|
| 315 |
+
gr.Checkbox,
|
| 316 |
+
gr.Dropdown,
|
| 317 |
+
gr.Checkbox,
|
| 318 |
+
gr.Textbox,
|
| 319 |
+
]:
|
| 320 |
+
"""
|
| 321 |
+
Create the control section with buttons and options.
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Tuple of control components
|
| 325 |
+
"""
|
| 326 |
+
with gr.Row():
|
| 327 |
+
submit_btn = gr.Button("Reconstruct", scale=1, variant="primary")
|
| 328 |
+
clear_btn = gr.ClearButton(
|
| 329 |
+
scale=1,
|
| 330 |
+
)
|
| 331 |
+
|
| 332 |
+
with gr.Row():
|
| 333 |
+
frame_filter = gr.Dropdown(
|
| 334 |
+
choices=["All"], value="All", label="Show Points from Frame"
|
| 335 |
+
)
|
| 336 |
+
with gr.Column():
|
| 337 |
+
gr.Markdown("### Visualization Option: (Click Reconstruct to update)")
|
| 338 |
+
show_cam = gr.Checkbox(label="Show Camera", value=True)
|
| 339 |
+
show_mesh = gr.Checkbox(label="Show Mesh", value=True)
|
| 340 |
+
filter_black_bg = gr.Checkbox(label="Filter Black Background", value=False)
|
| 341 |
+
filter_white_bg = gr.Checkbox(label="Filter White Background", value=False)
|
| 342 |
+
gr.Markdown("### Reconstruction Options: (updated on next run)")
|
| 343 |
+
apply_mask_checkbox = gr.Checkbox(
|
| 344 |
+
label="Apply mask for predicted ambiguous depth classes & edges",
|
| 345 |
+
value=True,
|
| 346 |
+
)
|
| 347 |
+
process_res_method_dropdown = gr.Dropdown(
|
| 348 |
+
choices=[
|
| 349 |
+
"upper_bound_resize",
|
| 350 |
+
"upper_bound_crop",
|
| 351 |
+
"lower_bound_resize",
|
| 352 |
+
"lower_bound_crop",
|
| 353 |
+
],
|
| 354 |
+
value="upper_bound_resize",
|
| 355 |
+
label="Image Processing Method",
|
| 356 |
+
info="Method for resizing input images",
|
| 357 |
+
)
|
| 358 |
+
save_to_gallery_checkbox = gr.Checkbox(
|
| 359 |
+
label="Save to Gallery",
|
| 360 |
+
value=False,
|
| 361 |
+
info="Save current reconstruction results to gallery directory",
|
| 362 |
+
)
|
| 363 |
+
gallery_name_input = gr.Textbox(
|
| 364 |
+
label="Gallery Name",
|
| 365 |
+
placeholder="Enter a name for the gallery folder",
|
| 366 |
+
value="",
|
| 367 |
+
info="Leave empty for auto-generated name with timestamp",
|
| 368 |
+
)
|
| 369 |
+
|
| 370 |
+
return (
|
| 371 |
+
submit_btn,
|
| 372 |
+
clear_btn,
|
| 373 |
+
frame_filter,
|
| 374 |
+
show_cam,
|
| 375 |
+
show_mesh,
|
| 376 |
+
filter_black_bg,
|
| 377 |
+
filter_white_bg,
|
| 378 |
+
apply_mask_checkbox,
|
| 379 |
+
process_res_method_dropdown,
|
| 380 |
+
save_to_gallery_checkbox,
|
| 381 |
+
gallery_name_input,
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
def create_example_scenes_section(self) -> List[Dict[str, Any]]:
|
| 385 |
+
"""
|
| 386 |
+
Create the example scenes section.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
List of scene information dictionaries
|
| 390 |
+
"""
|
| 391 |
+
# Get workspace directory from environment variable
|
| 392 |
+
workspace_dir = os.environ.get("DA3_WORKSPACE_DIR", "gradio_workspace")
|
| 393 |
+
examples_dir = os.path.join(workspace_dir, "examples")
|
| 394 |
+
|
| 395 |
+
# Get scene information
|
| 396 |
+
scenes = get_scene_info(examples_dir)
|
| 397 |
+
|
| 398 |
+
return scenes
|
| 399 |
+
|
| 400 |
+
def create_example_scene_grid(self, scenes: List[Dict[str, Any]]) -> List[gr.Image]:
|
| 401 |
+
"""
|
| 402 |
+
Create the example scene grid.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
scenes: List of scene information dictionaries
|
| 406 |
+
|
| 407 |
+
Returns:
|
| 408 |
+
List of scene image components
|
| 409 |
+
"""
|
| 410 |
+
scene_components = []
|
| 411 |
+
|
| 412 |
+
if scenes:
|
| 413 |
+
for i in range(0, len(scenes), 4): # Process 4 scenes per row
|
| 414 |
+
with gr.Row():
|
| 415 |
+
for j in range(4):
|
| 416 |
+
scene_idx = i + j
|
| 417 |
+
if scene_idx < len(scenes):
|
| 418 |
+
scene = scenes[scene_idx]
|
| 419 |
+
with gr.Column(scale=1, elem_classes=["clickable-thumbnail"]):
|
| 420 |
+
# Clickable thumbnail
|
| 421 |
+
scene_img = gr.Image(
|
| 422 |
+
value=scene["thumbnail"],
|
| 423 |
+
height=150,
|
| 424 |
+
interactive=False,
|
| 425 |
+
show_label=False,
|
| 426 |
+
elem_id=f"scene_thumb_{scene['name']}",
|
| 427 |
+
sources=[],
|
| 428 |
+
)
|
| 429 |
+
scene_components.append(scene_img)
|
| 430 |
+
|
| 431 |
+
# Scene name and image count as text below thumbnail
|
| 432 |
+
gr.Markdown(
|
| 433 |
+
f"**{scene['name']}** \n {scene['num_images']} images",
|
| 434 |
+
elem_classes=["scene-info"],
|
| 435 |
+
)
|
| 436 |
+
else:
|
| 437 |
+
# Empty column to maintain grid structure
|
| 438 |
+
with gr.Column(scale=1):
|
| 439 |
+
pass
|
| 440 |
+
|
| 441 |
+
return scene_components
|
| 442 |
+
|
| 443 |
+
def create_header_section(self) -> gr.HTML:
|
| 444 |
+
"""
|
| 445 |
+
Create the header section with logo and title.
|
| 446 |
+
|
| 447 |
+
Returns:
|
| 448 |
+
Header HTML component
|
| 449 |
+
"""
|
| 450 |
+
from depth_anything_3.app.css_and_html import get_header_html
|
| 451 |
+
|
| 452 |
+
return gr.HTML(get_header_html(get_logo_base64()))
|
| 453 |
+
|
| 454 |
+
def create_description_section(self) -> gr.HTML:
|
| 455 |
+
"""
|
| 456 |
+
Create the description section.
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
Description HTML component
|
| 460 |
+
"""
|
| 461 |
+
from depth_anything_3.app.css_and_html import get_description_html
|
| 462 |
+
|
| 463 |
+
return gr.HTML(get_description_html())
|
| 464 |
+
|
| 465 |
+
def create_acknowledgements_section(self) -> gr.HTML:
|
| 466 |
+
"""
|
| 467 |
+
Create the acknowledgements section.
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
Acknowledgements HTML component
|
| 471 |
+
"""
|
| 472 |
+
from depth_anything_3.app.css_and_html import get_acknowledgements_html
|
| 473 |
+
|
| 474 |
+
return gr.HTML(get_acknowledgements_html())
|
src/depth_anything_3/app/modules/utils.py
ADDED
|
@@ -0,0 +1,211 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Utility functions for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module contains helper functions for data processing, visualization,
|
| 19 |
+
and file operations.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import gc
|
| 23 |
+
import json
|
| 24 |
+
import os
|
| 25 |
+
import shutil
|
| 26 |
+
from datetime import datetime
|
| 27 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def create_depth_visualization(depth: np.ndarray) -> Optional[np.ndarray]:
|
| 33 |
+
"""
|
| 34 |
+
Create a colored depth visualization.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
depth: Depth array
|
| 38 |
+
|
| 39 |
+
Returns:
|
| 40 |
+
Colored depth visualization or None
|
| 41 |
+
"""
|
| 42 |
+
if depth is None:
|
| 43 |
+
return None
|
| 44 |
+
|
| 45 |
+
# Normalize depth to 0-1 range
|
| 46 |
+
depth_min = depth[depth > 0].min() if (depth > 0).any() else 0
|
| 47 |
+
depth_max = depth.max()
|
| 48 |
+
|
| 49 |
+
if depth_max <= depth_min:
|
| 50 |
+
return None
|
| 51 |
+
|
| 52 |
+
# Normalize depth
|
| 53 |
+
depth_norm = (depth - depth_min) / (depth_max - depth_min)
|
| 54 |
+
depth_norm = np.clip(depth_norm, 0, 1)
|
| 55 |
+
|
| 56 |
+
# Apply colormap (using matplotlib's viridis colormap)
|
| 57 |
+
import matplotlib.cm as cm
|
| 58 |
+
|
| 59 |
+
# Convert to colored image
|
| 60 |
+
depth_colored = cm.viridis(depth_norm)[:, :, :3] # Remove alpha channel
|
| 61 |
+
depth_colored = (depth_colored * 255).astype(np.uint8)
|
| 62 |
+
|
| 63 |
+
return depth_colored
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def save_to_gallery_func(
|
| 67 |
+
target_dir: str, processed_data: Dict[int, Dict[str, Any]], gallery_name: Optional[str] = None
|
| 68 |
+
) -> Tuple[bool, str]:
|
| 69 |
+
"""
|
| 70 |
+
Save the current reconstruction results to the gallery directory.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
target_dir: Source directory containing reconstruction results
|
| 74 |
+
processed_data: Processed data dictionary
|
| 75 |
+
gallery_name: Name for the gallery folder
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Tuple of (success, message)
|
| 79 |
+
"""
|
| 80 |
+
try:
|
| 81 |
+
# Get gallery directory from environment variable or use default
|
| 82 |
+
gallery_dir = os.environ.get(
|
| 83 |
+
"DA3_GALLERY_DIR",
|
| 84 |
+
"workspace/gallery",
|
| 85 |
+
)
|
| 86 |
+
if not os.path.exists(gallery_dir):
|
| 87 |
+
os.makedirs(gallery_dir)
|
| 88 |
+
|
| 89 |
+
# Use provided name or create a unique name
|
| 90 |
+
if gallery_name is None or gallery_name.strip() == "":
|
| 91 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 92 |
+
gallery_name = f"reconstruction_{timestamp}"
|
| 93 |
+
|
| 94 |
+
gallery_path = os.path.join(gallery_dir, gallery_name)
|
| 95 |
+
|
| 96 |
+
# Check if directory already exists
|
| 97 |
+
if os.path.exists(gallery_path):
|
| 98 |
+
return False, f"Save failed: folder '{gallery_name}' already exists"
|
| 99 |
+
|
| 100 |
+
# Create the gallery directory
|
| 101 |
+
os.makedirs(gallery_path, exist_ok=True)
|
| 102 |
+
|
| 103 |
+
# Copy GLB file
|
| 104 |
+
glb_source = os.path.join(target_dir, "scene.glb")
|
| 105 |
+
glb_dest = os.path.join(gallery_path, "scene.glb")
|
| 106 |
+
if os.path.exists(glb_source):
|
| 107 |
+
shutil.copy2(glb_source, glb_dest)
|
| 108 |
+
|
| 109 |
+
# Copy depth visualization images
|
| 110 |
+
depth_vis_dir = os.path.join(target_dir, "depth_vis")
|
| 111 |
+
if os.path.exists(depth_vis_dir):
|
| 112 |
+
gallery_depth_vis = os.path.join(gallery_path, "depth_vis")
|
| 113 |
+
shutil.copytree(depth_vis_dir, gallery_depth_vis)
|
| 114 |
+
|
| 115 |
+
# Copy original images
|
| 116 |
+
images_source = os.path.join(target_dir, "images")
|
| 117 |
+
if os.path.exists(images_source):
|
| 118 |
+
gallery_images = os.path.join(gallery_path, "images")
|
| 119 |
+
shutil.copytree(images_source, gallery_images)
|
| 120 |
+
|
| 121 |
+
scene_preview_source = os.path.join(target_dir, "scene.jpg")
|
| 122 |
+
scene_preview_dest = os.path.join(gallery_path, "scene.jpg")
|
| 123 |
+
shutil.copy2(scene_preview_source, scene_preview_dest)
|
| 124 |
+
|
| 125 |
+
# Save metadata
|
| 126 |
+
metadata = {
|
| 127 |
+
"timestamp": datetime.now().strftime("%Y%m%d_%H%M%S"),
|
| 128 |
+
"num_images": len(processed_data) if processed_data else 0,
|
| 129 |
+
"gallery_name": gallery_name,
|
| 130 |
+
}
|
| 131 |
+
|
| 132 |
+
with open(os.path.join(gallery_path, "metadata.json"), "w") as f:
|
| 133 |
+
json.dump(metadata, f, indent=2)
|
| 134 |
+
|
| 135 |
+
print(f"Saved reconstruction to gallery: {gallery_path}")
|
| 136 |
+
return True, f"Save successful: saved to {gallery_path}"
|
| 137 |
+
|
| 138 |
+
except Exception as e:
|
| 139 |
+
print(f"Error saving to gallery: {e}")
|
| 140 |
+
return False, f"Save failed: {str(e)}"
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def get_scene_info(examples_dir: str) -> List[Dict[str, Any]]:
|
| 144 |
+
"""
|
| 145 |
+
Get information about scenes in the examples directory.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
examples_dir: Path to examples directory
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
List of scene information dictionaries
|
| 152 |
+
"""
|
| 153 |
+
import glob
|
| 154 |
+
|
| 155 |
+
scenes = []
|
| 156 |
+
if not os.path.exists(examples_dir):
|
| 157 |
+
return scenes
|
| 158 |
+
|
| 159 |
+
for scene_folder in sorted(os.listdir(examples_dir)):
|
| 160 |
+
scene_path = os.path.join(examples_dir, scene_folder)
|
| 161 |
+
if os.path.isdir(scene_path):
|
| 162 |
+
# Find all image files in the scene folder
|
| 163 |
+
image_extensions = ["*.jpg", "*.jpeg", "*.png", "*.bmp", "*.tiff", "*.tif"]
|
| 164 |
+
image_files = []
|
| 165 |
+
for ext in image_extensions:
|
| 166 |
+
image_files.extend(glob.glob(os.path.join(scene_path, ext)))
|
| 167 |
+
image_files.extend(glob.glob(os.path.join(scene_path, ext.upper())))
|
| 168 |
+
|
| 169 |
+
if image_files:
|
| 170 |
+
# Sort images and get the first one for thumbnail
|
| 171 |
+
image_files = sorted(image_files)
|
| 172 |
+
first_image = image_files[0]
|
| 173 |
+
num_images = len(image_files)
|
| 174 |
+
|
| 175 |
+
scenes.append(
|
| 176 |
+
{
|
| 177 |
+
"name": scene_folder,
|
| 178 |
+
"path": scene_path,
|
| 179 |
+
"thumbnail": first_image,
|
| 180 |
+
"num_images": num_images,
|
| 181 |
+
"image_files": image_files,
|
| 182 |
+
}
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return scenes
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def cleanup_memory() -> None:
|
| 189 |
+
"""Clean up GPU memory and garbage collect."""
|
| 190 |
+
gc.collect()
|
| 191 |
+
if torch.cuda.is_available():
|
| 192 |
+
torch.cuda.empty_cache()
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get_logo_base64() -> Optional[str]:
|
| 196 |
+
"""
|
| 197 |
+
Convert WAI logo to base64 for embedding in HTML.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Base64 encoded logo string or None
|
| 201 |
+
"""
|
| 202 |
+
import base64
|
| 203 |
+
|
| 204 |
+
logo_path = "examples/WAI-Logo/wai_logo.png"
|
| 205 |
+
try:
|
| 206 |
+
with open(logo_path, "rb") as img_file:
|
| 207 |
+
img_data = img_file.read()
|
| 208 |
+
base64_str = base64.b64encode(img_data).decode()
|
| 209 |
+
return f"data:image/png;base64,{base64_str}"
|
| 210 |
+
except FileNotFoundError:
|
| 211 |
+
return None
|
src/depth_anything_3/app/modules/visualization.py
ADDED
|
@@ -0,0 +1,434 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Visualization module for Depth Anything 3 Gradio app.
|
| 17 |
+
|
| 18 |
+
This module handles visualization updates, navigation, and measurement functionality.
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os
|
| 22 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 23 |
+
import cv2
|
| 24 |
+
import gradio as gr
|
| 25 |
+
import numpy as np
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class VisualizationHandler:
|
| 29 |
+
"""
|
| 30 |
+
Handles visualization updates and navigation for the Gradio app.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
def __init__(self):
|
| 34 |
+
"""Initialize the visualization handler."""
|
| 35 |
+
|
| 36 |
+
def update_view_selectors(
|
| 37 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
|
| 38 |
+
) -> Tuple[gr.Dropdown, gr.Dropdown]:
|
| 39 |
+
"""
|
| 40 |
+
Update view selector dropdowns based on available views.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
processed_data: Processed data dictionary
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
Tuple of (depth_view_selector, measure_view_selector)
|
| 47 |
+
"""
|
| 48 |
+
if processed_data is None or len(processed_data) == 0:
|
| 49 |
+
choices = ["View 1"]
|
| 50 |
+
else:
|
| 51 |
+
num_views = len(processed_data)
|
| 52 |
+
choices = [f"View {i + 1}" for i in range(num_views)]
|
| 53 |
+
|
| 54 |
+
return (
|
| 55 |
+
gr.Dropdown(choices=choices, value=choices[0]), # depth_view_selector
|
| 56 |
+
gr.Dropdown(choices=choices, value=choices[0]), # measure_view_selector
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def get_view_data_by_index(
|
| 60 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 61 |
+
) -> Optional[Dict[str, Any]]:
|
| 62 |
+
"""
|
| 63 |
+
Get view data by index, handling bounds.
|
| 64 |
+
|
| 65 |
+
Args:
|
| 66 |
+
processed_data: Processed data dictionary
|
| 67 |
+
view_index: Index of the view to get
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
View data dictionary or None
|
| 71 |
+
"""
|
| 72 |
+
if processed_data is None or len(processed_data) == 0:
|
| 73 |
+
return None
|
| 74 |
+
|
| 75 |
+
view_keys = list(processed_data.keys())
|
| 76 |
+
if view_index < 0 or view_index >= len(view_keys):
|
| 77 |
+
view_index = 0
|
| 78 |
+
|
| 79 |
+
return processed_data[view_keys[view_index]]
|
| 80 |
+
|
| 81 |
+
def update_depth_view(
|
| 82 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 83 |
+
) -> Optional[str]:
|
| 84 |
+
"""
|
| 85 |
+
Update depth view for a specific view index.
|
| 86 |
+
|
| 87 |
+
Args:
|
| 88 |
+
processed_data: Processed data dictionary
|
| 89 |
+
view_index: Index of the view to update
|
| 90 |
+
|
| 91 |
+
Returns:
|
| 92 |
+
Path to depth visualization image or None
|
| 93 |
+
"""
|
| 94 |
+
view_data = self.get_view_data_by_index(processed_data, view_index)
|
| 95 |
+
if view_data is None or view_data.get("depth_image") is None:
|
| 96 |
+
return None
|
| 97 |
+
|
| 98 |
+
# Return the depth visualization image directly
|
| 99 |
+
return view_data["depth_image"]
|
| 100 |
+
|
| 101 |
+
def navigate_depth_view(
|
| 102 |
+
self,
|
| 103 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 104 |
+
current_selector_value: str,
|
| 105 |
+
direction: int,
|
| 106 |
+
) -> Tuple[str, Optional[str]]:
|
| 107 |
+
"""
|
| 108 |
+
Navigate depth view (direction: -1 for previous, +1 for next).
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
processed_data: Processed data dictionary
|
| 112 |
+
current_selector_value: Current selector value
|
| 113 |
+
direction: Direction to navigate (-1 for previous, +1 for next)
|
| 114 |
+
|
| 115 |
+
Returns:
|
| 116 |
+
Tuple of (new_selector_value, depth_vis)
|
| 117 |
+
"""
|
| 118 |
+
if processed_data is None or len(processed_data) == 0:
|
| 119 |
+
return "View 1", None
|
| 120 |
+
|
| 121 |
+
# Parse current view number
|
| 122 |
+
try:
|
| 123 |
+
current_view = int(current_selector_value.split()[1]) - 1
|
| 124 |
+
except: # noqa
|
| 125 |
+
current_view = 0
|
| 126 |
+
|
| 127 |
+
num_views = len(processed_data)
|
| 128 |
+
new_view = (current_view + direction) % num_views
|
| 129 |
+
|
| 130 |
+
new_selector_value = f"View {new_view + 1}"
|
| 131 |
+
depth_vis = self.update_depth_view(processed_data, new_view)
|
| 132 |
+
|
| 133 |
+
return new_selector_value, depth_vis
|
| 134 |
+
|
| 135 |
+
def update_measure_view(
|
| 136 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]], view_index: int
|
| 137 |
+
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List]:
|
| 138 |
+
"""
|
| 139 |
+
Update measure view for a specific view index.
|
| 140 |
+
|
| 141 |
+
Args:
|
| 142 |
+
processed_data: Processed data dictionary
|
| 143 |
+
view_index: Index of the view to update
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Tuple of (measure_image, depth_right_half, measure_points)
|
| 147 |
+
"""
|
| 148 |
+
view_data = self.get_view_data_by_index(processed_data, view_index)
|
| 149 |
+
if view_data is None:
|
| 150 |
+
return None, None, [] # image, depth_right_half, measure_points
|
| 151 |
+
|
| 152 |
+
# Get the processed (resized) image
|
| 153 |
+
if "image" in view_data and view_data["image"] is not None:
|
| 154 |
+
image = view_data["image"].copy()
|
| 155 |
+
else:
|
| 156 |
+
return None, None, []
|
| 157 |
+
|
| 158 |
+
# Ensure image is in uint8 format
|
| 159 |
+
if image.dtype != np.uint8:
|
| 160 |
+
if image.max() <= 1.0:
|
| 161 |
+
image = (image * 255).astype(np.uint8)
|
| 162 |
+
else:
|
| 163 |
+
image = image.astype(np.uint8)
|
| 164 |
+
|
| 165 |
+
# Extract right half of the depth visualization (pure depth part)
|
| 166 |
+
depth_image_path = view_data.get("depth_image", None)
|
| 167 |
+
depth_right_half = None
|
| 168 |
+
|
| 169 |
+
if depth_image_path and os.path.exists(depth_image_path):
|
| 170 |
+
try:
|
| 171 |
+
# Load the combined depth visualization image
|
| 172 |
+
depth_combined = cv2.imread(depth_image_path)
|
| 173 |
+
depth_combined = cv2.cvtColor(depth_combined, cv2.COLOR_BGR2RGB)
|
| 174 |
+
if depth_combined is not None:
|
| 175 |
+
height, width = depth_combined.shape[:2]
|
| 176 |
+
# Extract right half (depth visualization part)
|
| 177 |
+
depth_right_half = depth_combined[:, width // 2 :]
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"Error extracting depth right half: {e}")
|
| 180 |
+
|
| 181 |
+
return image, depth_right_half, []
|
| 182 |
+
|
| 183 |
+
def navigate_measure_view(
|
| 184 |
+
self,
|
| 185 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 186 |
+
current_selector_value: str,
|
| 187 |
+
direction: int,
|
| 188 |
+
) -> Tuple[str, Optional[np.ndarray], Optional[str], List]:
|
| 189 |
+
"""
|
| 190 |
+
Navigate measure view (direction: -1 for previous, +1 for next).
|
| 191 |
+
|
| 192 |
+
Args:
|
| 193 |
+
processed_data: Processed data dictionary
|
| 194 |
+
current_selector_value: Current selector value
|
| 195 |
+
direction: Direction to navigate (-1 for previous, +1 for next)
|
| 196 |
+
|
| 197 |
+
Returns:
|
| 198 |
+
Tuple of (new_selector_value, measure_image, depth_image_path, measure_points)
|
| 199 |
+
"""
|
| 200 |
+
if processed_data is None or len(processed_data) == 0:
|
| 201 |
+
return "View 1", None, None, []
|
| 202 |
+
|
| 203 |
+
# Parse current view number
|
| 204 |
+
try:
|
| 205 |
+
current_view = int(current_selector_value.split()[1]) - 1
|
| 206 |
+
except: # noqa
|
| 207 |
+
current_view = 0
|
| 208 |
+
|
| 209 |
+
num_views = len(processed_data)
|
| 210 |
+
new_view = (current_view + direction) % num_views
|
| 211 |
+
|
| 212 |
+
new_selector_value = f"View {new_view + 1}"
|
| 213 |
+
measure_image, depth_right_half, measure_points = self.update_measure_view(
|
| 214 |
+
processed_data, new_view
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
return new_selector_value, measure_image, depth_right_half, measure_points
|
| 218 |
+
|
| 219 |
+
def populate_visualization_tabs(
|
| 220 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
|
| 221 |
+
) -> Tuple[Optional[str], Optional[np.ndarray], Optional[str], List]:
|
| 222 |
+
"""
|
| 223 |
+
Populate the depth and measure tabs with processed data.
|
| 224 |
+
|
| 225 |
+
Args:
|
| 226 |
+
processed_data: Processed data dictionary
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
Tuple of (depth_vis, measure_img, depth_image_path, measure_points)
|
| 230 |
+
"""
|
| 231 |
+
if processed_data is None or len(processed_data) == 0:
|
| 232 |
+
return None, None, None, []
|
| 233 |
+
|
| 234 |
+
# Use update function to get depth visualization
|
| 235 |
+
depth_vis = self.update_depth_view(processed_data, 0)
|
| 236 |
+
measure_img, depth_right_half, _ = self.update_measure_view(processed_data, 0)
|
| 237 |
+
|
| 238 |
+
return depth_vis, measure_img, depth_right_half, []
|
| 239 |
+
|
| 240 |
+
def reset_measure(
|
| 241 |
+
self, processed_data: Optional[Dict[int, Dict[str, Any]]]
|
| 242 |
+
) -> Tuple[Optional[np.ndarray], List, str]:
|
| 243 |
+
"""
|
| 244 |
+
Reset measure points.
|
| 245 |
+
|
| 246 |
+
Args:
|
| 247 |
+
processed_data: Processed data dictionary
|
| 248 |
+
|
| 249 |
+
Returns:
|
| 250 |
+
Tuple of (image, measure_points, text)
|
| 251 |
+
"""
|
| 252 |
+
if processed_data is None or len(processed_data) == 0:
|
| 253 |
+
return None, [], ""
|
| 254 |
+
|
| 255 |
+
# Return the first view image
|
| 256 |
+
first_view = list(processed_data.values())[0]
|
| 257 |
+
return first_view["image"], [], ""
|
| 258 |
+
|
| 259 |
+
def measure(
|
| 260 |
+
self,
|
| 261 |
+
processed_data: Optional[Dict[int, Dict[str, Any]]],
|
| 262 |
+
measure_points: List,
|
| 263 |
+
current_view_selector: str,
|
| 264 |
+
event: gr.SelectData,
|
| 265 |
+
) -> List:
|
| 266 |
+
"""
|
| 267 |
+
Handle measurement on images.
|
| 268 |
+
|
| 269 |
+
Args:
|
| 270 |
+
processed_data: Processed data dictionary
|
| 271 |
+
measure_points: List of current measure points
|
| 272 |
+
current_view_selector: Current view selector value
|
| 273 |
+
event: Gradio select event
|
| 274 |
+
|
| 275 |
+
Returns:
|
| 276 |
+
List of [image, depth_right_half, measure_points, text]
|
| 277 |
+
"""
|
| 278 |
+
try:
|
| 279 |
+
print(f"Measure function called with selector: {current_view_selector}")
|
| 280 |
+
|
| 281 |
+
if processed_data is None or len(processed_data) == 0:
|
| 282 |
+
return [None, [], "No data available"]
|
| 283 |
+
|
| 284 |
+
# Use the currently selected view instead of always using the first view
|
| 285 |
+
try:
|
| 286 |
+
current_view_index = int(current_view_selector.split()[1]) - 1
|
| 287 |
+
except: # noqa
|
| 288 |
+
current_view_index = 0
|
| 289 |
+
|
| 290 |
+
print(f"Using view index: {current_view_index}")
|
| 291 |
+
|
| 292 |
+
# Get view data safely
|
| 293 |
+
if current_view_index < 0 or current_view_index >= len(processed_data):
|
| 294 |
+
current_view_index = 0
|
| 295 |
+
|
| 296 |
+
view_keys = list(processed_data.keys())
|
| 297 |
+
current_view = processed_data[view_keys[current_view_index]]
|
| 298 |
+
|
| 299 |
+
if current_view is None:
|
| 300 |
+
return [None, [], "No view data available"]
|
| 301 |
+
|
| 302 |
+
point2d = event.index[0], event.index[1]
|
| 303 |
+
print(f"Clicked point: {point2d}")
|
| 304 |
+
|
| 305 |
+
measure_points.append(point2d)
|
| 306 |
+
|
| 307 |
+
# Get image and depth visualization
|
| 308 |
+
image, depth_right_half, _ = self.update_measure_view(
|
| 309 |
+
processed_data, current_view_index
|
| 310 |
+
)
|
| 311 |
+
if image is None:
|
| 312 |
+
return [None, [], "No image available"]
|
| 313 |
+
|
| 314 |
+
image = image.copy()
|
| 315 |
+
|
| 316 |
+
# Ensure image is in uint8 format for proper cv2 operations
|
| 317 |
+
try:
|
| 318 |
+
if image.dtype != np.uint8:
|
| 319 |
+
if image.max() <= 1.0:
|
| 320 |
+
# Image is in [0, 1] range, convert to [0, 255]
|
| 321 |
+
image = (image * 255).astype(np.uint8)
|
| 322 |
+
else:
|
| 323 |
+
# Image is already in [0, 255] range
|
| 324 |
+
image = image.astype(np.uint8)
|
| 325 |
+
except Exception as e:
|
| 326 |
+
print(f"Image conversion error: {e}")
|
| 327 |
+
return [None, [], f"Image conversion error: {e}"]
|
| 328 |
+
|
| 329 |
+
# Draw circles for points
|
| 330 |
+
try:
|
| 331 |
+
for p in measure_points:
|
| 332 |
+
if 0 <= p[0] < image.shape[1] and 0 <= p[1] < image.shape[0]:
|
| 333 |
+
image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2)
|
| 334 |
+
except Exception as e:
|
| 335 |
+
print(f"Drawing error: {e}")
|
| 336 |
+
return [None, [], f"Drawing error: {e}"]
|
| 337 |
+
|
| 338 |
+
# Get depth information from processed_data
|
| 339 |
+
depth_text = ""
|
| 340 |
+
try:
|
| 341 |
+
for i, p in enumerate(measure_points):
|
| 342 |
+
if (
|
| 343 |
+
current_view["depth"] is not None
|
| 344 |
+
and 0 <= p[1] < current_view["depth"].shape[0]
|
| 345 |
+
and 0 <= p[0] < current_view["depth"].shape[1]
|
| 346 |
+
):
|
| 347 |
+
d = current_view["depth"][p[1], p[0]]
|
| 348 |
+
depth_text += f"- **P{i + 1} depth: {d:.2f}m**\n"
|
| 349 |
+
else:
|
| 350 |
+
depth_text += f"- **P{i + 1}: Click position ({p[0]}, {p[1]}) - No depth information**\n" # noqa: E501
|
| 351 |
+
except Exception as e:
|
| 352 |
+
print(f"Depth text error: {e}")
|
| 353 |
+
depth_text = f"Error computing depth: {e}\n"
|
| 354 |
+
|
| 355 |
+
if len(measure_points) == 2:
|
| 356 |
+
try:
|
| 357 |
+
point1, point2 = measure_points
|
| 358 |
+
# Draw line
|
| 359 |
+
if (
|
| 360 |
+
0 <= point1[0] < image.shape[1]
|
| 361 |
+
and 0 <= point1[1] < image.shape[0]
|
| 362 |
+
and 0 <= point2[0] < image.shape[1]
|
| 363 |
+
and 0 <= point2[1] < image.shape[0]
|
| 364 |
+
):
|
| 365 |
+
image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2)
|
| 366 |
+
|
| 367 |
+
# Compute 3D distance using depth information and camera intrinsics
|
| 368 |
+
distance_text = "- **Distance: Unable to calculate 3D distance**"
|
| 369 |
+
if (
|
| 370 |
+
current_view["depth"] is not None
|
| 371 |
+
and 0 <= point1[1] < current_view["depth"].shape[0]
|
| 372 |
+
and 0 <= point1[0] < current_view["depth"].shape[1]
|
| 373 |
+
and 0 <= point2[1] < current_view["depth"].shape[0]
|
| 374 |
+
and 0 <= point2[0] < current_view["depth"].shape[1]
|
| 375 |
+
):
|
| 376 |
+
try:
|
| 377 |
+
# Get depth values at the two points
|
| 378 |
+
d1 = current_view["depth"][point1[1], point1[0]]
|
| 379 |
+
d2 = current_view["depth"][point2[1], point2[0]]
|
| 380 |
+
|
| 381 |
+
# Convert 2D pixel coordinates to 3D world coordinates
|
| 382 |
+
if current_view["intrinsics"] is not None:
|
| 383 |
+
# Get camera intrinsics
|
| 384 |
+
K = current_view["intrinsics"] # 3x3 intrinsic matrix
|
| 385 |
+
fx, fy = K[0, 0], K[1, 1] # focal lengths
|
| 386 |
+
cx, cy = K[0, 2], K[1, 2] # principal point
|
| 387 |
+
|
| 388 |
+
# Convert pixel coordinates to normalized camera coordinates
|
| 389 |
+
# Point 1: (u1, v1) -> (x1, y1, z1)
|
| 390 |
+
u1, v1 = point1[0], point1[1]
|
| 391 |
+
x1 = (u1 - cx) * d1 / fx
|
| 392 |
+
y1 = (v1 - cy) * d1 / fy
|
| 393 |
+
z1 = d1
|
| 394 |
+
|
| 395 |
+
# Point 2: (u2, v2) -> (x2, y2, z2)
|
| 396 |
+
u2, v2 = point2[0], point2[1]
|
| 397 |
+
x2 = (u2 - cx) * d2 / fx
|
| 398 |
+
y2 = (v2 - cy) * d2 / fy
|
| 399 |
+
z2 = d2
|
| 400 |
+
|
| 401 |
+
# Calculate 3D Euclidean distance
|
| 402 |
+
p1_3d = np.array([x1, y1, z1])
|
| 403 |
+
p2_3d = np.array([x2, y2, z2])
|
| 404 |
+
distance_3d = np.linalg.norm(p1_3d - p2_3d)
|
| 405 |
+
|
| 406 |
+
distance_text = f"- **Distance: {distance_3d:.2f}m**"
|
| 407 |
+
else:
|
| 408 |
+
# Fallback to simplified calculation if no intrinsics
|
| 409 |
+
pixel_distance = np.sqrt(
|
| 410 |
+
(point1[0] - point2[0]) ** 2 + (point1[1] - point2[1]) ** 2
|
| 411 |
+
)
|
| 412 |
+
avg_depth = (d1 + d2) / 2
|
| 413 |
+
scale_factor = avg_depth / 1000 # Rough scaling factor
|
| 414 |
+
estimated_3d_distance = pixel_distance * scale_factor
|
| 415 |
+
distance_text = f"- **Distance: {estimated_3d_distance:.2f}m (estimated, no intrinsics)**" # noqa: E501
|
| 416 |
+
|
| 417 |
+
except Exception as e:
|
| 418 |
+
print(f"Distance computation error: {e}")
|
| 419 |
+
distance_text = f"- **Distance computation error: {e}**"
|
| 420 |
+
|
| 421 |
+
measure_points = []
|
| 422 |
+
text = depth_text + distance_text
|
| 423 |
+
print(f"Measurement complete: {text}")
|
| 424 |
+
return [image, depth_right_half, measure_points, text]
|
| 425 |
+
except Exception as e:
|
| 426 |
+
print(f"Final measurement error: {e}")
|
| 427 |
+
return [None, [], f"Measurement error: {e}"]
|
| 428 |
+
else:
|
| 429 |
+
print(f"Single point measurement: {depth_text}")
|
| 430 |
+
return [image, depth_right_half, measure_points, depth_text]
|
| 431 |
+
|
| 432 |
+
except Exception as e:
|
| 433 |
+
print(f"Overall measure function error: {e}")
|
| 434 |
+
return [None, [], f"Measure function error: {e}"]
|
src/depth_anything_3/cfg.py
ADDED
|
@@ -0,0 +1,144 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Configuration utility functions
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
import importlib
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any, Callable, List, Union
|
| 22 |
+
from omegaconf import DictConfig, ListConfig, OmegaConf
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
OmegaConf.register_new_resolver("eval", eval)
|
| 26 |
+
except Exception as e:
|
| 27 |
+
# if eval is not available, we can just pass
|
| 28 |
+
print(f"Error registering eval resolver: {e}")
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def load_config(path: str, argv: List[str] = None) -> Union[DictConfig, ListConfig]:
|
| 32 |
+
"""
|
| 33 |
+
Load a configuration. Will resolve inheritance.
|
| 34 |
+
Supports both file paths and module paths (e.g., depth_anything_3.configs.giant).
|
| 35 |
+
"""
|
| 36 |
+
# Check if path is a module path (contains dots but no slashes and doesn't end with .yaml)
|
| 37 |
+
if "." in path and "/" not in path and not path.endswith(".yaml"):
|
| 38 |
+
# It's a module path, load from package resources
|
| 39 |
+
path_parts = path.split(".")[1:]
|
| 40 |
+
config_path = Path(__file__).resolve().parent
|
| 41 |
+
for part in path_parts:
|
| 42 |
+
config_path = config_path.joinpath(part)
|
| 43 |
+
config_path = config_path.with_suffix(".yaml")
|
| 44 |
+
config = OmegaConf.load(str(config_path))
|
| 45 |
+
else:
|
| 46 |
+
# It's a file path (absolute, relative, or with .yaml extension)
|
| 47 |
+
config = OmegaConf.load(path)
|
| 48 |
+
|
| 49 |
+
if argv is not None:
|
| 50 |
+
config_argv = OmegaConf.from_dotlist(argv)
|
| 51 |
+
config = OmegaConf.merge(config, config_argv)
|
| 52 |
+
config = resolve_recursive(config, resolve_inheritance)
|
| 53 |
+
return config
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def resolve_recursive(
|
| 57 |
+
config: Any,
|
| 58 |
+
resolver: Callable[[Union[DictConfig, ListConfig]], Union[DictConfig, ListConfig]],
|
| 59 |
+
) -> Any:
|
| 60 |
+
config = resolver(config)
|
| 61 |
+
if isinstance(config, DictConfig):
|
| 62 |
+
for k in config.keys():
|
| 63 |
+
v = config.get(k)
|
| 64 |
+
if isinstance(v, (DictConfig, ListConfig)):
|
| 65 |
+
config[k] = resolve_recursive(v, resolver)
|
| 66 |
+
if isinstance(config, ListConfig):
|
| 67 |
+
for i in range(len(config)):
|
| 68 |
+
v = config.get(i)
|
| 69 |
+
if isinstance(v, (DictConfig, ListConfig)):
|
| 70 |
+
config[i] = resolve_recursive(v, resolver)
|
| 71 |
+
return config
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def resolve_inheritance(config: Union[DictConfig, ListConfig]) -> Any:
|
| 75 |
+
"""
|
| 76 |
+
Recursively resolve inheritance if the config contains:
|
| 77 |
+
__inherit__: path/to/parent.yaml or a ListConfig of such paths.
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(config, DictConfig):
|
| 80 |
+
inherit = config.pop("__inherit__", None)
|
| 81 |
+
|
| 82 |
+
if inherit:
|
| 83 |
+
inherit_list = inherit if isinstance(inherit, ListConfig) else [inherit]
|
| 84 |
+
|
| 85 |
+
parent_config = None
|
| 86 |
+
for parent_path in inherit_list:
|
| 87 |
+
assert isinstance(parent_path, str)
|
| 88 |
+
parent_config = (
|
| 89 |
+
load_config(parent_path)
|
| 90 |
+
if parent_config is None
|
| 91 |
+
else OmegaConf.merge(parent_config, load_config(parent_path))
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
if len(config.keys()) > 0:
|
| 95 |
+
config = OmegaConf.merge(parent_config, config)
|
| 96 |
+
else:
|
| 97 |
+
config = parent_config
|
| 98 |
+
return config
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def import_item(path: str, name: str) -> Any:
|
| 102 |
+
"""
|
| 103 |
+
Import a python item. Example: import_item("path.to.file", "MyClass") -> MyClass
|
| 104 |
+
"""
|
| 105 |
+
return getattr(importlib.import_module(path), name)
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def create_object(config: DictConfig) -> Any:
|
| 109 |
+
"""
|
| 110 |
+
Create an object from config.
|
| 111 |
+
The config is expected to contains the following:
|
| 112 |
+
__object__:
|
| 113 |
+
path: path.to.module
|
| 114 |
+
name: MyClass
|
| 115 |
+
args: as_config | as_params (default to as_config)
|
| 116 |
+
"""
|
| 117 |
+
config = DictConfig(config)
|
| 118 |
+
item = import_item(
|
| 119 |
+
path=config.__object__.path,
|
| 120 |
+
name=config.__object__.name,
|
| 121 |
+
)
|
| 122 |
+
args = config.__object__.get("args", "as_config")
|
| 123 |
+
if args == "as_config":
|
| 124 |
+
return item(config)
|
| 125 |
+
if args == "as_params":
|
| 126 |
+
config = OmegaConf.to_object(config)
|
| 127 |
+
config.pop("__object__")
|
| 128 |
+
return item(**config)
|
| 129 |
+
raise NotImplementedError(f"Unknown args type: {args}")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def create_dataset(path: str, *args, **kwargs) -> Any:
|
| 133 |
+
"""
|
| 134 |
+
Create a dataset. Requires the file to contain a "create_dataset" function.
|
| 135 |
+
"""
|
| 136 |
+
return import_item(path, "create_dataset")(*args, **kwargs)
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def to_dict_recursive(config_obj):
|
| 140 |
+
if isinstance(config_obj, DictConfig):
|
| 141 |
+
return {k: to_dict_recursive(v) for k, v in config_obj.items()}
|
| 142 |
+
elif isinstance(config_obj, ListConfig):
|
| 143 |
+
return [to_dict_recursive(item) for item in config_obj]
|
| 144 |
+
return config_obj
|
src/depth_anything_3/cli.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: E402
|
| 2 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""
|
| 16 |
+
Refactored Depth Anything 3 CLI
|
| 17 |
+
Clean, modular command-line interface
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
from __future__ import annotations
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import typer
|
| 24 |
+
|
| 25 |
+
from depth_anything_3.services import start_server
|
| 26 |
+
from depth_anything_3.services.gallery import gallery as gallery_main
|
| 27 |
+
from depth_anything_3.services.inference_service import run_inference
|
| 28 |
+
from depth_anything_3.services.input_handlers import (
|
| 29 |
+
ColmapHandler,
|
| 30 |
+
ImageHandler,
|
| 31 |
+
ImagesHandler,
|
| 32 |
+
InputHandler,
|
| 33 |
+
VideoHandler,
|
| 34 |
+
parse_export_feat,
|
| 35 |
+
)
|
| 36 |
+
from depth_anything_3.utils.constants import DEFAULT_EXPORT_DIR, DEFAULT_GALLERY_DIR, DEFAULT_GRADIO_DIR, DEFAULT_MODEL
|
| 37 |
+
|
| 38 |
+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 39 |
+
|
| 40 |
+
app = typer.Typer(help="Depth Anything 3 - Video depth estimation CLI", add_completion=False)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
# ============================================================================
|
| 44 |
+
# Input type detection utilities
|
| 45 |
+
# ============================================================================
|
| 46 |
+
|
| 47 |
+
# Supported file extensions
|
| 48 |
+
IMAGE_EXTENSIONS = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tiff", ".tif"}
|
| 49 |
+
VIDEO_EXTENSIONS = {".mp4", ".avi", ".mov", ".mkv", ".flv", ".wmv", ".webm", ".m4v"}
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def detect_input_type(input_path: str) -> str:
|
| 53 |
+
"""
|
| 54 |
+
Detect input type from path.
|
| 55 |
+
|
| 56 |
+
Returns:
|
| 57 |
+
- "image": Single image file
|
| 58 |
+
- "images": Directory containing images
|
| 59 |
+
- "video": Video file
|
| 60 |
+
- "colmap": COLMAP directory structure
|
| 61 |
+
- "unknown": Cannot determine type
|
| 62 |
+
"""
|
| 63 |
+
if not os.path.exists(input_path):
|
| 64 |
+
return "unknown"
|
| 65 |
+
|
| 66 |
+
# Check if it's a file
|
| 67 |
+
if os.path.isfile(input_path):
|
| 68 |
+
ext = os.path.splitext(input_path)[1].lower()
|
| 69 |
+
if ext in IMAGE_EXTENSIONS:
|
| 70 |
+
return "image"
|
| 71 |
+
elif ext in VIDEO_EXTENSIONS:
|
| 72 |
+
return "video"
|
| 73 |
+
return "unknown"
|
| 74 |
+
|
| 75 |
+
# Check if it's a directory
|
| 76 |
+
if os.path.isdir(input_path):
|
| 77 |
+
# Check for COLMAP structure
|
| 78 |
+
images_dir = os.path.join(input_path, "images")
|
| 79 |
+
sparse_dir = os.path.join(input_path, "sparse")
|
| 80 |
+
|
| 81 |
+
if os.path.isdir(images_dir) and os.path.isdir(sparse_dir):
|
| 82 |
+
return "colmap"
|
| 83 |
+
|
| 84 |
+
# Check if directory contains image files
|
| 85 |
+
for item in os.listdir(input_path):
|
| 86 |
+
item_path = os.path.join(input_path, item)
|
| 87 |
+
if os.path.isfile(item_path):
|
| 88 |
+
ext = os.path.splitext(item)[1].lower()
|
| 89 |
+
if ext in IMAGE_EXTENSIONS:
|
| 90 |
+
return "images"
|
| 91 |
+
|
| 92 |
+
return "unknown"
|
| 93 |
+
|
| 94 |
+
return "unknown"
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
# ============================================================================
|
| 98 |
+
# Common parameters and configuration
|
| 99 |
+
# ============================================================================
|
| 100 |
+
|
| 101 |
+
# ============================================================================
|
| 102 |
+
# Inference commands
|
| 103 |
+
# ============================================================================
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
@app.command()
|
| 107 |
+
def auto(
|
| 108 |
+
input_path: str = typer.Argument(
|
| 109 |
+
..., help="Path to input (image, directory, video, or COLMAP)"
|
| 110 |
+
),
|
| 111 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 112 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 113 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 114 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 115 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 116 |
+
backend_url: str = typer.Option(
|
| 117 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 118 |
+
),
|
| 119 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 120 |
+
process_res_method: str = typer.Option(
|
| 121 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 122 |
+
),
|
| 123 |
+
export_feat: str = typer.Option(
|
| 124 |
+
"",
|
| 125 |
+
help="[FEAT_VIS]Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 126 |
+
),
|
| 127 |
+
auto_cleanup: bool = typer.Option(
|
| 128 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 129 |
+
),
|
| 130 |
+
# Video-specific options
|
| 131 |
+
fps: float = typer.Option(1.0, help="[Video] Sampling FPS for frame extraction"),
|
| 132 |
+
# COLMAP-specific options
|
| 133 |
+
sparse_subdir: str = typer.Option(
|
| 134 |
+
"", help="[COLMAP] Sparse reconstruction subdirectory (e.g., '0' for sparse/0/)"
|
| 135 |
+
),
|
| 136 |
+
align_to_input_ext_scale: bool = typer.Option(
|
| 137 |
+
True, help="[COLMAP] Align prediction to input extrinsics scale"
|
| 138 |
+
),
|
| 139 |
+
# GLB export options
|
| 140 |
+
conf_thresh_percentile: float = typer.Option(
|
| 141 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 142 |
+
),
|
| 143 |
+
num_max_points: int = typer.Option(
|
| 144 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 145 |
+
),
|
| 146 |
+
show_cameras: bool = typer.Option(
|
| 147 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 148 |
+
),
|
| 149 |
+
# Feat_vis export options
|
| 150 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 151 |
+
):
|
| 152 |
+
"""
|
| 153 |
+
Automatically detect input type and run appropriate processing.
|
| 154 |
+
|
| 155 |
+
Supports:
|
| 156 |
+
- Single image file (.jpg, .png, etc.)
|
| 157 |
+
- Directory of images
|
| 158 |
+
- Video file (.mp4, .avi, etc.)
|
| 159 |
+
- COLMAP directory (with 'images' and 'sparse' subdirectories)
|
| 160 |
+
"""
|
| 161 |
+
# Detect input type
|
| 162 |
+
input_type = detect_input_type(input_path)
|
| 163 |
+
|
| 164 |
+
if input_type == "unknown":
|
| 165 |
+
typer.echo(f"❌ Error: Cannot determine input type for: {input_path}", err=True)
|
| 166 |
+
typer.echo("Supported inputs:", err=True)
|
| 167 |
+
typer.echo(" - Single image file (.jpg, .png, etc.)", err=True)
|
| 168 |
+
typer.echo(" - Directory containing images", err=True)
|
| 169 |
+
typer.echo(" - Video file (.mp4, .avi, etc.)", err=True)
|
| 170 |
+
typer.echo(" - COLMAP directory (with 'images/' and 'sparse/' subdirectories)", err=True)
|
| 171 |
+
raise typer.Exit(1)
|
| 172 |
+
|
| 173 |
+
# Display detected type
|
| 174 |
+
typer.echo(f"🔍 Detected input type: {input_type.upper()}")
|
| 175 |
+
typer.echo(f"📁 Input path: {input_path}")
|
| 176 |
+
typer.echo()
|
| 177 |
+
|
| 178 |
+
# Determine backend URL based on use_backend flag
|
| 179 |
+
final_backend_url = backend_url if use_backend else None
|
| 180 |
+
|
| 181 |
+
# Parse export_feat parameter
|
| 182 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 183 |
+
|
| 184 |
+
# Route to appropriate handler
|
| 185 |
+
if input_type == "image":
|
| 186 |
+
typer.echo("Processing single image...")
|
| 187 |
+
# Process input
|
| 188 |
+
image_files = ImageHandler.process(input_path)
|
| 189 |
+
|
| 190 |
+
# Handle export directory
|
| 191 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 192 |
+
|
| 193 |
+
# Run inference
|
| 194 |
+
run_inference(
|
| 195 |
+
image_paths=image_files,
|
| 196 |
+
export_dir=export_dir,
|
| 197 |
+
model_dir=model_dir,
|
| 198 |
+
device=device,
|
| 199 |
+
backend_url=final_backend_url,
|
| 200 |
+
export_format=export_format,
|
| 201 |
+
process_res=process_res,
|
| 202 |
+
process_res_method=process_res_method,
|
| 203 |
+
export_feat_layers=export_feat_layers,
|
| 204 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 205 |
+
num_max_points=num_max_points,
|
| 206 |
+
show_cameras=show_cameras,
|
| 207 |
+
feat_vis_fps=feat_vis_fps,
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
elif input_type == "images":
|
| 211 |
+
typer.echo("Processing directory of images...")
|
| 212 |
+
# Process input - use default extensions
|
| 213 |
+
image_files = ImagesHandler.process(input_path, "png,jpg,jpeg")
|
| 214 |
+
|
| 215 |
+
# Handle export directory
|
| 216 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 217 |
+
|
| 218 |
+
# Run inference
|
| 219 |
+
run_inference(
|
| 220 |
+
image_paths=image_files,
|
| 221 |
+
export_dir=export_dir,
|
| 222 |
+
model_dir=model_dir,
|
| 223 |
+
device=device,
|
| 224 |
+
backend_url=final_backend_url,
|
| 225 |
+
export_format=export_format,
|
| 226 |
+
process_res=process_res,
|
| 227 |
+
process_res_method=process_res_method,
|
| 228 |
+
export_feat_layers=export_feat_layers,
|
| 229 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 230 |
+
num_max_points=num_max_points,
|
| 231 |
+
show_cameras=show_cameras,
|
| 232 |
+
feat_vis_fps=feat_vis_fps,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
elif input_type == "video":
|
| 236 |
+
typer.echo(f"Processing video with FPS={fps}...")
|
| 237 |
+
# Handle export directory
|
| 238 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 239 |
+
|
| 240 |
+
# Process input
|
| 241 |
+
image_files = VideoHandler.process(input_path, export_dir, fps)
|
| 242 |
+
|
| 243 |
+
# Run inference
|
| 244 |
+
run_inference(
|
| 245 |
+
image_paths=image_files,
|
| 246 |
+
export_dir=export_dir,
|
| 247 |
+
model_dir=model_dir,
|
| 248 |
+
device=device,
|
| 249 |
+
backend_url=final_backend_url,
|
| 250 |
+
export_format=export_format,
|
| 251 |
+
process_res=process_res,
|
| 252 |
+
process_res_method=process_res_method,
|
| 253 |
+
export_feat_layers=export_feat_layers,
|
| 254 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 255 |
+
num_max_points=num_max_points,
|
| 256 |
+
show_cameras=show_cameras,
|
| 257 |
+
feat_vis_fps=feat_vis_fps,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
elif input_type == "colmap":
|
| 261 |
+
typer.echo(
|
| 262 |
+
f"Processing COLMAP directory (sparse subdirectory: '{sparse_subdir or 'default'}')..."
|
| 263 |
+
)
|
| 264 |
+
# Process input
|
| 265 |
+
image_files, extrinsics, intrinsics = ColmapHandler.process(input_path, sparse_subdir)
|
| 266 |
+
|
| 267 |
+
# Handle export directory
|
| 268 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 269 |
+
|
| 270 |
+
# Run inference
|
| 271 |
+
run_inference(
|
| 272 |
+
image_paths=image_files,
|
| 273 |
+
export_dir=export_dir,
|
| 274 |
+
model_dir=model_dir,
|
| 275 |
+
device=device,
|
| 276 |
+
backend_url=final_backend_url,
|
| 277 |
+
export_format=export_format,
|
| 278 |
+
process_res=process_res,
|
| 279 |
+
process_res_method=process_res_method,
|
| 280 |
+
export_feat_layers=export_feat_layers,
|
| 281 |
+
extrinsics=extrinsics,
|
| 282 |
+
intrinsics=intrinsics,
|
| 283 |
+
align_to_input_ext_scale=align_to_input_ext_scale,
|
| 284 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 285 |
+
num_max_points=num_max_points,
|
| 286 |
+
show_cameras=show_cameras,
|
| 287 |
+
feat_vis_fps=feat_vis_fps,
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
typer.echo()
|
| 291 |
+
typer.echo("✅ Processing completed successfully!")
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
@app.command()
|
| 295 |
+
def image(
|
| 296 |
+
image_path: str = typer.Argument(..., help="Path to input image file"),
|
| 297 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 298 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 299 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 300 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 301 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 302 |
+
backend_url: str = typer.Option(
|
| 303 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 304 |
+
),
|
| 305 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 306 |
+
process_res_method: str = typer.Option(
|
| 307 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 308 |
+
),
|
| 309 |
+
export_feat: str = typer.Option(
|
| 310 |
+
"",
|
| 311 |
+
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 312 |
+
),
|
| 313 |
+
auto_cleanup: bool = typer.Option(
|
| 314 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 315 |
+
),
|
| 316 |
+
# GLB export options
|
| 317 |
+
conf_thresh_percentile: float = typer.Option(
|
| 318 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 319 |
+
),
|
| 320 |
+
num_max_points: int = typer.Option(
|
| 321 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 322 |
+
),
|
| 323 |
+
show_cameras: bool = typer.Option(
|
| 324 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 325 |
+
),
|
| 326 |
+
# Feat_vis export options
|
| 327 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 328 |
+
):
|
| 329 |
+
"""Run camera pose and depth estimation on a single image."""
|
| 330 |
+
# Process input
|
| 331 |
+
image_files = ImageHandler.process(image_path)
|
| 332 |
+
|
| 333 |
+
# Handle export directory
|
| 334 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 335 |
+
|
| 336 |
+
# Parse export_feat parameter
|
| 337 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 338 |
+
|
| 339 |
+
# Determine backend URL based on use_backend flag
|
| 340 |
+
final_backend_url = backend_url if use_backend else None
|
| 341 |
+
|
| 342 |
+
# Run inference
|
| 343 |
+
run_inference(
|
| 344 |
+
image_paths=image_files,
|
| 345 |
+
export_dir=export_dir,
|
| 346 |
+
model_dir=model_dir,
|
| 347 |
+
device=device,
|
| 348 |
+
backend_url=final_backend_url,
|
| 349 |
+
export_format=export_format,
|
| 350 |
+
process_res=process_res,
|
| 351 |
+
process_res_method=process_res_method,
|
| 352 |
+
export_feat_layers=export_feat_layers,
|
| 353 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 354 |
+
num_max_points=num_max_points,
|
| 355 |
+
show_cameras=show_cameras,
|
| 356 |
+
feat_vis_fps=feat_vis_fps,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
@app.command()
|
| 361 |
+
def images(
|
| 362 |
+
images_dir: str = typer.Argument(..., help="Path to directory containing input images"),
|
| 363 |
+
image_extensions: str = typer.Option(
|
| 364 |
+
"png,jpg,jpeg", help="Comma-separated image file extensions to process"
|
| 365 |
+
),
|
| 366 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 367 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 368 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 369 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 370 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 371 |
+
backend_url: str = typer.Option(
|
| 372 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 373 |
+
),
|
| 374 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 375 |
+
process_res_method: str = typer.Option(
|
| 376 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 377 |
+
),
|
| 378 |
+
export_feat: str = typer.Option(
|
| 379 |
+
"",
|
| 380 |
+
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 381 |
+
),
|
| 382 |
+
auto_cleanup: bool = typer.Option(
|
| 383 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 384 |
+
),
|
| 385 |
+
# GLB export options
|
| 386 |
+
conf_thresh_percentile: float = typer.Option(
|
| 387 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 388 |
+
),
|
| 389 |
+
num_max_points: int = typer.Option(
|
| 390 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 391 |
+
),
|
| 392 |
+
show_cameras: bool = typer.Option(
|
| 393 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 394 |
+
),
|
| 395 |
+
# Feat_vis export options
|
| 396 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 397 |
+
):
|
| 398 |
+
"""Run camera pose and depth estimation on a directory of images."""
|
| 399 |
+
# Process input
|
| 400 |
+
image_files = ImagesHandler.process(images_dir, image_extensions)
|
| 401 |
+
|
| 402 |
+
# Handle export directory
|
| 403 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 404 |
+
|
| 405 |
+
# Parse export_feat parameter
|
| 406 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 407 |
+
|
| 408 |
+
# Determine backend URL based on use_backend flag
|
| 409 |
+
final_backend_url = backend_url if use_backend else None
|
| 410 |
+
|
| 411 |
+
# Run inference
|
| 412 |
+
run_inference(
|
| 413 |
+
image_paths=image_files,
|
| 414 |
+
export_dir=export_dir,
|
| 415 |
+
model_dir=model_dir,
|
| 416 |
+
device=device,
|
| 417 |
+
backend_url=final_backend_url,
|
| 418 |
+
export_format=export_format,
|
| 419 |
+
process_res=process_res,
|
| 420 |
+
process_res_method=process_res_method,
|
| 421 |
+
export_feat_layers=export_feat_layers,
|
| 422 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 423 |
+
num_max_points=num_max_points,
|
| 424 |
+
show_cameras=show_cameras,
|
| 425 |
+
feat_vis_fps=feat_vis_fps,
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
@app.command()
|
| 430 |
+
def colmap(
|
| 431 |
+
colmap_dir: str = typer.Argument(
|
| 432 |
+
..., help="Path to COLMAP directory containing 'images' and 'sparse' subdirectories"
|
| 433 |
+
),
|
| 434 |
+
sparse_subdir: str = typer.Option(
|
| 435 |
+
"", help="Sparse reconstruction subdirectory (e.g., '0' for sparse/0/, empty for sparse/)"
|
| 436 |
+
),
|
| 437 |
+
align_to_input_ext_scale: bool = typer.Option(
|
| 438 |
+
True, help="Align prediction to input extrinsics scale"
|
| 439 |
+
),
|
| 440 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 441 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 442 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 443 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 444 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 445 |
+
backend_url: str = typer.Option(
|
| 446 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 447 |
+
),
|
| 448 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 449 |
+
process_res_method: str = typer.Option(
|
| 450 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 451 |
+
),
|
| 452 |
+
export_feat: str = typer.Option(
|
| 453 |
+
"",
|
| 454 |
+
help="Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 455 |
+
),
|
| 456 |
+
auto_cleanup: bool = typer.Option(
|
| 457 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 458 |
+
),
|
| 459 |
+
# GLB export options
|
| 460 |
+
conf_thresh_percentile: float = typer.Option(
|
| 461 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 462 |
+
),
|
| 463 |
+
num_max_points: int = typer.Option(
|
| 464 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 465 |
+
),
|
| 466 |
+
show_cameras: bool = typer.Option(
|
| 467 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 468 |
+
),
|
| 469 |
+
# Feat_vis export options
|
| 470 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 471 |
+
):
|
| 472 |
+
"""Run pose conditioned depth estimation on COLMAP data."""
|
| 473 |
+
# Process input
|
| 474 |
+
image_files, extrinsics, intrinsics = ColmapHandler.process(colmap_dir, sparse_subdir)
|
| 475 |
+
|
| 476 |
+
# Handle export directory
|
| 477 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 478 |
+
|
| 479 |
+
# Parse export_feat parameter
|
| 480 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 481 |
+
|
| 482 |
+
# Determine backend URL based on use_backend flag
|
| 483 |
+
final_backend_url = backend_url if use_backend else None
|
| 484 |
+
|
| 485 |
+
# Run inference
|
| 486 |
+
run_inference(
|
| 487 |
+
image_paths=image_files,
|
| 488 |
+
export_dir=export_dir,
|
| 489 |
+
model_dir=model_dir,
|
| 490 |
+
device=device,
|
| 491 |
+
backend_url=final_backend_url,
|
| 492 |
+
export_format=export_format,
|
| 493 |
+
process_res=process_res,
|
| 494 |
+
process_res_method=process_res_method,
|
| 495 |
+
export_feat_layers=export_feat_layers,
|
| 496 |
+
extrinsics=extrinsics,
|
| 497 |
+
intrinsics=intrinsics,
|
| 498 |
+
align_to_input_ext_scale=align_to_input_ext_scale,
|
| 499 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 500 |
+
num_max_points=num_max_points,
|
| 501 |
+
show_cameras=show_cameras,
|
| 502 |
+
feat_vis_fps=feat_vis_fps,
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
|
| 506 |
+
@app.command()
|
| 507 |
+
def video(
|
| 508 |
+
video_path: str = typer.Argument(..., help="Path to input video file"),
|
| 509 |
+
fps: float = typer.Option(1.0, help="Sampling FPS for frame extraction"),
|
| 510 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 511 |
+
export_dir: str = typer.Option(DEFAULT_EXPORT_DIR, help="Export directory"),
|
| 512 |
+
export_format: str = typer.Option("glb", help="Export format"),
|
| 513 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 514 |
+
use_backend: bool = typer.Option(False, help="Use backend service for inference"),
|
| 515 |
+
backend_url: str = typer.Option(
|
| 516 |
+
"http://localhost:8008", help="Backend URL (default: http://localhost:8008)"
|
| 517 |
+
),
|
| 518 |
+
process_res: int = typer.Option(504, help="Processing resolution"),
|
| 519 |
+
process_res_method: str = typer.Option(
|
| 520 |
+
"upper_bound_resize", help="Processing resolution method"
|
| 521 |
+
),
|
| 522 |
+
export_feat: str = typer.Option(
|
| 523 |
+
"",
|
| 524 |
+
help="[FEAT_VIS] Export features from specified layers using comma-separated indices (e.g., '0,1,2').",
|
| 525 |
+
),
|
| 526 |
+
auto_cleanup: bool = typer.Option(
|
| 527 |
+
False, help="Automatically clean export directory if it exists (no prompt)"
|
| 528 |
+
),
|
| 529 |
+
# GLB export options
|
| 530 |
+
conf_thresh_percentile: float = typer.Option(
|
| 531 |
+
40.0, help="[GLB] Lower percentile for adaptive confidence threshold"
|
| 532 |
+
),
|
| 533 |
+
num_max_points: int = typer.Option(
|
| 534 |
+
1_000_000, help="[GLB] Maximum number of points in the point cloud"
|
| 535 |
+
),
|
| 536 |
+
show_cameras: bool = typer.Option(
|
| 537 |
+
True, help="[GLB] Show camera wireframes in the exported scene"
|
| 538 |
+
),
|
| 539 |
+
# Feat_vis export options
|
| 540 |
+
feat_vis_fps: int = typer.Option(15, help="[FEAT_VIS] Frame rate for output video"),
|
| 541 |
+
):
|
| 542 |
+
"""Run depth estimation on video by extracting frames and processing them."""
|
| 543 |
+
# Handle export directory
|
| 544 |
+
export_dir = InputHandler.handle_export_dir(export_dir, auto_cleanup)
|
| 545 |
+
|
| 546 |
+
# Process input
|
| 547 |
+
image_files = VideoHandler.process(video_path, export_dir, fps)
|
| 548 |
+
|
| 549 |
+
# Parse export_feat parameter
|
| 550 |
+
export_feat_layers = parse_export_feat(export_feat)
|
| 551 |
+
|
| 552 |
+
# Determine backend URL based on use_backend flag
|
| 553 |
+
final_backend_url = backend_url if use_backend else None
|
| 554 |
+
|
| 555 |
+
# Run inference
|
| 556 |
+
run_inference(
|
| 557 |
+
image_paths=image_files,
|
| 558 |
+
export_dir=export_dir,
|
| 559 |
+
model_dir=model_dir,
|
| 560 |
+
device=device,
|
| 561 |
+
backend_url=final_backend_url,
|
| 562 |
+
export_format=export_format,
|
| 563 |
+
process_res=process_res,
|
| 564 |
+
process_res_method=process_res_method,
|
| 565 |
+
export_feat_layers=export_feat_layers,
|
| 566 |
+
conf_thresh_percentile=conf_thresh_percentile,
|
| 567 |
+
num_max_points=num_max_points,
|
| 568 |
+
show_cameras=show_cameras,
|
| 569 |
+
feat_vis_fps=feat_vis_fps,
|
| 570 |
+
)
|
| 571 |
+
|
| 572 |
+
|
| 573 |
+
# ============================================================================
|
| 574 |
+
# Service management commands
|
| 575 |
+
# ============================================================================
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
@app.command()
|
| 579 |
+
def backend(
|
| 580 |
+
model_dir: str = typer.Option(DEFAULT_MODEL, help="Model directory path"),
|
| 581 |
+
device: str = typer.Option("cuda", help="Device to use"),
|
| 582 |
+
host: str = typer.Option("127.0.0.1", help="Host to bind to"),
|
| 583 |
+
port: int = typer.Option(8008, help="Port to bind to"),
|
| 584 |
+
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery directory path (optional)"),
|
| 585 |
+
):
|
| 586 |
+
"""Start model backend service with integrated gallery."""
|
| 587 |
+
typer.echo("=" * 60)
|
| 588 |
+
typer.echo("🚀 Starting Depth Anything 3 Backend Server")
|
| 589 |
+
typer.echo("=" * 60)
|
| 590 |
+
typer.echo(f"Model directory: {model_dir}")
|
| 591 |
+
typer.echo(f"Device: {device}")
|
| 592 |
+
|
| 593 |
+
# Check if gallery directory exists
|
| 594 |
+
if gallery_dir and os.path.exists(gallery_dir):
|
| 595 |
+
typer.echo(f"Gallery directory: {gallery_dir}")
|
| 596 |
+
else:
|
| 597 |
+
gallery_dir = None # Disable gallery if directory doesn't exist
|
| 598 |
+
|
| 599 |
+
typer.echo()
|
| 600 |
+
typer.echo("📡 Server URLs (Ctrl/CMD+Click to open):")
|
| 601 |
+
typer.echo(f" 🏠 Home: http://{host}:{port}")
|
| 602 |
+
typer.echo(f" 📊 Dashboard: http://{host}:{port}/dashboard")
|
| 603 |
+
typer.echo(f" 📈 API Status: http://{host}:{port}/status")
|
| 604 |
+
|
| 605 |
+
if gallery_dir:
|
| 606 |
+
typer.echo(f" 🎨 Gallery: http://{host}:{port}/gallery/")
|
| 607 |
+
|
| 608 |
+
typer.echo("=" * 60)
|
| 609 |
+
|
| 610 |
+
try:
|
| 611 |
+
start_server(model_dir, device, host, port, gallery_dir)
|
| 612 |
+
except KeyboardInterrupt:
|
| 613 |
+
typer.echo("\n👋 Backend server stopped.")
|
| 614 |
+
except Exception as e:
|
| 615 |
+
typer.echo(f"❌ Failed to start backend: {e}")
|
| 616 |
+
raise typer.Exit(1)
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
# ============================================================================
|
| 620 |
+
# Application launch commands
|
| 621 |
+
# ============================================================================
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
@app.command()
|
| 625 |
+
def gradio(
|
| 626 |
+
model_dir: str = typer.Option(DEFAULT_MODEL,help="Model directory path"),
|
| 627 |
+
workspace_dir: str = typer.Option(DEFAULT_GRADIO_DIR,help="Workspace directory path"),
|
| 628 |
+
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR,help="Gallery directory path"),
|
| 629 |
+
host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
|
| 630 |
+
port: int = typer.Option(7860, help="Port number to bind to"),
|
| 631 |
+
share: bool = typer.Option(False, help="Create a public link for the app"),
|
| 632 |
+
debug: bool = typer.Option(False, help="Enable debug mode"),
|
| 633 |
+
cache_examples: bool = typer.Option(
|
| 634 |
+
False, help="Pre-cache all example scenes at startup for faster loading"
|
| 635 |
+
),
|
| 636 |
+
cache_gs_tag: str = typer.Option(
|
| 637 |
+
"",
|
| 638 |
+
help="Tag to match scene names for high-res+3DGS caching (e.g., 'dl3dv'). Scenes containing this tag will use high_res and infer_gs=True; others will use low_res only.",
|
| 639 |
+
),
|
| 640 |
+
):
|
| 641 |
+
"""Launch Depth Anything 3 Gradio interactive web application"""
|
| 642 |
+
from depth_anything_3.app.gradio_app import DepthAnything3App
|
| 643 |
+
|
| 644 |
+
# Create necessary directories
|
| 645 |
+
os.makedirs(workspace_dir, exist_ok=True)
|
| 646 |
+
os.makedirs(gallery_dir, exist_ok=True)
|
| 647 |
+
|
| 648 |
+
typer.echo("Launching Depth Anything 3 Gradio application...")
|
| 649 |
+
typer.echo(f"Model directory: {model_dir}")
|
| 650 |
+
typer.echo(f"Workspace directory: {workspace_dir}")
|
| 651 |
+
typer.echo(f"Gallery directory: {gallery_dir}")
|
| 652 |
+
typer.echo(f"Host: {host}")
|
| 653 |
+
typer.echo(f"Port: {port}")
|
| 654 |
+
typer.echo(f"Share: {share}")
|
| 655 |
+
typer.echo(f"Debug mode: {debug}")
|
| 656 |
+
typer.echo(f"Cache examples: {cache_examples}")
|
| 657 |
+
if cache_examples:
|
| 658 |
+
if cache_gs_tag:
|
| 659 |
+
typer.echo(
|
| 660 |
+
f"Cache GS Tag: '{cache_gs_tag}' (scenes matching this tag will use high-res + 3DGS)"
|
| 661 |
+
)
|
| 662 |
+
else:
|
| 663 |
+
typer.echo(f"Cache GS Tag: None (all scenes will use low-res only)")
|
| 664 |
+
|
| 665 |
+
try:
|
| 666 |
+
# Initialize and launch application
|
| 667 |
+
app = DepthAnything3App(
|
| 668 |
+
model_dir=model_dir, workspace_dir=workspace_dir, gallery_dir=gallery_dir
|
| 669 |
+
)
|
| 670 |
+
|
| 671 |
+
# Pre-cache examples if requested
|
| 672 |
+
if cache_examples:
|
| 673 |
+
typer.echo("\n" + "=" * 60)
|
| 674 |
+
typer.echo("Pre-caching mode enabled")
|
| 675 |
+
if cache_gs_tag:
|
| 676 |
+
typer.echo(f"Scenes containing '{cache_gs_tag}' will use HIGH-RES + 3DGS")
|
| 677 |
+
typer.echo(f"Other scenes will use LOW-RES only")
|
| 678 |
+
else:
|
| 679 |
+
typer.echo(f"All scenes will use LOW-RES only")
|
| 680 |
+
typer.echo("=" * 60)
|
| 681 |
+
app.cache_examples(
|
| 682 |
+
show_cam=True,
|
| 683 |
+
filter_black_bg=False,
|
| 684 |
+
filter_white_bg=False,
|
| 685 |
+
save_percentage=20.0,
|
| 686 |
+
num_max_points=1000,
|
| 687 |
+
cache_gs_tag=cache_gs_tag,
|
| 688 |
+
gs_trj_mode="smooth",
|
| 689 |
+
gs_video_quality="low",
|
| 690 |
+
)
|
| 691 |
+
|
| 692 |
+
# Prepare launch arguments
|
| 693 |
+
launch_kwargs = {"share": share, "debug": debug}
|
| 694 |
+
|
| 695 |
+
app.launch(host=host, port=port, **launch_kwargs)
|
| 696 |
+
|
| 697 |
+
except KeyboardInterrupt:
|
| 698 |
+
typer.echo("\nGradio application stopped.")
|
| 699 |
+
except Exception as e:
|
| 700 |
+
typer.echo(f"Failed to launch Gradio application: {e}")
|
| 701 |
+
raise typer.Exit(1)
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
@app.command()
|
| 705 |
+
def gallery(
|
| 706 |
+
gallery_dir: str = typer.Option(DEFAULT_GALLERY_DIR, help="Gallery root directory"),
|
| 707 |
+
host: str = typer.Option("127.0.0.1", help="Host address to bind to"),
|
| 708 |
+
port: int = typer.Option(8007, help="Port number to bind to"),
|
| 709 |
+
open_browser: bool = typer.Option(False, help="Open browser after launch"),
|
| 710 |
+
):
|
| 711 |
+
"""Launch Depth Anything 3 Gallery server"""
|
| 712 |
+
|
| 713 |
+
# Validate gallery directory
|
| 714 |
+
if not os.path.exists(gallery_dir):
|
| 715 |
+
raise typer.BadParameter(f"Gallery directory not found: {gallery_dir}")
|
| 716 |
+
|
| 717 |
+
typer.echo("Launching Depth Anything 3 Gallery server...")
|
| 718 |
+
typer.echo(f"Gallery directory: {gallery_dir}")
|
| 719 |
+
typer.echo(f"Host: {host}")
|
| 720 |
+
typer.echo(f"Port: {port}")
|
| 721 |
+
typer.echo(f"Auto-open browser: {open_browser}")
|
| 722 |
+
|
| 723 |
+
try:
|
| 724 |
+
# Set command line arguments
|
| 725 |
+
import sys
|
| 726 |
+
|
| 727 |
+
sys.argv = ["gallery", "--dir", gallery_dir, "--host", host, "--port", str(port)]
|
| 728 |
+
if open_browser:
|
| 729 |
+
sys.argv.append("--open")
|
| 730 |
+
|
| 731 |
+
# Launch gallery server
|
| 732 |
+
gallery_main()
|
| 733 |
+
|
| 734 |
+
except KeyboardInterrupt:
|
| 735 |
+
typer.echo("\nGallery server stopped.")
|
| 736 |
+
except Exception as e:
|
| 737 |
+
typer.echo(f"Failed to launch Gallery server: {e}")
|
| 738 |
+
raise typer.Exit(1)
|
| 739 |
+
|
| 740 |
+
|
| 741 |
+
if __name__ == "__main__":
|
| 742 |
+
app()
|
src/depth_anything_3/configs/da3-base.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitb
|
| 13 |
+
out_layers: [5, 7, 9, 11]
|
| 14 |
+
alt_start: 4
|
| 15 |
+
qknorm_start: 4
|
| 16 |
+
rope_start: 4
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 1536
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 128
|
| 28 |
+
out_channels: &head_out_channels [96, 192, 384, 768]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 768
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 1536
|
src/depth_anything_3/configs/da3-giant.yaml
ADDED
|
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitg
|
| 13 |
+
out_layers: [19, 27, 33, 39]
|
| 14 |
+
alt_start: 13
|
| 15 |
+
qknorm_start: 13
|
| 16 |
+
rope_start: 13
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 3072
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 256
|
| 28 |
+
out_channels: &head_out_channels [256, 512, 1024, 1024]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 1536
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 3072
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
gs_head:
|
| 49 |
+
__object__:
|
| 50 |
+
path: depth_anything_3.model.gsdpt
|
| 51 |
+
name: GSDPT
|
| 52 |
+
args: as_params
|
| 53 |
+
|
| 54 |
+
dim_in: *head_dim_in
|
| 55 |
+
output_dim: 38 # should align with gs_adapter's setting, for gs params
|
| 56 |
+
features: *head_features
|
| 57 |
+
out_channels: *head_out_channels
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
gs_adapter:
|
| 61 |
+
__object__:
|
| 62 |
+
path: depth_anything_3.model.gs_adapter
|
| 63 |
+
name: GaussianAdapter
|
| 64 |
+
args: as_params
|
| 65 |
+
|
| 66 |
+
sh_degree: 2
|
| 67 |
+
pred_color: false # predict SH coefficient if false
|
| 68 |
+
pred_offset_depth: true
|
| 69 |
+
pred_offset_xy: true
|
| 70 |
+
gaussian_scale_min: 1e-5
|
| 71 |
+
gaussian_scale_max: 30.0
|
src/depth_anything_3/configs/da3-large.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitl
|
| 13 |
+
out_layers: [11, 15, 19, 23]
|
| 14 |
+
alt_start: 8
|
| 15 |
+
qknorm_start: 8
|
| 16 |
+
rope_start: 8
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 2048
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 256
|
| 28 |
+
out_channels: &head_out_channels [256, 512, 1024, 1024]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 1024
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 2048
|
src/depth_anything_3/configs/da3-small.yaml
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vits
|
| 13 |
+
out_layers: [5, 7, 9, 11]
|
| 14 |
+
alt_start: 4
|
| 15 |
+
qknorm_start: 4
|
| 16 |
+
rope_start: 4
|
| 17 |
+
cat_token: True
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dualdpt
|
| 22 |
+
name: DualDPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: &head_dim_in 768
|
| 26 |
+
output_dim: 2
|
| 27 |
+
features: &head_features 64
|
| 28 |
+
out_channels: &head_out_channels [48, 96, 192, 384]
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
cam_enc:
|
| 32 |
+
__object__:
|
| 33 |
+
path: depth_anything_3.model.cam_enc
|
| 34 |
+
name: CameraEnc
|
| 35 |
+
args: as_params
|
| 36 |
+
|
| 37 |
+
dim_out: 384
|
| 38 |
+
|
| 39 |
+
cam_dec:
|
| 40 |
+
__object__:
|
| 41 |
+
path: depth_anything_3.model.cam_dec
|
| 42 |
+
name: CameraDec
|
| 43 |
+
args: as_params
|
| 44 |
+
|
| 45 |
+
dim_in: 768
|
src/depth_anything_3/configs/da3metric-large.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitl
|
| 13 |
+
out_layers: [4, 11, 17, 23]
|
| 14 |
+
alt_start: -1 # -1 means disable
|
| 15 |
+
qknorm_start: -1
|
| 16 |
+
rope_start: -1
|
| 17 |
+
cat_token: False
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dpt
|
| 22 |
+
name: DPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: 1024
|
| 26 |
+
output_dim: 1
|
| 27 |
+
features: 256
|
| 28 |
+
out_channels: [256, 512, 1024, 1024]
|
src/depth_anything_3/configs/da3mono-large.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: DepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
net:
|
| 7 |
+
__object__:
|
| 8 |
+
path: depth_anything_3.model.dinov2.dinov2
|
| 9 |
+
name: DinoV2
|
| 10 |
+
args: as_params
|
| 11 |
+
|
| 12 |
+
name: vitl
|
| 13 |
+
out_layers: [4, 11, 17, 23]
|
| 14 |
+
alt_start: -1 # -1 means disable
|
| 15 |
+
qknorm_start: -1
|
| 16 |
+
rope_start: -1
|
| 17 |
+
cat_token: False
|
| 18 |
+
|
| 19 |
+
head:
|
| 20 |
+
__object__:
|
| 21 |
+
path: depth_anything_3.model.dpt
|
| 22 |
+
name: DPT
|
| 23 |
+
args: as_params
|
| 24 |
+
|
| 25 |
+
dim_in: 1024
|
| 26 |
+
output_dim: 1
|
| 27 |
+
features: 256
|
| 28 |
+
out_channels: [256, 512, 1024, 1024]
|
src/depth_anything_3/configs/da3nested-giant-large.yaml
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__object__:
|
| 2 |
+
path: depth_anything_3.model.da3
|
| 3 |
+
name: NestedDepthAnything3Net
|
| 4 |
+
args: as_params
|
| 5 |
+
|
| 6 |
+
anyview:
|
| 7 |
+
__inherit__: depth_anything_3.configs.da3-giant
|
| 8 |
+
|
| 9 |
+
metric:
|
| 10 |
+
__inherit__: depth_anything_3.configs.da3metric-large
|
src/depth_anything_3/model/__init__.py
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from depth_anything_3.model.da3 import DepthAnything3Net, NestedDepthAnything3Net
|
| 16 |
+
|
| 17 |
+
__export__ = [
|
| 18 |
+
NestedDepthAnything3Net,
|
| 19 |
+
DepthAnything3Net,
|
| 20 |
+
]
|
src/depth_anything_3/model/cam_dec.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CameraDec(nn.Module):
|
| 20 |
+
def __init__(self, dim_in=1536):
|
| 21 |
+
super().__init__()
|
| 22 |
+
output_dim = dim_in
|
| 23 |
+
self.backbone = nn.Sequential(
|
| 24 |
+
nn.Linear(output_dim, output_dim),
|
| 25 |
+
nn.ReLU(),
|
| 26 |
+
nn.Linear(output_dim, output_dim),
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
)
|
| 29 |
+
self.fc_t = nn.Linear(output_dim, 3)
|
| 30 |
+
self.fc_qvec = nn.Linear(output_dim, 4)
|
| 31 |
+
self.fc_fov = nn.Sequential(nn.Linear(output_dim, 2), nn.ReLU())
|
| 32 |
+
|
| 33 |
+
def forward(self, feat, camera_encoding=None, *args, **kwargs):
|
| 34 |
+
B, N = feat.shape[:2]
|
| 35 |
+
feat = feat.reshape(B * N, -1)
|
| 36 |
+
feat = self.backbone(feat)
|
| 37 |
+
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
|
| 38 |
+
if camera_encoding is None:
|
| 39 |
+
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
|
| 40 |
+
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
|
| 41 |
+
else:
|
| 42 |
+
out_qvec = camera_encoding[..., 3:7]
|
| 43 |
+
out_fov = camera_encoding[..., -2:]
|
| 44 |
+
pose_enc = torch.cat([out_t, out_qvec, out_fov], dim=-1)
|
| 45 |
+
return pose_enc
|
src/depth_anything_3/model/cam_enc.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from depth_anything_3.model.utils.attention import Mlp
|
| 18 |
+
from depth_anything_3.model.utils.block import Block
|
| 19 |
+
from depth_anything_3.model.utils.transform import extri_intri_to_pose_encoding
|
| 20 |
+
from depth_anything_3.utils.geometry import affine_inverse
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CameraEnc(nn.Module):
|
| 24 |
+
"""
|
| 25 |
+
CameraHead predicts camera parameters from token representations using iterative refinement.
|
| 26 |
+
|
| 27 |
+
It applies a series of transformer blocks (the "trunk") to dedicated camera tokens.
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim_out: int = 1024,
|
| 33 |
+
dim_in: int = 9,
|
| 34 |
+
trunk_depth: int = 4,
|
| 35 |
+
target_dim: int = 9,
|
| 36 |
+
num_heads: int = 16,
|
| 37 |
+
mlp_ratio: int = 4,
|
| 38 |
+
init_values: float = 0.01,
|
| 39 |
+
**kwargs,
|
| 40 |
+
):
|
| 41 |
+
super().__init__()
|
| 42 |
+
self.target_dim = target_dim
|
| 43 |
+
self.trunk_depth = trunk_depth
|
| 44 |
+
self.trunk = nn.Sequential(
|
| 45 |
+
*[
|
| 46 |
+
Block(
|
| 47 |
+
dim=dim_out,
|
| 48 |
+
num_heads=num_heads,
|
| 49 |
+
mlp_ratio=mlp_ratio,
|
| 50 |
+
init_values=init_values,
|
| 51 |
+
)
|
| 52 |
+
for _ in range(trunk_depth)
|
| 53 |
+
]
|
| 54 |
+
)
|
| 55 |
+
self.token_norm = nn.LayerNorm(dim_out)
|
| 56 |
+
self.trunk_norm = nn.LayerNorm(dim_out)
|
| 57 |
+
self.pose_branch = Mlp(
|
| 58 |
+
in_features=dim_in,
|
| 59 |
+
hidden_features=dim_out // 2,
|
| 60 |
+
out_features=dim_out,
|
| 61 |
+
drop=0,
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(
|
| 65 |
+
self,
|
| 66 |
+
ext,
|
| 67 |
+
ixt,
|
| 68 |
+
image_size,
|
| 69 |
+
) -> tuple:
|
| 70 |
+
c2ws = affine_inverse(ext)
|
| 71 |
+
pose_encoding = extri_intri_to_pose_encoding(
|
| 72 |
+
c2ws,
|
| 73 |
+
ixt,
|
| 74 |
+
image_size,
|
| 75 |
+
)
|
| 76 |
+
pose_tokens = self.pose_branch(pose_encoding)
|
| 77 |
+
pose_tokens = self.token_norm(pose_tokens)
|
| 78 |
+
pose_tokens = self.trunk(pose_tokens)
|
| 79 |
+
pose_tokens = self.trunk_norm(pose_tokens)
|
| 80 |
+
return pose_tokens
|
src/depth_anything_3/model/da3.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from __future__ import annotations
|
| 16 |
+
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from addict import Dict
|
| 20 |
+
from omegaconf import DictConfig, OmegaConf
|
| 21 |
+
|
| 22 |
+
from depth_anything_3.cfg import create_object
|
| 23 |
+
from depth_anything_3.model.utils.transform import pose_encoding_to_extri_intri
|
| 24 |
+
from depth_anything_3.utils.alignment import (
|
| 25 |
+
apply_metric_scaling,
|
| 26 |
+
compute_alignment_mask,
|
| 27 |
+
compute_sky_mask,
|
| 28 |
+
least_squares_scale_scalar,
|
| 29 |
+
sample_tensor_for_quantile,
|
| 30 |
+
set_sky_regions_to_max_depth,
|
| 31 |
+
)
|
| 32 |
+
from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, map_pdf_to_opacity
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def _wrap_cfg(cfg_obj):
|
| 36 |
+
return OmegaConf.create(cfg_obj)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class DepthAnything3Net(nn.Module):
|
| 40 |
+
"""
|
| 41 |
+
Depth Anything 3 network for depth estimation and camera pose estimation.
|
| 42 |
+
|
| 43 |
+
This network consists of:
|
| 44 |
+
- Backbone: DinoV2 feature extractor
|
| 45 |
+
- Head: DPT or DualDPT for depth prediction
|
| 46 |
+
- Optional camera decoders for pose estimation
|
| 47 |
+
- Optional GSDPT for 3DGS prediction
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
preset: Configuration preset containing network dimensions and settings
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
Dictionary containing:
|
| 54 |
+
- depth: Predicted depth map (B, H, W)
|
| 55 |
+
- depth_conf: Depth confidence map (B, H, W)
|
| 56 |
+
- extrinsics: Camera extrinsics (B, N, 4, 4)
|
| 57 |
+
- intrinsics: Camera intrinsics (B, N, 3, 3)
|
| 58 |
+
- gaussians: 3D Gaussian Splats (world space), type: model.gs_adapter.Gaussians
|
| 59 |
+
- aux: Auxiliary features for specified layers
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
# Patch size for feature extraction
|
| 63 |
+
PATCH_SIZE = 14
|
| 64 |
+
|
| 65 |
+
def __init__(self, net, head, cam_dec=None, cam_enc=None, gs_head=None, gs_adapter=None):
|
| 66 |
+
"""
|
| 67 |
+
Initialize DepthAnything3Net with given yaml-initialized configuration.
|
| 68 |
+
"""
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.backbone = net if isinstance(net, nn.Module) else create_object(_wrap_cfg(net))
|
| 71 |
+
self.head = head if isinstance(head, nn.Module) else create_object(_wrap_cfg(head))
|
| 72 |
+
self.cam_dec, self.cam_enc = None, None
|
| 73 |
+
if cam_dec is not None:
|
| 74 |
+
self.cam_dec = (
|
| 75 |
+
cam_dec if isinstance(cam_dec, nn.Module) else create_object(_wrap_cfg(cam_dec))
|
| 76 |
+
)
|
| 77 |
+
self.cam_enc = (
|
| 78 |
+
cam_dec if isinstance(cam_enc, nn.Module) else create_object(_wrap_cfg(cam_enc))
|
| 79 |
+
)
|
| 80 |
+
self.gs_adapter, self.gs_head = None, None
|
| 81 |
+
if gs_head is not None and gs_adapter is not None:
|
| 82 |
+
self.gs_adapter = (
|
| 83 |
+
gs_adapter
|
| 84 |
+
if isinstance(gs_adapter, nn.Module)
|
| 85 |
+
else create_object(_wrap_cfg(gs_adapter))
|
| 86 |
+
)
|
| 87 |
+
gs_out_dim = self.gs_adapter.d_in + 1
|
| 88 |
+
if isinstance(gs_head, nn.Module):
|
| 89 |
+
assert (
|
| 90 |
+
gs_head.out_dim == gs_out_dim
|
| 91 |
+
), f"gs_head.out_dim should be {gs_out_dim}, got {gs_head.out_dim}"
|
| 92 |
+
self.gs_head = gs_head
|
| 93 |
+
else:
|
| 94 |
+
assert (
|
| 95 |
+
gs_head["output_dim"] == gs_out_dim
|
| 96 |
+
), f"gs_head output_dim should set to {gs_out_dim}, got {gs_head['output_dim']}"
|
| 97 |
+
self.gs_head = create_object(_wrap_cfg(gs_head))
|
| 98 |
+
|
| 99 |
+
def forward(
|
| 100 |
+
self,
|
| 101 |
+
x: torch.Tensor,
|
| 102 |
+
extrinsics: torch.Tensor | None = None,
|
| 103 |
+
intrinsics: torch.Tensor | None = None,
|
| 104 |
+
export_feat_layers: list[int] | None = [],
|
| 105 |
+
infer_gs: bool = False,
|
| 106 |
+
) -> Dict[str, torch.Tensor]:
|
| 107 |
+
"""
|
| 108 |
+
Forward pass through the network.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
x: Input images (B, N, 3, H, W)
|
| 112 |
+
extrinsics: Camera extrinsics (B, N, 4, 4) - unused
|
| 113 |
+
intrinsics: Camera intrinsics (B, N, 3, 3) - unused
|
| 114 |
+
feat_layers: List of layer indices to extract features from
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
Dictionary containing predictions and auxiliary features
|
| 118 |
+
"""
|
| 119 |
+
# Extract features using backbone
|
| 120 |
+
if extrinsics is not None:
|
| 121 |
+
with torch.autocast(device_type=x.device.type, enabled=False):
|
| 122 |
+
cam_token = self.cam_enc(extrinsics, intrinsics, x.shape[-2:])
|
| 123 |
+
else:
|
| 124 |
+
cam_token = None
|
| 125 |
+
|
| 126 |
+
feats, aux_feats = self.backbone(
|
| 127 |
+
x, cam_token=cam_token, export_feat_layers=export_feat_layers
|
| 128 |
+
)
|
| 129 |
+
# feats = [[item for item in feat] for feat in feats]
|
| 130 |
+
H, W = x.shape[-2], x.shape[-1]
|
| 131 |
+
|
| 132 |
+
# Process features through depth head
|
| 133 |
+
with torch.autocast(device_type=x.device.type, enabled=False):
|
| 134 |
+
output = self._process_depth_head(feats, H, W)
|
| 135 |
+
output = self._process_camera_estimation(feats, H, W, output)
|
| 136 |
+
if infer_gs:
|
| 137 |
+
output = self._process_gs_head(feats, H, W, output, x, extrinsics, intrinsics)
|
| 138 |
+
|
| 139 |
+
# Extract auxiliary features if requested
|
| 140 |
+
output.aux = self._extract_auxiliary_features(aux_feats, export_feat_layers, H, W)
|
| 141 |
+
|
| 142 |
+
return output
|
| 143 |
+
|
| 144 |
+
def _process_depth_head(
|
| 145 |
+
self, feats: list[torch.Tensor], H: int, W: int
|
| 146 |
+
) -> Dict[str, torch.Tensor]:
|
| 147 |
+
"""Process features through the depth prediction head."""
|
| 148 |
+
return self.head(feats, H, W, patch_start_idx=0)
|
| 149 |
+
|
| 150 |
+
def _process_camera_estimation(
|
| 151 |
+
self, feats: list[torch.Tensor], H: int, W: int, output: Dict[str, torch.Tensor]
|
| 152 |
+
) -> Dict[str, torch.Tensor]:
|
| 153 |
+
"""Process camera pose estimation if camera decoder is available."""
|
| 154 |
+
if self.cam_dec is not None:
|
| 155 |
+
pose_enc = self.cam_dec(feats[-1][1])
|
| 156 |
+
# Remove ray information as it's not needed for pose estimation
|
| 157 |
+
if "ray" in output:
|
| 158 |
+
del output.ray
|
| 159 |
+
if "ray_conf" in output:
|
| 160 |
+
del output.ray_conf
|
| 161 |
+
|
| 162 |
+
# Convert pose encoding to extrinsics and intrinsics
|
| 163 |
+
c2w, ixt = pose_encoding_to_extri_intri(pose_enc, (H, W))
|
| 164 |
+
output.extrinsics = affine_inverse(c2w)
|
| 165 |
+
output.intrinsics = ixt
|
| 166 |
+
|
| 167 |
+
return output
|
| 168 |
+
|
| 169 |
+
def _process_gs_head(
|
| 170 |
+
self,
|
| 171 |
+
feats: list[torch.Tensor],
|
| 172 |
+
H: int,
|
| 173 |
+
W: int,
|
| 174 |
+
output: Dict[str, torch.Tensor],
|
| 175 |
+
in_images: torch.Tensor,
|
| 176 |
+
extrinsics: torch.Tensor | None = None,
|
| 177 |
+
intrinsics: torch.Tensor | None = None,
|
| 178 |
+
) -> Dict[str, torch.Tensor]:
|
| 179 |
+
"""Process 3DGS parameters estimation if 3DGS head is available."""
|
| 180 |
+
if self.gs_head is None or self.gs_adapter is None:
|
| 181 |
+
return output
|
| 182 |
+
assert output.get("depth", None) is not None, "must provide MV depth for the GS head."
|
| 183 |
+
|
| 184 |
+
# if GT camera poses are provided, use them
|
| 185 |
+
if extrinsics is not None and intrinsics is not None:
|
| 186 |
+
ctx_extr = extrinsics
|
| 187 |
+
ctx_intr = intrinsics
|
| 188 |
+
else:
|
| 189 |
+
ctx_extr = output.get("extrinsics", None)
|
| 190 |
+
ctx_intr = output.get("intrinsics", None)
|
| 191 |
+
assert (
|
| 192 |
+
ctx_extr is not None and ctx_intr is not None
|
| 193 |
+
), "must process camera info first if GT is not available"
|
| 194 |
+
gt_extr = extrinsics
|
| 195 |
+
# homo the extr if needed
|
| 196 |
+
ctx_extr = as_homogeneous(ctx_extr)
|
| 197 |
+
if gt_extr is not None:
|
| 198 |
+
gt_extr = as_homogeneous(gt_extr)
|
| 199 |
+
|
| 200 |
+
# forward through the gs_dpt head to get 'camera space' parameters
|
| 201 |
+
gs_outs = self.gs_head(
|
| 202 |
+
feats=feats,
|
| 203 |
+
H=H,
|
| 204 |
+
W=W,
|
| 205 |
+
patch_start_idx=0,
|
| 206 |
+
images=in_images,
|
| 207 |
+
)
|
| 208 |
+
raw_gaussians = gs_outs.raw_gs
|
| 209 |
+
densities = gs_outs.raw_gs_conf
|
| 210 |
+
|
| 211 |
+
# convert to 'world space' 3DGS parameters; ready to export and render
|
| 212 |
+
# gt_extr could be None, and will be used to align the pose scale if available
|
| 213 |
+
gs_world = self.gs_adapter(
|
| 214 |
+
extrinsics=ctx_extr,
|
| 215 |
+
intrinsics=ctx_intr,
|
| 216 |
+
depths=output.depth,
|
| 217 |
+
opacities=map_pdf_to_opacity(densities),
|
| 218 |
+
raw_gaussians=raw_gaussians,
|
| 219 |
+
image_shape=(H, W),
|
| 220 |
+
gt_extrinsics=gt_extr,
|
| 221 |
+
)
|
| 222 |
+
output.gaussians = gs_world
|
| 223 |
+
|
| 224 |
+
return output
|
| 225 |
+
|
| 226 |
+
def _extract_auxiliary_features(
|
| 227 |
+
self, feats: list[torch.Tensor], feat_layers: list[int], H: int, W: int
|
| 228 |
+
) -> Dict[str, torch.Tensor]:
|
| 229 |
+
"""Extract auxiliary features from specified layers."""
|
| 230 |
+
aux_features = Dict()
|
| 231 |
+
assert len(feats) == len(feat_layers)
|
| 232 |
+
for feat, feat_layer in zip(feats, feat_layers):
|
| 233 |
+
# Reshape features to spatial dimensions
|
| 234 |
+
feat_reshaped = feat.reshape(
|
| 235 |
+
[
|
| 236 |
+
feat.shape[0],
|
| 237 |
+
feat.shape[1],
|
| 238 |
+
H // self.PATCH_SIZE,
|
| 239 |
+
W // self.PATCH_SIZE,
|
| 240 |
+
feat.shape[-1],
|
| 241 |
+
]
|
| 242 |
+
)
|
| 243 |
+
aux_features[f"feat_layer_{feat_layer}"] = feat_reshaped
|
| 244 |
+
|
| 245 |
+
return aux_features
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
class NestedDepthAnything3Net(nn.Module):
|
| 249 |
+
"""
|
| 250 |
+
Nested Depth Anything 3 network with metric scaling capabilities.
|
| 251 |
+
|
| 252 |
+
This network combines two DepthAnything3Net branches:
|
| 253 |
+
- Main branch: Standard depth estimation
|
| 254 |
+
- Metric branch: Metric depth estimation for scaling alignment
|
| 255 |
+
|
| 256 |
+
The network performs depth alignment using least squares scaling
|
| 257 |
+
and handles sky region masking for improved depth estimation.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
preset: Configuration for the main depth estimation branch
|
| 261 |
+
second_preset: Configuration for the metric depth branch
|
| 262 |
+
"""
|
| 263 |
+
|
| 264 |
+
def __init__(self, anyview: DictConfig, metric: DictConfig):
|
| 265 |
+
"""
|
| 266 |
+
Initialize NestedDepthAnything3Net with two branches.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
preset: Configuration for main depth estimation branch
|
| 270 |
+
second_preset: Configuration for metric depth branch
|
| 271 |
+
"""
|
| 272 |
+
super().__init__()
|
| 273 |
+
self.da3 = create_object(anyview)
|
| 274 |
+
self.da3_metric = create_object(metric)
|
| 275 |
+
|
| 276 |
+
def forward(
|
| 277 |
+
self,
|
| 278 |
+
x: torch.Tensor,
|
| 279 |
+
extrinsics: torch.Tensor | None = None,
|
| 280 |
+
intrinsics: torch.Tensor | None = None,
|
| 281 |
+
export_feat_layers: list[int] | None = [],
|
| 282 |
+
infer_gs: bool = False,
|
| 283 |
+
) -> Dict[str, torch.Tensor]:
|
| 284 |
+
"""
|
| 285 |
+
Forward pass through both branches with metric scaling alignment.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
x: Input images (B, N, 3, H, W)
|
| 289 |
+
extrinsics: Camera extrinsics (B, N, 4, 4) - unused
|
| 290 |
+
intrinsics: Camera intrinsics (B, N, 3, 3) - unused
|
| 291 |
+
feat_layers: List of layer indices to extract features from
|
| 292 |
+
metric_feat: Whether to use metric features (unused)
|
| 293 |
+
|
| 294 |
+
Returns:
|
| 295 |
+
Dictionary containing aligned depth predictions and camera parameters
|
| 296 |
+
"""
|
| 297 |
+
# Get predictions from both branches
|
| 298 |
+
output = self.da3(
|
| 299 |
+
x, extrinsics, intrinsics, export_feat_layers=export_feat_layers, infer_gs=infer_gs
|
| 300 |
+
)
|
| 301 |
+
metric_output = self.da3_metric(x, infer_gs=infer_gs)
|
| 302 |
+
|
| 303 |
+
# Apply metric scaling and alignment
|
| 304 |
+
output = self._apply_metric_scaling(output, metric_output)
|
| 305 |
+
output = self._apply_depth_alignment(output, metric_output)
|
| 306 |
+
output = self._handle_sky_regions(output, metric_output)
|
| 307 |
+
|
| 308 |
+
return output
|
| 309 |
+
|
| 310 |
+
def _apply_metric_scaling(
|
| 311 |
+
self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor]
|
| 312 |
+
) -> Dict[str, torch.Tensor]:
|
| 313 |
+
"""Apply metric scaling to the metric depth output."""
|
| 314 |
+
# Scale metric depth based on camera intrinsics
|
| 315 |
+
metric_output.depth = apply_metric_scaling(
|
| 316 |
+
metric_output.depth,
|
| 317 |
+
output.intrinsics,
|
| 318 |
+
)
|
| 319 |
+
return output
|
| 320 |
+
|
| 321 |
+
def _apply_depth_alignment(
|
| 322 |
+
self, output: Dict[str, torch.Tensor], metric_output: Dict[str, torch.Tensor]
|
| 323 |
+
) -> Dict[str, torch.Tensor]:
|
| 324 |
+
"""Apply depth alignment using least squares scaling."""
|
| 325 |
+
# Compute non-sky mask
|
| 326 |
+
non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3)
|
| 327 |
+
|
| 328 |
+
# Ensure we have enough non-sky pixels
|
| 329 |
+
assert non_sky_mask.sum() > 10, "Insufficient non-sky pixels for alignment"
|
| 330 |
+
|
| 331 |
+
# Sample depth confidence for quantile computation
|
| 332 |
+
depth_conf_ns = output.depth_conf[non_sky_mask]
|
| 333 |
+
depth_conf_sampled = sample_tensor_for_quantile(depth_conf_ns, max_samples=100000)
|
| 334 |
+
median_conf = torch.quantile(depth_conf_sampled, 0.5)
|
| 335 |
+
|
| 336 |
+
# Compute alignment mask
|
| 337 |
+
align_mask = compute_alignment_mask(
|
| 338 |
+
output.depth_conf, non_sky_mask, output.depth, metric_output.depth, median_conf
|
| 339 |
+
)
|
| 340 |
+
|
| 341 |
+
# Compute scale factor using least squares
|
| 342 |
+
valid_depth = output.depth[align_mask]
|
| 343 |
+
valid_metric_depth = metric_output.depth[align_mask]
|
| 344 |
+
scale_factor = least_squares_scale_scalar(valid_metric_depth, valid_depth)
|
| 345 |
+
|
| 346 |
+
# Apply scaling to depth and extrinsics
|
| 347 |
+
output.depth *= scale_factor
|
| 348 |
+
output.extrinsics[:, :, :3, 3] *= scale_factor
|
| 349 |
+
output.is_metric = 1
|
| 350 |
+
output.scale_factor = scale_factor.item()
|
| 351 |
+
|
| 352 |
+
return output
|
| 353 |
+
|
| 354 |
+
def _handle_sky_regions(
|
| 355 |
+
self,
|
| 356 |
+
output: Dict[str, torch.Tensor],
|
| 357 |
+
metric_output: Dict[str, torch.Tensor],
|
| 358 |
+
sky_depth_def: float = 200.0,
|
| 359 |
+
) -> Dict[str, torch.Tensor]:
|
| 360 |
+
"""Handle sky regions by setting them to maximum depth."""
|
| 361 |
+
non_sky_mask = compute_sky_mask(metric_output.sky, threshold=0.3)
|
| 362 |
+
|
| 363 |
+
# Compute maximum depth for non-sky regions
|
| 364 |
+
# Use sampling to safely compute quantile on large tensors
|
| 365 |
+
non_sky_depth = output.depth[non_sky_mask]
|
| 366 |
+
if non_sky_depth.numel() > 100000:
|
| 367 |
+
idx = torch.randint(0, non_sky_depth.numel(), (100000,), device=non_sky_depth.device)
|
| 368 |
+
sampled_depth = non_sky_depth[idx]
|
| 369 |
+
else:
|
| 370 |
+
sampled_depth = non_sky_depth
|
| 371 |
+
non_sky_max = min(torch.quantile(sampled_depth, 0.99), sky_depth_def)
|
| 372 |
+
|
| 373 |
+
# Set sky regions to maximum depth and high confidence
|
| 374 |
+
output.depth, output.depth_conf = set_sky_regions_to_max_depth(
|
| 375 |
+
output.depth, output.depth_conf, non_sky_mask, max_depth=non_sky_max
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
return output
|
src/depth_anything_3/model/dinov2/dinov2.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import List
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
from depth_anything_3.model.dinov2.vision_transformer import (
|
| 15 |
+
vit_base,
|
| 16 |
+
vit_giant2,
|
| 17 |
+
vit_large,
|
| 18 |
+
vit_small,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class DinoV2(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
name: str,
|
| 26 |
+
out_layers: List[int],
|
| 27 |
+
alt_start: int = -1,
|
| 28 |
+
qknorm_start: int = -1,
|
| 29 |
+
rope_start: int = -1,
|
| 30 |
+
cat_token: bool = True,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
super().__init__()
|
| 34 |
+
assert name in {"vits", "vitb", "vitl", "vitg"}
|
| 35 |
+
self.name = name
|
| 36 |
+
self.out_layers = out_layers
|
| 37 |
+
self.alt_start = alt_start
|
| 38 |
+
self.qknorm_start = qknorm_start
|
| 39 |
+
self.rope_start = rope_start
|
| 40 |
+
self.cat_token = cat_token
|
| 41 |
+
encoder_map = {
|
| 42 |
+
"vits": vit_small,
|
| 43 |
+
"vitb": vit_base,
|
| 44 |
+
"vitl": vit_large,
|
| 45 |
+
"vitg": vit_giant2,
|
| 46 |
+
}
|
| 47 |
+
encoder_fn = encoder_map[self.name]
|
| 48 |
+
ffn_layer = "swiglufused" if self.name == "vitg" else "mlp"
|
| 49 |
+
self.pretrained = encoder_fn(
|
| 50 |
+
img_size=518,
|
| 51 |
+
patch_size=14,
|
| 52 |
+
ffn_layer=ffn_layer,
|
| 53 |
+
alt_start=alt_start,
|
| 54 |
+
qknorm_start=qknorm_start,
|
| 55 |
+
rope_start=rope_start,
|
| 56 |
+
cat_token=cat_token,
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
def forward(self, x, **kwargs):
|
| 60 |
+
return self.pretrained.get_intermediate_layers(
|
| 61 |
+
x,
|
| 62 |
+
self.out_layers,
|
| 63 |
+
**kwargs,
|
| 64 |
+
)
|
src/depth_anything_3/model/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# from .attention import MemEffAttention
|
| 8 |
+
from .block import Block
|
| 9 |
+
from .layer_scale import LayerScale
|
| 10 |
+
from .mlp import Mlp
|
| 11 |
+
from .patch_embed import PatchEmbed
|
| 12 |
+
from .rope import PositionGetter, RotaryPositionEmbedding2D
|
| 13 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
Mlp,
|
| 17 |
+
PatchEmbed,
|
| 18 |
+
SwiGLUFFN,
|
| 19 |
+
SwiGLUFFNFused,
|
| 20 |
+
Block,
|
| 21 |
+
# MemEffAttention,
|
| 22 |
+
LayerScale,
|
| 23 |
+
PositionGetter,
|
| 24 |
+
RotaryPositionEmbedding2D,
|
| 25 |
+
]
|
src/depth_anything_3/model/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
logger = logging.getLogger("dinov2")
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class Attention(nn.Module):
|
| 19 |
+
def __init__(
|
| 20 |
+
self,
|
| 21 |
+
dim: int,
|
| 22 |
+
num_heads: int = 8,
|
| 23 |
+
qkv_bias: bool = False,
|
| 24 |
+
proj_bias: bool = True,
|
| 25 |
+
attn_drop: float = 0.0,
|
| 26 |
+
proj_drop: float = 0.0,
|
| 27 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 28 |
+
qk_norm: bool = False,
|
| 29 |
+
fused_attn: bool = True, # use F.scaled_dot_product_attention or not
|
| 30 |
+
rope=None,
|
| 31 |
+
) -> None:
|
| 32 |
+
super().__init__()
|
| 33 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 34 |
+
self.num_heads = num_heads
|
| 35 |
+
head_dim = dim // num_heads
|
| 36 |
+
self.scale = head_dim**-0.5
|
| 37 |
+
self.fused_attn = fused_attn
|
| 38 |
+
|
| 39 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 40 |
+
self.q_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 41 |
+
self.k_norm = norm_layer(head_dim) if qk_norm else nn.Identity()
|
| 42 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 43 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 44 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 45 |
+
self.rope = rope
|
| 46 |
+
|
| 47 |
+
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
|
| 48 |
+
B, N, C = x.shape
|
| 49 |
+
qkv = (
|
| 50 |
+
self.qkv(x)
|
| 51 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 52 |
+
.permute(2, 0, 3, 1, 4)
|
| 53 |
+
)
|
| 54 |
+
q, k, v = qkv[0], qkv[1], qkv[2]
|
| 55 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 56 |
+
if self.rope is not None and pos is not None:
|
| 57 |
+
q = self.rope(q, pos)
|
| 58 |
+
k = self.rope(k, pos)
|
| 59 |
+
if self.fused_attn:
|
| 60 |
+
x = F.scaled_dot_product_attention(
|
| 61 |
+
q,
|
| 62 |
+
k,
|
| 63 |
+
v,
|
| 64 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 65 |
+
attn_mask=(
|
| 66 |
+
(attn_mask)[:, None].repeat(1, self.num_heads, 1, 1)
|
| 67 |
+
if attn_mask is not None
|
| 68 |
+
else None
|
| 69 |
+
),
|
| 70 |
+
)
|
| 71 |
+
else:
|
| 72 |
+
q = q * self.scale
|
| 73 |
+
attn = q @ k.transpose(-2, -1)
|
| 74 |
+
attn = attn.softmax(dim=-1)
|
| 75 |
+
attn = self.attn_drop(attn)
|
| 76 |
+
x = attn @ v
|
| 77 |
+
|
| 78 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def _forward(self, x: Tensor) -> Tensor:
|
| 84 |
+
B, N, C = x.shape
|
| 85 |
+
qkv = (
|
| 86 |
+
self.qkv(x)
|
| 87 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 88 |
+
.permute(2, 0, 3, 1, 4)
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 92 |
+
attn = q @ k.transpose(-2, -1)
|
| 93 |
+
|
| 94 |
+
attn = attn.softmax(dim=-1)
|
| 95 |
+
attn = self.attn_drop(attn)
|
| 96 |
+
|
| 97 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 98 |
+
x = self.proj(x)
|
| 99 |
+
x = self.proj_drop(x)
|
| 100 |
+
return x
|
src/depth_anything_3/model/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa: F821
|
| 2 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 3 |
+
# All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# This source code is licensed under the license found in the
|
| 6 |
+
# LICENSE file in the root directory of this source tree.
|
| 7 |
+
|
| 8 |
+
# References:
|
| 9 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 10 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Callable, Optional
|
| 14 |
+
import torch
|
| 15 |
+
from torch import Tensor, nn
|
| 16 |
+
|
| 17 |
+
from .attention import Attention
|
| 18 |
+
from .drop_path import DropPath
|
| 19 |
+
from .layer_scale import LayerScale
|
| 20 |
+
from .mlp import Mlp
|
| 21 |
+
|
| 22 |
+
logger = logging.getLogger("dinov2")
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class Block(nn.Module):
|
| 27 |
+
def __init__(
|
| 28 |
+
self,
|
| 29 |
+
dim: int,
|
| 30 |
+
num_heads: int,
|
| 31 |
+
mlp_ratio: float = 4.0,
|
| 32 |
+
qkv_bias: bool = False,
|
| 33 |
+
proj_bias: bool = True,
|
| 34 |
+
ffn_bias: bool = True,
|
| 35 |
+
drop: float = 0.0,
|
| 36 |
+
attn_drop: float = 0.0,
|
| 37 |
+
init_values=None,
|
| 38 |
+
drop_path: float = 0.0,
|
| 39 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 40 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 41 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 42 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 43 |
+
qk_norm: bool = False,
|
| 44 |
+
rope=None,
|
| 45 |
+
ln_eps: float = 1e-6,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 49 |
+
self.norm1 = norm_layer(dim, eps=ln_eps)
|
| 50 |
+
self.attn = attn_class(
|
| 51 |
+
dim,
|
| 52 |
+
num_heads=num_heads,
|
| 53 |
+
qkv_bias=qkv_bias,
|
| 54 |
+
proj_bias=proj_bias,
|
| 55 |
+
attn_drop=attn_drop,
|
| 56 |
+
proj_drop=drop,
|
| 57 |
+
qk_norm=qk_norm,
|
| 58 |
+
rope=rope,
|
| 59 |
+
)
|
| 60 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 61 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 62 |
+
|
| 63 |
+
self.norm2 = norm_layer(dim, eps=ln_eps)
|
| 64 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 65 |
+
self.mlp = ffn_layer(
|
| 66 |
+
in_features=dim,
|
| 67 |
+
hidden_features=mlp_hidden_dim,
|
| 68 |
+
act_layer=act_layer,
|
| 69 |
+
drop=drop,
|
| 70 |
+
bias=ffn_bias,
|
| 71 |
+
)
|
| 72 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 73 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 74 |
+
|
| 75 |
+
self.sample_drop_ratio = drop_path
|
| 76 |
+
|
| 77 |
+
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
|
| 78 |
+
def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor:
|
| 79 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
|
| 80 |
+
|
| 81 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 82 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 83 |
+
|
| 84 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 85 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 86 |
+
x = drop_add_residual_stochastic_depth(
|
| 87 |
+
x,
|
| 88 |
+
residual_func=attn_residual_func,
|
| 89 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 90 |
+
pos=pos,
|
| 91 |
+
)
|
| 92 |
+
x = drop_add_residual_stochastic_depth(
|
| 93 |
+
x,
|
| 94 |
+
residual_func=ffn_residual_func,
|
| 95 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 96 |
+
)
|
| 97 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 98 |
+
x = x + self.drop_path1(attn_residual_func(x, pos=pos, attn_mask=attn_mask))
|
| 99 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 100 |
+
else:
|
| 101 |
+
x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
|
| 102 |
+
x = x + ffn_residual_func(x)
|
| 103 |
+
return x
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def drop_add_residual_stochastic_depth(
|
| 107 |
+
x: Tensor,
|
| 108 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 109 |
+
sample_drop_ratio: float = 0.0,
|
| 110 |
+
pos: Optional[Tensor] = None,
|
| 111 |
+
) -> Tensor:
|
| 112 |
+
# 1) extract subset using permutation
|
| 113 |
+
b, n, d = x.shape
|
| 114 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 115 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 116 |
+
x_subset = x[brange]
|
| 117 |
+
|
| 118 |
+
# 2) apply residual_func to get residual
|
| 119 |
+
if pos is not None:
|
| 120 |
+
# if necessary, apply rope to the subset
|
| 121 |
+
pos = pos[brange]
|
| 122 |
+
residual = residual_func(x_subset, pos=pos)
|
| 123 |
+
else:
|
| 124 |
+
residual = residual_func(x_subset)
|
| 125 |
+
|
| 126 |
+
x_flat = x.flatten(1)
|
| 127 |
+
residual = residual.flatten(1)
|
| 128 |
+
|
| 129 |
+
residual_scale_factor = b / sample_subset_size
|
| 130 |
+
|
| 131 |
+
# 3) add the residual
|
| 132 |
+
x_plus_residual = torch.index_add(
|
| 133 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 134 |
+
)
|
| 135 |
+
return x_plus_residual.view_as(x)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 139 |
+
b, n, d = x.shape
|
| 140 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 141 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 142 |
+
residual_scale_factor = b / sample_subset_size
|
| 143 |
+
return brange, residual_scale_factor
|
src/depth_anything_3/model/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 16 |
+
if drop_prob == 0.0 or not training:
|
| 17 |
+
return x
|
| 18 |
+
keep_prob = 1 - drop_prob
|
| 19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 21 |
+
if keep_prob > 0.0:
|
| 22 |
+
random_tensor.div_(keep_prob)
|
| 23 |
+
output = x * random_tensor
|
| 24 |
+
return output
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class DropPath(nn.Module):
|
| 28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, drop_prob=None):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.drop_prob = drop_prob
|
| 33 |
+
|
| 34 |
+
def forward(self, x):
|
| 35 |
+
return drop_path(x, self.drop_prob, self.training)
|
src/depth_anything_3/model/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa: E501
|
| 8 |
+
|
| 9 |
+
from typing import Union
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor, nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LayerScale(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
dim: int,
|
| 18 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 19 |
+
inplace: bool = False,
|
| 20 |
+
) -> None:
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.dim = dim
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.init_values = init_values
|
| 25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 29 |
+
|
| 30 |
+
def extra_repr(self) -> str:
|
| 31 |
+
return f"{self.dim}, init_values={self.init_values}, inplace={self.inplace}"
|
src/depth_anything_3/model/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
src/depth_anything_3/model/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional, Tuple, Union
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert (
|
| 73 |
+
H % patch_H == 0
|
| 74 |
+
), f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 75 |
+
assert (
|
| 76 |
+
W % patch_W == 0
|
| 77 |
+
), f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 78 |
+
|
| 79 |
+
x = self.proj(x) # B C H W
|
| 80 |
+
H, W = x.size(2), x.size(3)
|
| 81 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 82 |
+
x = self.norm(x)
|
| 83 |
+
if not self.flatten_embedding:
|
| 84 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 85 |
+
return x
|
| 86 |
+
|
| 87 |
+
def flops(self) -> float:
|
| 88 |
+
Ho, Wo = self.patches_resolution
|
| 89 |
+
flops = (
|
| 90 |
+
Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 91 |
+
)
|
| 92 |
+
if self.norm is not None:
|
| 93 |
+
flops += Ho * Wo * self.embed_dim
|
| 94 |
+
return flops
|
src/depth_anything_3/model/dinov2/layers/rope.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
# Implementation of 2D Rotary Position Embeddings (RoPE).
|
| 8 |
+
|
| 9 |
+
# This module provides a clean implementation of 2D Rotary Position Embeddings,
|
| 10 |
+
# which extends the original RoPE concept to handle 2D spatial positions.
|
| 11 |
+
|
| 12 |
+
# Inspired by:
|
| 13 |
+
# https://github.com/meta-llama/codellama/blob/main/llama/model.py
|
| 14 |
+
# https://github.com/naver-ai/rope-vit
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
from typing import Dict, Tuple
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
import torch.nn.functional as F
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class PositionGetter:
|
| 24 |
+
"""Generates and caches 2D spatial positions for patches in a grid.
|
| 25 |
+
|
| 26 |
+
This class efficiently manages the generation of spatial coordinates for patches
|
| 27 |
+
in a 2D grid, caching results to avoid redundant computations.
|
| 28 |
+
|
| 29 |
+
Attributes:
|
| 30 |
+
position_cache: Dictionary storing precomputed position tensors for different
|
| 31 |
+
grid dimensions.
|
| 32 |
+
"""
|
| 33 |
+
|
| 34 |
+
def __init__(self):
|
| 35 |
+
"""Initializes the position generator with an empty cache."""
|
| 36 |
+
self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {}
|
| 37 |
+
|
| 38 |
+
def __call__(
|
| 39 |
+
self, batch_size: int, height: int, width: int, device: torch.device
|
| 40 |
+
) -> torch.Tensor:
|
| 41 |
+
"""Generates spatial positions for a batch of patches.
|
| 42 |
+
|
| 43 |
+
Args:
|
| 44 |
+
batch_size: Number of samples in the batch.
|
| 45 |
+
height: Height of the grid in patches.
|
| 46 |
+
width: Width of the grid in patches.
|
| 47 |
+
device: Target device for the position tensor.
|
| 48 |
+
|
| 49 |
+
Returns:
|
| 50 |
+
Tensor of shape (batch_size, height*width, 2) containing y,x coordinates
|
| 51 |
+
for each position in the grid, repeated for each batch item.
|
| 52 |
+
"""
|
| 53 |
+
if (height, width) not in self.position_cache:
|
| 54 |
+
y_coords = torch.arange(height, device=device)
|
| 55 |
+
x_coords = torch.arange(width, device=device)
|
| 56 |
+
positions = torch.cartesian_prod(y_coords, x_coords)
|
| 57 |
+
self.position_cache[height, width] = positions
|
| 58 |
+
|
| 59 |
+
cached_positions = self.position_cache[height, width]
|
| 60 |
+
return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class RotaryPositionEmbedding2D(nn.Module):
|
| 64 |
+
"""2D Rotary Position Embedding implementation.
|
| 65 |
+
|
| 66 |
+
This module applies rotary position embeddings to input tokens based on their
|
| 67 |
+
2D spatial positions. It handles the position-dependent rotation of features
|
| 68 |
+
separately for vertical and horizontal dimensions.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
frequency: Base frequency for the position embeddings. Default: 100.0
|
| 72 |
+
scaling_factor: Scaling factor for frequency computation. Default: 1.0
|
| 73 |
+
|
| 74 |
+
Attributes:
|
| 75 |
+
base_frequency: Base frequency for computing position embeddings.
|
| 76 |
+
scaling_factor: Factor to scale the computed frequencies.
|
| 77 |
+
frequency_cache: Cache for storing precomputed frequency components.
|
| 78 |
+
"""
|
| 79 |
+
|
| 80 |
+
def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0):
|
| 81 |
+
"""Initializes the 2D RoPE module."""
|
| 82 |
+
super().__init__()
|
| 83 |
+
self.base_frequency = frequency
|
| 84 |
+
self.scaling_factor = scaling_factor
|
| 85 |
+
self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {}
|
| 86 |
+
|
| 87 |
+
def _compute_frequency_components(
|
| 88 |
+
self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype
|
| 89 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 90 |
+
"""Computes frequency components for rotary embeddings.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
dim: Feature dimension (must be even).
|
| 94 |
+
seq_len: Maximum sequence length.
|
| 95 |
+
device: Target device for computations.
|
| 96 |
+
dtype: Data type for the computed tensors.
|
| 97 |
+
|
| 98 |
+
Returns:
|
| 99 |
+
Tuple of (cosine, sine) tensors for frequency components.
|
| 100 |
+
"""
|
| 101 |
+
cache_key = (dim, seq_len, device, dtype)
|
| 102 |
+
if cache_key not in self.frequency_cache:
|
| 103 |
+
# Compute frequency bands
|
| 104 |
+
exponents = torch.arange(0, dim, 2, device=device).float() / dim
|
| 105 |
+
inv_freq = 1.0 / (self.base_frequency**exponents)
|
| 106 |
+
|
| 107 |
+
# Generate position-dependent frequencies
|
| 108 |
+
positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
| 109 |
+
angles = torch.einsum("i,j->ij", positions, inv_freq)
|
| 110 |
+
|
| 111 |
+
# Compute and cache frequency components
|
| 112 |
+
angles = angles.to(dtype)
|
| 113 |
+
angles = torch.cat((angles, angles), dim=-1)
|
| 114 |
+
cos_components = angles.cos().to(dtype)
|
| 115 |
+
sin_components = angles.sin().to(dtype)
|
| 116 |
+
self.frequency_cache[cache_key] = (cos_components, sin_components)
|
| 117 |
+
|
| 118 |
+
return self.frequency_cache[cache_key]
|
| 119 |
+
|
| 120 |
+
@staticmethod
|
| 121 |
+
def _rotate_features(x: torch.Tensor) -> torch.Tensor:
|
| 122 |
+
"""Performs feature rotation by splitting and recombining feature dimensions.
|
| 123 |
+
|
| 124 |
+
Args:
|
| 125 |
+
x: Input tensor to rotate.
|
| 126 |
+
|
| 127 |
+
Returns:
|
| 128 |
+
Rotated feature tensor.
|
| 129 |
+
"""
|
| 130 |
+
feature_dim = x.shape[-1]
|
| 131 |
+
x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :]
|
| 132 |
+
return torch.cat((-x2, x1), dim=-1)
|
| 133 |
+
|
| 134 |
+
def _apply_1d_rope(
|
| 135 |
+
self,
|
| 136 |
+
tokens: torch.Tensor,
|
| 137 |
+
positions: torch.Tensor,
|
| 138 |
+
cos_comp: torch.Tensor,
|
| 139 |
+
sin_comp: torch.Tensor,
|
| 140 |
+
) -> torch.Tensor:
|
| 141 |
+
"""Applies 1D rotary position embeddings along one dimension.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
tokens: Input token features.
|
| 145 |
+
positions: Position indices.
|
| 146 |
+
cos_comp: Cosine components for rotation.
|
| 147 |
+
sin_comp: Sine components for rotation.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
Tokens with applied rotary position embeddings.
|
| 151 |
+
"""
|
| 152 |
+
# Embed positions with frequency components
|
| 153 |
+
cos = F.embedding(positions, cos_comp)[:, None, :, :]
|
| 154 |
+
sin = F.embedding(positions, sin_comp)[:, None, :, :]
|
| 155 |
+
# Apply rotation
|
| 156 |
+
return (tokens * cos) + (self._rotate_features(tokens) * sin)
|
| 157 |
+
|
| 158 |
+
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
| 159 |
+
"""Applies 2D rotary position embeddings to input tokens.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim).
|
| 163 |
+
The feature dimension (dim) must be divisible by 4.
|
| 164 |
+
positions: Position tensor of shape (batch_size, n_tokens, 2) containing
|
| 165 |
+
the y and x coordinates for each token.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
Tensor of same shape as input with applied 2D rotary position embeddings.
|
| 169 |
+
|
| 170 |
+
Raises:
|
| 171 |
+
AssertionError: If input dimensions are invalid or positions are malformed.
|
| 172 |
+
"""
|
| 173 |
+
# Validate inputs
|
| 174 |
+
assert tokens.size(-1) % 2 == 0, "Feature dimension must be even"
|
| 175 |
+
assert (
|
| 176 |
+
positions.ndim == 3 and positions.shape[-1] == 2
|
| 177 |
+
), "Positions must have shape (batch_size, n_tokens, 2)"
|
| 178 |
+
|
| 179 |
+
# Compute feature dimension for each spatial direction
|
| 180 |
+
feature_dim = tokens.size(-1) // 2
|
| 181 |
+
|
| 182 |
+
# Get frequency components
|
| 183 |
+
max_position = int(positions.max()) + 1
|
| 184 |
+
cos_comp, sin_comp = self._compute_frequency_components(
|
| 185 |
+
feature_dim, max_position, tokens.device, tokens.dtype
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Split features for vertical and horizontal processing
|
| 189 |
+
vertical_features, horizontal_features = tokens.chunk(2, dim=-1)
|
| 190 |
+
|
| 191 |
+
# Apply RoPE separately for each dimension
|
| 192 |
+
vertical_features = self._apply_1d_rope(
|
| 193 |
+
vertical_features, positions[..., 0], cos_comp, sin_comp
|
| 194 |
+
)
|
| 195 |
+
horizontal_features = self._apply_1d_rope(
|
| 196 |
+
horizontal_features, positions[..., 1], cos_comp, sin_comp
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# Combine processed features
|
| 200 |
+
return torch.cat((vertical_features, horizontal_features), dim=-1)
|
src/depth_anything_3/model/dinov2/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SwiGLUFFN(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_features: int,
|
| 16 |
+
hidden_features: Optional[int] = None,
|
| 17 |
+
out_features: Optional[int] = None,
|
| 18 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 19 |
+
drop: float = 0.0,
|
| 20 |
+
bias: bool = True,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
out_features = out_features or in_features
|
| 24 |
+
hidden_features = hidden_features or in_features
|
| 25 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 26 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 27 |
+
|
| 28 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 29 |
+
x12 = self.w12(x)
|
| 30 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 31 |
+
hidden = F.silu(x1) * x2
|
| 32 |
+
return self.w3(hidden)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
from xformers.ops import SwiGLU
|
| 37 |
+
|
| 38 |
+
XFORMERS_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
SwiGLU = SwiGLUFFN
|
| 41 |
+
XFORMERS_AVAILABLE = False
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
in_features: int,
|
| 48 |
+
hidden_features: Optional[int] = None,
|
| 49 |
+
out_features: Optional[int] = None,
|
| 50 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 51 |
+
drop: float = 0.0,
|
| 52 |
+
bias: bool = True,
|
| 53 |
+
) -> None:
|
| 54 |
+
out_features = out_features or in_features
|
| 55 |
+
hidden_features = hidden_features or in_features
|
| 56 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 57 |
+
super().__init__(
|
| 58 |
+
in_features=in_features,
|
| 59 |
+
hidden_features=hidden_features,
|
| 60 |
+
out_features=out_features,
|
| 61 |
+
bias=bias,
|
| 62 |
+
)
|
src/depth_anything_3/model/dinov2/vision_transformer.py
ADDED
|
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import math
|
| 11 |
+
from typing import Callable, List, Sequence, Tuple, Union
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
import torch.utils.checkpoint
|
| 16 |
+
from einops import rearrange
|
| 17 |
+
|
| 18 |
+
from depth_anything_3.utils.logger import logger
|
| 19 |
+
|
| 20 |
+
from .layers import LayerScale # noqa: F401
|
| 21 |
+
from .layers import Mlp # noqa: F401
|
| 22 |
+
from .layers import ( # noqa: F401
|
| 23 |
+
Block,
|
| 24 |
+
PatchEmbed,
|
| 25 |
+
PositionGetter,
|
| 26 |
+
RotaryPositionEmbedding2D,
|
| 27 |
+
SwiGLUFFNFused,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
# logger = logging.getLogger("dinov2")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 34 |
+
"""
|
| 35 |
+
embed_dim: output dimension for each position
|
| 36 |
+
pos: a list of positions to be encoded: size (M,)
|
| 37 |
+
out: (M, D)
|
| 38 |
+
"""
|
| 39 |
+
assert embed_dim % 2 == 0
|
| 40 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 41 |
+
omega /= embed_dim / 2.0
|
| 42 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 43 |
+
|
| 44 |
+
pos = pos.reshape(-1) # (M,)
|
| 45 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 46 |
+
|
| 47 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 48 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 49 |
+
|
| 50 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 51 |
+
return emb
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def named_apply(
|
| 55 |
+
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
|
| 56 |
+
) -> nn.Module:
|
| 57 |
+
if not depth_first and include_root:
|
| 58 |
+
fn(module=module, name=name)
|
| 59 |
+
for child_name, child_module in module.named_children():
|
| 60 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 61 |
+
named_apply(
|
| 62 |
+
fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True
|
| 63 |
+
)
|
| 64 |
+
if depth_first and include_root:
|
| 65 |
+
fn(module=module, name=name)
|
| 66 |
+
return module
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class BlockChunk(nn.ModuleList):
|
| 70 |
+
def forward(self, x):
|
| 71 |
+
for b in self:
|
| 72 |
+
x = b(x)
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DinoVisionTransformer(nn.Module):
|
| 77 |
+
def __init__(
|
| 78 |
+
self,
|
| 79 |
+
img_size=224,
|
| 80 |
+
patch_size=16,
|
| 81 |
+
in_chans=3,
|
| 82 |
+
embed_dim=768,
|
| 83 |
+
depth=12,
|
| 84 |
+
num_heads=12,
|
| 85 |
+
mlp_ratio=4.0,
|
| 86 |
+
qkv_bias=True,
|
| 87 |
+
ffn_bias=True,
|
| 88 |
+
proj_bias=True,
|
| 89 |
+
drop_path_rate=0.0,
|
| 90 |
+
drop_path_uniform=False,
|
| 91 |
+
init_values=1.0, # for layerscale: None or 0 => no layerscale
|
| 92 |
+
embed_layer=PatchEmbed,
|
| 93 |
+
act_layer=nn.GELU,
|
| 94 |
+
block_fn=Block,
|
| 95 |
+
ffn_layer="mlp",
|
| 96 |
+
block_chunks=1,
|
| 97 |
+
num_register_tokens=0,
|
| 98 |
+
interpolate_antialias=False,
|
| 99 |
+
interpolate_offset=0.1,
|
| 100 |
+
alt_start=-1,
|
| 101 |
+
qknorm_start=-1,
|
| 102 |
+
rope_start=-1,
|
| 103 |
+
rope_freq=100,
|
| 104 |
+
plus_cam_token=False,
|
| 105 |
+
cat_token=True,
|
| 106 |
+
):
|
| 107 |
+
"""
|
| 108 |
+
Args:
|
| 109 |
+
img_size (int, tuple): input image size
|
| 110 |
+
patch_size (int, tuple): patch size
|
| 111 |
+
in_chans (int): number of input channels
|
| 112 |
+
embed_dim (int): embedding dimension
|
| 113 |
+
depth (int): depth of transformer
|
| 114 |
+
num_heads (int): number of attention heads
|
| 115 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 116 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 117 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 118 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 119 |
+
weight_init (str): weight init scheme
|
| 120 |
+
init_values (float): layer-scale init values
|
| 121 |
+
embed_layer (nn.Module): patch embedding layer
|
| 122 |
+
act_layer (nn.Module): MLP activation layer
|
| 123 |
+
block_fn (nn.Module): transformer block class
|
| 124 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 125 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 126 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 127 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating
|
| 128 |
+
positional embeddings
|
| 129 |
+
interpolate_offset: (float) work-around offset to apply when interpolating
|
| 130 |
+
positional embeddings
|
| 131 |
+
block_prompt: (bool) whether to add ray embeddings to the block input
|
| 132 |
+
"""
|
| 133 |
+
super().__init__()
|
| 134 |
+
self.patch_start_idx = 1
|
| 135 |
+
norm_layer = nn.LayerNorm
|
| 136 |
+
self.num_features = self.embed_dim = (
|
| 137 |
+
embed_dim # num_features for consistency with other models
|
| 138 |
+
)
|
| 139 |
+
self.alt_start = alt_start
|
| 140 |
+
self.qknorm_start = qknorm_start
|
| 141 |
+
self.rope_start = rope_start
|
| 142 |
+
self.cat_token = cat_token
|
| 143 |
+
self.num_tokens = 1
|
| 144 |
+
self.n_blocks = depth
|
| 145 |
+
self.num_heads = num_heads
|
| 146 |
+
self.patch_size = patch_size
|
| 147 |
+
self.num_register_tokens = num_register_tokens
|
| 148 |
+
self.interpolate_antialias = interpolate_antialias
|
| 149 |
+
self.interpolate_offset = interpolate_offset
|
| 150 |
+
|
| 151 |
+
self.patch_embed = embed_layer(
|
| 152 |
+
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim
|
| 153 |
+
)
|
| 154 |
+
num_patches = self.patch_embed.num_patches
|
| 155 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 156 |
+
if self.alt_start != -1:
|
| 157 |
+
self.camera_token = nn.Parameter(torch.randn(1, 2, embed_dim))
|
| 158 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 159 |
+
assert num_register_tokens >= 0
|
| 160 |
+
self.register_tokens = (
|
| 161 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim))
|
| 162 |
+
if num_register_tokens
|
| 163 |
+
else None
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
if drop_path_uniform is True:
|
| 167 |
+
dpr = [drop_path_rate] * depth
|
| 168 |
+
else:
|
| 169 |
+
dpr = [
|
| 170 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 171 |
+
] # stochastic depth decay rule
|
| 172 |
+
if ffn_layer == "mlp":
|
| 173 |
+
logger.info("using MLP layer as FFN")
|
| 174 |
+
ffn_layer = Mlp
|
| 175 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 176 |
+
logger.info("using SwiGLU layer as FFN")
|
| 177 |
+
ffn_layer = SwiGLUFFNFused
|
| 178 |
+
elif ffn_layer == "identity":
|
| 179 |
+
logger.info("using Identity layer as FFN")
|
| 180 |
+
|
| 181 |
+
def f(*args, **kwargs):
|
| 182 |
+
return nn.Identity()
|
| 183 |
+
|
| 184 |
+
ffn_layer = f
|
| 185 |
+
else:
|
| 186 |
+
raise NotImplementedError
|
| 187 |
+
|
| 188 |
+
if self.rope_start != -1:
|
| 189 |
+
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None
|
| 190 |
+
self.position_getter = PositionGetter() if self.rope is not None else None
|
| 191 |
+
else:
|
| 192 |
+
self.rope = None
|
| 193 |
+
blocks_list = [
|
| 194 |
+
block_fn(
|
| 195 |
+
dim=embed_dim,
|
| 196 |
+
num_heads=num_heads,
|
| 197 |
+
mlp_ratio=mlp_ratio,
|
| 198 |
+
qkv_bias=qkv_bias,
|
| 199 |
+
proj_bias=proj_bias,
|
| 200 |
+
ffn_bias=ffn_bias,
|
| 201 |
+
drop_path=dpr[i],
|
| 202 |
+
norm_layer=norm_layer,
|
| 203 |
+
act_layer=act_layer,
|
| 204 |
+
ffn_layer=ffn_layer,
|
| 205 |
+
init_values=init_values,
|
| 206 |
+
qk_norm=i >= qknorm_start if qknorm_start != -1 else False,
|
| 207 |
+
rope=self.rope if i >= rope_start and rope_start != -1 else None,
|
| 208 |
+
)
|
| 209 |
+
for i in range(depth)
|
| 210 |
+
]
|
| 211 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 212 |
+
self.norm = norm_layer(embed_dim)
|
| 213 |
+
|
| 214 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 215 |
+
previous_dtype = x.dtype
|
| 216 |
+
npatch = x.shape[1] - 1
|
| 217 |
+
N = self.pos_embed.shape[1] - 1
|
| 218 |
+
if npatch == N and w == h:
|
| 219 |
+
return self.pos_embed
|
| 220 |
+
pos_embed = self.pos_embed.float()
|
| 221 |
+
class_pos_embed = pos_embed[:, 0]
|
| 222 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 223 |
+
dim = x.shape[-1]
|
| 224 |
+
w0 = w // self.patch_size
|
| 225 |
+
h0 = h // self.patch_size
|
| 226 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 227 |
+
assert N == M * M
|
| 228 |
+
kwargs = {}
|
| 229 |
+
if self.interpolate_offset:
|
| 230 |
+
# Historical kludge: add a small number to avoid floating point error in the
|
| 231 |
+
# interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 232 |
+
# Note: still needed for backward-compatibility, the underlying operators are using
|
| 233 |
+
# both output size and scale factors
|
| 234 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 235 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 236 |
+
kwargs["scale_factor"] = (sx, sy)
|
| 237 |
+
else:
|
| 238 |
+
# Simply specify an output size instead of a scale factor
|
| 239 |
+
kwargs["size"] = (w0, h0)
|
| 240 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 241 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 242 |
+
mode="bicubic",
|
| 243 |
+
antialias=self.interpolate_antialias,
|
| 244 |
+
**kwargs,
|
| 245 |
+
)
|
| 246 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
| 247 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 248 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
| 249 |
+
|
| 250 |
+
def prepare_cls_token(self, B, S):
|
| 251 |
+
cls_token = self.cls_token.expand(B, S, -1)
|
| 252 |
+
cls_token = cls_token.reshape(B * S, -1, self.embed_dim)
|
| 253 |
+
return cls_token
|
| 254 |
+
|
| 255 |
+
def prepare_tokens_with_masks(self, x, masks=None, cls_token=None, **kwargs):
|
| 256 |
+
B, S, nc, w, h = x.shape
|
| 257 |
+
x = rearrange(x, "b s c h w -> (b s) c h w")
|
| 258 |
+
x = self.patch_embed(x)
|
| 259 |
+
if masks is not None:
|
| 260 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 261 |
+
cls_token = self.prepare_cls_token(B, S)
|
| 262 |
+
x = torch.cat((cls_token, x), dim=1)
|
| 263 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 264 |
+
if self.register_tokens is not None:
|
| 265 |
+
x = torch.cat(
|
| 266 |
+
(
|
| 267 |
+
x[:, :1],
|
| 268 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 269 |
+
x[:, 1:],
|
| 270 |
+
),
|
| 271 |
+
dim=1,
|
| 272 |
+
)
|
| 273 |
+
x = rearrange(x, "(b s) n c -> b s n c", b=B, s=S)
|
| 274 |
+
return x
|
| 275 |
+
|
| 276 |
+
def _prepare_rope(self, B, S, H, W, device):
|
| 277 |
+
pos = None
|
| 278 |
+
pos_nodiff = None
|
| 279 |
+
if self.rope is not None:
|
| 280 |
+
pos = self.position_getter(
|
| 281 |
+
B * S, H // self.patch_size, W // self.patch_size, device=device
|
| 282 |
+
)
|
| 283 |
+
pos = rearrange(pos, "(b s) n c -> b s n c", b=B)
|
| 284 |
+
pos_nodiff = torch.zeros_like(pos).to(pos.dtype)
|
| 285 |
+
if self.patch_start_idx > 0:
|
| 286 |
+
pos = pos + 1
|
| 287 |
+
pos_special = torch.zeros(B * S, self.patch_start_idx, 2).to(device).to(pos.dtype)
|
| 288 |
+
pos_special = rearrange(pos_special, "(b s) n c -> b s n c", b=B)
|
| 289 |
+
pos = torch.cat([pos_special, pos], dim=2)
|
| 290 |
+
pos_nodiff = pos_nodiff + 1
|
| 291 |
+
pos_nodiff = torch.cat([pos_special, pos_nodiff], dim=2)
|
| 292 |
+
return pos, pos_nodiff
|
| 293 |
+
|
| 294 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1, export_feat_layers=[], **kwargs):
|
| 295 |
+
B, S, _, H, W = x.shape
|
| 296 |
+
x = self.prepare_tokens_with_masks(x)
|
| 297 |
+
output, total_block_len, aux_output = [], len(self.blocks), []
|
| 298 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 299 |
+
pos, pos_nodiff = self._prepare_rope(B, S, H, W, x.device)
|
| 300 |
+
|
| 301 |
+
for i, blk in enumerate(self.blocks):
|
| 302 |
+
if i < self.rope_start or self.rope is None:
|
| 303 |
+
g_pos, l_pos = None, None
|
| 304 |
+
else:
|
| 305 |
+
g_pos = pos_nodiff
|
| 306 |
+
l_pos = pos
|
| 307 |
+
if self.alt_start != -1 and i == self.alt_start:
|
| 308 |
+
if kwargs.get("cam_token", None) is not None:
|
| 309 |
+
logger.info("Using camera conditions provided by the user")
|
| 310 |
+
cam_token = kwargs.get("cam_token")
|
| 311 |
+
else:
|
| 312 |
+
ref_token = self.camera_token[:, :1].expand(B, -1, -1)
|
| 313 |
+
src_token = self.camera_token[:, 1:].expand(B, S - 1, -1)
|
| 314 |
+
cam_token = torch.cat([ref_token, src_token], dim=1)
|
| 315 |
+
x[:, :, 0] = cam_token
|
| 316 |
+
|
| 317 |
+
if self.alt_start != -1 and i >= self.alt_start and i % 2 == 1:
|
| 318 |
+
x = self.process_attention(
|
| 319 |
+
x, blk, "global", pos=g_pos, attn_mask=kwargs.get("attn_mask", None)
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
x = self.process_attention(x, blk, "local", pos=l_pos)
|
| 323 |
+
local_x = x
|
| 324 |
+
|
| 325 |
+
if i in blocks_to_take:
|
| 326 |
+
out_x = torch.cat([local_x, x], dim=-1) if self.cat_token else x
|
| 327 |
+
output.append((out_x[:, :, 0], out_x))
|
| 328 |
+
if i in export_feat_layers:
|
| 329 |
+
aux_output.append(x)
|
| 330 |
+
return output, aux_output
|
| 331 |
+
|
| 332 |
+
def process_attention(self, x, block, attn_type="global", pos=None, attn_mask=None):
|
| 333 |
+
b, s, n = x.shape[:3]
|
| 334 |
+
if attn_type == "local":
|
| 335 |
+
x = rearrange(x, "b s n c -> (b s) n c")
|
| 336 |
+
if pos is not None:
|
| 337 |
+
pos = rearrange(pos, "b s n c -> (b s) n c")
|
| 338 |
+
elif attn_type == "global":
|
| 339 |
+
x = rearrange(x, "b s n c -> b (s n) c")
|
| 340 |
+
if pos is not None:
|
| 341 |
+
pos = rearrange(pos, "b s n c -> b (s n) c")
|
| 342 |
+
else:
|
| 343 |
+
raise ValueError(f"Invalid attention type: {attn_type}")
|
| 344 |
+
|
| 345 |
+
x = block(x, pos=pos, attn_mask=attn_mask)
|
| 346 |
+
|
| 347 |
+
if attn_type == "local":
|
| 348 |
+
x = rearrange(x, "(b s) n c -> b s n c", b=b, s=s)
|
| 349 |
+
elif attn_type == "global":
|
| 350 |
+
x = rearrange(x, "b (s n) c -> b s n c", b=b, s=s)
|
| 351 |
+
return x
|
| 352 |
+
|
| 353 |
+
def get_intermediate_layers(
|
| 354 |
+
self,
|
| 355 |
+
x: torch.Tensor,
|
| 356 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 357 |
+
export_feat_layers: List[int] = [],
|
| 358 |
+
**kwargs,
|
| 359 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 360 |
+
outputs, aux_outputs = self._get_intermediate_layers_not_chunked(
|
| 361 |
+
x, n, export_feat_layers=export_feat_layers, **kwargs
|
| 362 |
+
)
|
| 363 |
+
camera_tokens = [out[0] for out in outputs]
|
| 364 |
+
if outputs[0][1].shape[-1] == self.embed_dim:
|
| 365 |
+
outputs = [self.norm(out[1]) for out in outputs]
|
| 366 |
+
elif outputs[0][1].shape[-1] == (self.embed_dim * 2):
|
| 367 |
+
outputs = [
|
| 368 |
+
torch.cat(
|
| 369 |
+
[out[1][..., : self.embed_dim], self.norm(out[1][..., self.embed_dim :])],
|
| 370 |
+
dim=-1,
|
| 371 |
+
)
|
| 372 |
+
for out in outputs
|
| 373 |
+
]
|
| 374 |
+
else:
|
| 375 |
+
raise ValueError(f"Invalid output shape: {outputs[0][1].shape}")
|
| 376 |
+
aux_outputs = [self.norm(out) for out in aux_outputs]
|
| 377 |
+
outputs = [out[..., 1 + self.num_register_tokens :, :] for out in outputs]
|
| 378 |
+
aux_outputs = [out[..., 1 + self.num_register_tokens :, :] for out in aux_outputs]
|
| 379 |
+
return tuple(zip(outputs, camera_tokens)), aux_outputs
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
def vit_small(patch_size=16, num_register_tokens=0, depth=12, **kwargs):
|
| 383 |
+
model = DinoVisionTransformer(
|
| 384 |
+
patch_size=patch_size,
|
| 385 |
+
embed_dim=384,
|
| 386 |
+
depth=depth,
|
| 387 |
+
num_heads=6,
|
| 388 |
+
mlp_ratio=4,
|
| 389 |
+
# block_fn=partial(Block, attn_class=MemEffAttention),
|
| 390 |
+
num_register_tokens=num_register_tokens,
|
| 391 |
+
**kwargs,
|
| 392 |
+
)
|
| 393 |
+
return model
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def vit_base(patch_size=16, num_register_tokens=0, depth=12, **kwargs):
|
| 397 |
+
model = DinoVisionTransformer(
|
| 398 |
+
patch_size=patch_size,
|
| 399 |
+
embed_dim=768,
|
| 400 |
+
depth=depth,
|
| 401 |
+
num_heads=12,
|
| 402 |
+
mlp_ratio=4,
|
| 403 |
+
# block_fn=partial(Block, attn_class=MemEffAttention),
|
| 404 |
+
num_register_tokens=num_register_tokens,
|
| 405 |
+
**kwargs,
|
| 406 |
+
)
|
| 407 |
+
return model
|
| 408 |
+
|
| 409 |
+
|
| 410 |
+
def vit_large(patch_size=16, num_register_tokens=0, depth=24, **kwargs):
|
| 411 |
+
model = DinoVisionTransformer(
|
| 412 |
+
patch_size=patch_size,
|
| 413 |
+
embed_dim=1024,
|
| 414 |
+
depth=depth,
|
| 415 |
+
num_heads=16,
|
| 416 |
+
mlp_ratio=4,
|
| 417 |
+
# block_fn=partial(Block, attn_class=MemEffAttention),
|
| 418 |
+
num_register_tokens=num_register_tokens,
|
| 419 |
+
**kwargs,
|
| 420 |
+
)
|
| 421 |
+
return model
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, depth=40, **kwargs):
|
| 425 |
+
"""
|
| 426 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 427 |
+
"""
|
| 428 |
+
model = DinoVisionTransformer(
|
| 429 |
+
patch_size=patch_size,
|
| 430 |
+
embed_dim=1536,
|
| 431 |
+
depth=depth,
|
| 432 |
+
num_heads=24,
|
| 433 |
+
mlp_ratio=4,
|
| 434 |
+
num_register_tokens=num_register_tokens,
|
| 435 |
+
**kwargs,
|
| 436 |
+
)
|
| 437 |
+
return model
|
src/depth_anything_3/model/dpt.py
ADDED
|
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa E501
|
| 2 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Dict as TyDict
|
| 17 |
+
from typing import List, Sequence, Tuple
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn as nn
|
| 20 |
+
from addict import Dict
|
| 21 |
+
from einops import rearrange
|
| 22 |
+
|
| 23 |
+
from depth_anything_3.model.utils.head_utils import (
|
| 24 |
+
Permute,
|
| 25 |
+
create_uv_grid,
|
| 26 |
+
custom_interpolate,
|
| 27 |
+
position_grid_to_embed,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class DPT(nn.Module):
|
| 32 |
+
"""
|
| 33 |
+
DPT for dense prediction (main head + optional sky head, sky always 1 channel).
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
- Main head:
|
| 37 |
+
* If output_dim>1: { head_name, f"{head_name}_conf" }
|
| 38 |
+
* If output_dim==1: { head_name }
|
| 39 |
+
- Sky head (if use_sky_head=True): { sky_name } # [B, S, 1, H/down_ratio, W/down_ratio]
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
dim_in: int,
|
| 45 |
+
*,
|
| 46 |
+
patch_size: int = 14,
|
| 47 |
+
output_dim: int = 1,
|
| 48 |
+
activation: str = "exp",
|
| 49 |
+
conf_activation: str = "expp1",
|
| 50 |
+
features: int = 256,
|
| 51 |
+
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
| 52 |
+
pos_embed: bool = False,
|
| 53 |
+
down_ratio: int = 1,
|
| 54 |
+
head_name: str = "depth",
|
| 55 |
+
# ---- sky head (fixed 1 channel) ----
|
| 56 |
+
use_sky_head: bool = True,
|
| 57 |
+
sky_name: str = "sky",
|
| 58 |
+
sky_activation: str = "relu", # 'sigmoid' / 'relu' / 'linear'
|
| 59 |
+
use_ln_for_heads: bool = False, # If needed, apply LayerNorm on intermediate features of both heads
|
| 60 |
+
norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer"
|
| 61 |
+
fusion_block_inplace: bool = False,
|
| 62 |
+
) -> None:
|
| 63 |
+
super().__init__()
|
| 64 |
+
|
| 65 |
+
# -------------------- configuration --------------------
|
| 66 |
+
self.patch_size = patch_size
|
| 67 |
+
self.activation = activation
|
| 68 |
+
self.conf_activation = conf_activation
|
| 69 |
+
self.pos_embed = pos_embed
|
| 70 |
+
self.down_ratio = down_ratio
|
| 71 |
+
|
| 72 |
+
# Names
|
| 73 |
+
self.head_main = head_name
|
| 74 |
+
self.sky_name = sky_name
|
| 75 |
+
|
| 76 |
+
# Main head: output dimension and confidence switch
|
| 77 |
+
self.out_dim = output_dim
|
| 78 |
+
self.has_conf = output_dim > 1
|
| 79 |
+
|
| 80 |
+
# Sky head parameters (always 1 channel)
|
| 81 |
+
self.use_sky_head = use_sky_head
|
| 82 |
+
self.sky_activation = sky_activation
|
| 83 |
+
|
| 84 |
+
# Fixed 4 intermediate outputs
|
| 85 |
+
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
| 86 |
+
|
| 87 |
+
# -------------------- token pre-norm + per-stage projection --------------------
|
| 88 |
+
if norm_type == "layer":
|
| 89 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 90 |
+
elif norm_type == "idt":
|
| 91 |
+
self.norm = nn.Identity()
|
| 92 |
+
else:
|
| 93 |
+
raise Exception(f"Unknown norm_type {norm_type}, should be 'layer' or 'idt'.")
|
| 94 |
+
self.projects = nn.ModuleList(
|
| 95 |
+
[nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# -------------------- Spatial re-size (align to common scale before fusion) --------------------
|
| 99 |
+
# Design consistent with original: relative to patch grid (x4, x2, x1, /2)
|
| 100 |
+
self.resize_layers = nn.ModuleList(
|
| 101 |
+
[
|
| 102 |
+
nn.ConvTranspose2d(
|
| 103 |
+
out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0
|
| 104 |
+
),
|
| 105 |
+
nn.ConvTranspose2d(
|
| 106 |
+
out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0
|
| 107 |
+
),
|
| 108 |
+
nn.Identity(),
|
| 109 |
+
nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1),
|
| 110 |
+
]
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
# -------------------- scratch: stage adapters + main fusion chain --------------------
|
| 114 |
+
self.scratch = _make_scratch(list(out_channels), features, expand=False)
|
| 115 |
+
|
| 116 |
+
# Main fusion chain
|
| 117 |
+
self.scratch.refinenet1 = _make_fusion_block(features, inplace=fusion_block_inplace)
|
| 118 |
+
self.scratch.refinenet2 = _make_fusion_block(features, inplace=fusion_block_inplace)
|
| 119 |
+
self.scratch.refinenet3 = _make_fusion_block(features, inplace=fusion_block_inplace)
|
| 120 |
+
self.scratch.refinenet4 = _make_fusion_block(
|
| 121 |
+
features, has_residual=False, inplace=fusion_block_inplace
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Heads (shared neck1; then split into two heads)
|
| 125 |
+
head_features_1 = features
|
| 126 |
+
head_features_2 = 32
|
| 127 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 128 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
ln_seq = (
|
| 132 |
+
[Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))]
|
| 133 |
+
if use_ln_for_heads
|
| 134 |
+
else []
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
# Main head
|
| 138 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 139 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 140 |
+
*ln_seq,
|
| 141 |
+
nn.ReLU(inplace=True),
|
| 142 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
# Sky head (fixed 1 channel)
|
| 146 |
+
if self.use_sky_head:
|
| 147 |
+
self.scratch.sky_output_conv2 = nn.Sequential(
|
| 148 |
+
nn.Conv2d(
|
| 149 |
+
head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1
|
| 150 |
+
),
|
| 151 |
+
*ln_seq,
|
| 152 |
+
nn.ReLU(inplace=True),
|
| 153 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
# -------------------------------------------------------------------------
|
| 157 |
+
# Public forward (supports frame chunking to save memory)
|
| 158 |
+
# -------------------------------------------------------------------------
|
| 159 |
+
def forward(
|
| 160 |
+
self,
|
| 161 |
+
feats: List[torch.Tensor],
|
| 162 |
+
H: int,
|
| 163 |
+
W: int,
|
| 164 |
+
patch_start_idx: int,
|
| 165 |
+
chunk_size: int = 8,
|
| 166 |
+
**kwargs,
|
| 167 |
+
) -> Dict:
|
| 168 |
+
"""
|
| 169 |
+
Args:
|
| 170 |
+
feats: List of 4 entries, each entry is a tensor like [B, S, T, C] (or the 0th element of tuple/list is that tensor).
|
| 171 |
+
H, W: Original image dimensions
|
| 172 |
+
patch_start_idx: Starting index of patch tokens in sequence (for cropping non-patch tokens)
|
| 173 |
+
chunk_size: Chunk size along time dimension S
|
| 174 |
+
|
| 175 |
+
Returns:
|
| 176 |
+
Dict[str, Tensor]
|
| 177 |
+
"""
|
| 178 |
+
B, S, N, C = feats[0][0].shape
|
| 179 |
+
feats = [feat[0].reshape(B * S, N, C) for feat in feats]
|
| 180 |
+
|
| 181 |
+
# update image info, used by the GS-DPT head
|
| 182 |
+
extra_kwargs = {}
|
| 183 |
+
if "images" in kwargs:
|
| 184 |
+
extra_kwargs.update({"images": rearrange(kwargs["images"], "B S ... -> (B S) ...")})
|
| 185 |
+
|
| 186 |
+
if chunk_size is None or chunk_size >= S:
|
| 187 |
+
out_dict = self._forward_impl(feats, H, W, patch_start_idx, **extra_kwargs)
|
| 188 |
+
out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
|
| 189 |
+
return Dict(out_dict)
|
| 190 |
+
|
| 191 |
+
out_dicts: List[TyDict[str, torch.Tensor]] = []
|
| 192 |
+
for s0 in range(0, S, chunk_size):
|
| 193 |
+
s1 = min(s0 + chunk_size, S)
|
| 194 |
+
kw = {}
|
| 195 |
+
if "images" in extra_kwargs:
|
| 196 |
+
kw.update({"images": extra_kwargs["images"][s0:s1]})
|
| 197 |
+
out_dicts.append(
|
| 198 |
+
self._forward_impl([f[s0:s1] for f in feats], H, W, patch_start_idx, **kw)
|
| 199 |
+
)
|
| 200 |
+
out_dict = {k: torch.cat([od[k] for od in out_dicts], dim=0) for k in out_dicts[0].keys()}
|
| 201 |
+
out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
|
| 202 |
+
return Dict(out_dict)
|
| 203 |
+
|
| 204 |
+
# -------------------------------------------------------------------------
|
| 205 |
+
# Internal forward (single chunk)
|
| 206 |
+
# -------------------------------------------------------------------------
|
| 207 |
+
def _forward_impl(
|
| 208 |
+
self,
|
| 209 |
+
feats: List[torch.Tensor],
|
| 210 |
+
H: int,
|
| 211 |
+
W: int,
|
| 212 |
+
patch_start_idx: int,
|
| 213 |
+
) -> TyDict[str, torch.Tensor]:
|
| 214 |
+
B, _, C = feats[0].shape
|
| 215 |
+
ph, pw = H // self.patch_size, W // self.patch_size
|
| 216 |
+
resized_feats = []
|
| 217 |
+
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
| 218 |
+
x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C]
|
| 219 |
+
x = self.norm(x)
|
| 220 |
+
x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
|
| 221 |
+
|
| 222 |
+
x = self.projects[stage_idx](x)
|
| 223 |
+
if self.pos_embed:
|
| 224 |
+
x = self._add_pos_embed(x, W, H)
|
| 225 |
+
x = self.resize_layers[stage_idx](x) # Align scale
|
| 226 |
+
resized_feats.append(x)
|
| 227 |
+
|
| 228 |
+
# 2) Fusion pyramid (main branch only)
|
| 229 |
+
fused = self._fuse(resized_feats)
|
| 230 |
+
|
| 231 |
+
# 3) Upsample to target resolution, optionally add position encoding again
|
| 232 |
+
h_out = int(ph * self.patch_size / self.down_ratio)
|
| 233 |
+
w_out = int(pw * self.patch_size / self.down_ratio)
|
| 234 |
+
|
| 235 |
+
fused = self.scratch.output_conv1(fused)
|
| 236 |
+
fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
|
| 237 |
+
if self.pos_embed:
|
| 238 |
+
fused = self._add_pos_embed(fused, W, H)
|
| 239 |
+
|
| 240 |
+
# 4) Shared neck1
|
| 241 |
+
feat = fused
|
| 242 |
+
|
| 243 |
+
# 5) Main head: logits -> activation
|
| 244 |
+
main_logits = self.scratch.output_conv2(feat)
|
| 245 |
+
outs: TyDict[str, torch.Tensor] = {}
|
| 246 |
+
if self.has_conf:
|
| 247 |
+
fmap = main_logits.permute(0, 2, 3, 1)
|
| 248 |
+
pred = self._apply_activation_single(fmap[..., :-1], self.activation)
|
| 249 |
+
conf = self._apply_activation_single(fmap[..., -1], self.conf_activation)
|
| 250 |
+
outs[self.head_main] = pred.squeeze(1)
|
| 251 |
+
outs[f"{self.head_main}_conf"] = conf.squeeze(1)
|
| 252 |
+
else:
|
| 253 |
+
outs[self.head_main] = self._apply_activation_single(
|
| 254 |
+
main_logits, self.activation
|
| 255 |
+
).squeeze(1)
|
| 256 |
+
|
| 257 |
+
# 6) Sky head (fixed 1 channel)
|
| 258 |
+
if self.use_sky_head:
|
| 259 |
+
sky_logits = self.scratch.sky_output_conv2(feat)
|
| 260 |
+
outs[self.sky_name] = self._apply_sky_activation(sky_logits).squeeze(1)
|
| 261 |
+
|
| 262 |
+
return outs
|
| 263 |
+
|
| 264 |
+
# -------------------------------------------------------------------------
|
| 265 |
+
# Subroutines
|
| 266 |
+
# -------------------------------------------------------------------------
|
| 267 |
+
def _fuse(self, feats: List[torch.Tensor]) -> torch.Tensor:
|
| 268 |
+
"""
|
| 269 |
+
4-layer top-down fusion, returns finest scale features (after fusion, before neck1).
|
| 270 |
+
"""
|
| 271 |
+
l1, l2, l3, l4 = feats
|
| 272 |
+
|
| 273 |
+
l1_rn = self.scratch.layer1_rn(l1)
|
| 274 |
+
l2_rn = self.scratch.layer2_rn(l2)
|
| 275 |
+
l3_rn = self.scratch.layer3_rn(l3)
|
| 276 |
+
l4_rn = self.scratch.layer4_rn(l4)
|
| 277 |
+
|
| 278 |
+
# 4 -> 3 -> 2 -> 1
|
| 279 |
+
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
| 280 |
+
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
|
| 281 |
+
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
|
| 282 |
+
out = self.scratch.refinenet1(out, l1_rn)
|
| 283 |
+
return out
|
| 284 |
+
|
| 285 |
+
def _apply_activation_single(
|
| 286 |
+
self, x: torch.Tensor, activation: str = "linear"
|
| 287 |
+
) -> torch.Tensor:
|
| 288 |
+
"""
|
| 289 |
+
Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case.
|
| 290 |
+
Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1
|
| 291 |
+
"""
|
| 292 |
+
act = activation.lower() if isinstance(activation, str) else activation
|
| 293 |
+
if act == "exp":
|
| 294 |
+
return torch.exp(x)
|
| 295 |
+
if act == "expp1":
|
| 296 |
+
return torch.exp(x) + 1
|
| 297 |
+
if act == "expm1":
|
| 298 |
+
return torch.expm1(x)
|
| 299 |
+
if act == "relu":
|
| 300 |
+
return torch.relu(x)
|
| 301 |
+
if act == "sigmoid":
|
| 302 |
+
return torch.sigmoid(x)
|
| 303 |
+
if act == "softplus":
|
| 304 |
+
return torch.nn.functional.softplus(x)
|
| 305 |
+
if act == "tanh":
|
| 306 |
+
return torch.tanh(x)
|
| 307 |
+
# Default linear
|
| 308 |
+
return x
|
| 309 |
+
|
| 310 |
+
def _apply_sky_activation(self, x: torch.Tensor) -> torch.Tensor:
|
| 311 |
+
"""
|
| 312 |
+
Sky head activation (fixed 1 channel):
|
| 313 |
+
* 'sigmoid' -> Sigmoid probability map
|
| 314 |
+
* 'relu' -> ReLU positive domain output
|
| 315 |
+
* 'linear' -> Original value (logits)
|
| 316 |
+
"""
|
| 317 |
+
act = (
|
| 318 |
+
self.sky_activation.lower()
|
| 319 |
+
if isinstance(self.sky_activation, str)
|
| 320 |
+
else self.sky_activation
|
| 321 |
+
)
|
| 322 |
+
if act == "sigmoid":
|
| 323 |
+
return torch.sigmoid(x)
|
| 324 |
+
if act == "relu":
|
| 325 |
+
return torch.relu(x)
|
| 326 |
+
# 'linear'
|
| 327 |
+
return x
|
| 328 |
+
|
| 329 |
+
def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 330 |
+
"""Simple UV position encoding directly added to feature map."""
|
| 331 |
+
pw, ph = x.shape[-1], x.shape[-2]
|
| 332 |
+
pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 333 |
+
pe = position_grid_to_embed(pe, x.shape[1]) * ratio
|
| 334 |
+
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 335 |
+
return x + pe
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
# -----------------------------------------------------------------------------
|
| 339 |
+
# Building blocks (preserved, consistent with original)
|
| 340 |
+
# -----------------------------------------------------------------------------
|
| 341 |
+
def _make_fusion_block(
|
| 342 |
+
features: int,
|
| 343 |
+
size: Tuple[int, int] = None,
|
| 344 |
+
has_residual: bool = True,
|
| 345 |
+
groups: int = 1,
|
| 346 |
+
inplace: bool = False,
|
| 347 |
+
) -> nn.Module:
|
| 348 |
+
return FeatureFusionBlock(
|
| 349 |
+
features=features,
|
| 350 |
+
activation=nn.ReLU(inplace=inplace),
|
| 351 |
+
deconv=False,
|
| 352 |
+
bn=False,
|
| 353 |
+
expand=False,
|
| 354 |
+
align_corners=True,
|
| 355 |
+
size=size,
|
| 356 |
+
has_residual=has_residual,
|
| 357 |
+
groups=groups,
|
| 358 |
+
)
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def _make_scratch(
|
| 362 |
+
in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
|
| 363 |
+
) -> nn.Module:
|
| 364 |
+
scratch = nn.Module()
|
| 365 |
+
# Optional expansion by stage
|
| 366 |
+
c1 = out_shape
|
| 367 |
+
c2 = out_shape * (2 if expand else 1)
|
| 368 |
+
c3 = out_shape * (4 if expand else 1)
|
| 369 |
+
c4 = out_shape * (8 if expand else 1)
|
| 370 |
+
|
| 371 |
+
scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups)
|
| 372 |
+
scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups)
|
| 373 |
+
scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups)
|
| 374 |
+
scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups)
|
| 375 |
+
return scratch
|
| 376 |
+
|
| 377 |
+
|
| 378 |
+
class ResidualConvUnit(nn.Module):
|
| 379 |
+
"""Lightweight residual convolution block for fusion"""
|
| 380 |
+
|
| 381 |
+
def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None:
|
| 382 |
+
super().__init__()
|
| 383 |
+
self.bn = bn
|
| 384 |
+
self.groups = groups
|
| 385 |
+
self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
|
| 386 |
+
self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
|
| 387 |
+
self.norm1 = None
|
| 388 |
+
self.norm2 = None
|
| 389 |
+
self.activation = activation
|
| 390 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 391 |
+
|
| 392 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
| 393 |
+
out = self.activation(x)
|
| 394 |
+
out = self.conv1(out)
|
| 395 |
+
if self.norm1 is not None:
|
| 396 |
+
out = self.norm1(out)
|
| 397 |
+
|
| 398 |
+
out = self.activation(out)
|
| 399 |
+
out = self.conv2(out)
|
| 400 |
+
if self.norm2 is not None:
|
| 401 |
+
out = self.norm2(out)
|
| 402 |
+
|
| 403 |
+
return self.skip_add.add(out, x)
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
class FeatureFusionBlock(nn.Module):
|
| 407 |
+
"""Top-down fusion block: (optional) residual merge + upsampling + 1x1 contraction"""
|
| 408 |
+
|
| 409 |
+
def __init__(
|
| 410 |
+
self,
|
| 411 |
+
features: int,
|
| 412 |
+
activation: nn.Module,
|
| 413 |
+
deconv: bool = False,
|
| 414 |
+
bn: bool = False,
|
| 415 |
+
expand: bool = False,
|
| 416 |
+
align_corners: bool = True,
|
| 417 |
+
size: Tuple[int, int] = None,
|
| 418 |
+
has_residual: bool = True,
|
| 419 |
+
groups: int = 1,
|
| 420 |
+
) -> None:
|
| 421 |
+
super().__init__()
|
| 422 |
+
self.align_corners = align_corners
|
| 423 |
+
self.size = size
|
| 424 |
+
self.has_residual = has_residual
|
| 425 |
+
|
| 426 |
+
self.resConfUnit1 = (
|
| 427 |
+
ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None
|
| 428 |
+
)
|
| 429 |
+
self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups)
|
| 430 |
+
|
| 431 |
+
out_features = (features // 2) if expand else features
|
| 432 |
+
self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups)
|
| 433 |
+
self.skip_add = nn.quantized.FloatFunctional()
|
| 434 |
+
|
| 435 |
+
def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override]
|
| 436 |
+
"""
|
| 437 |
+
xs:
|
| 438 |
+
- xs[0]: Top branch input
|
| 439 |
+
- xs[1]: Lateral input (can do residual addition with top branch)
|
| 440 |
+
"""
|
| 441 |
+
y = xs[0]
|
| 442 |
+
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
|
| 443 |
+
y = self.skip_add.add(y, self.resConfUnit1(xs[1]))
|
| 444 |
+
|
| 445 |
+
y = self.resConfUnit2(y)
|
| 446 |
+
|
| 447 |
+
# Upsampling
|
| 448 |
+
if (size is None) and (self.size is None):
|
| 449 |
+
up_kwargs = {"scale_factor": 2}
|
| 450 |
+
elif size is None:
|
| 451 |
+
up_kwargs = {"size": self.size}
|
| 452 |
+
else:
|
| 453 |
+
up_kwargs = {"size": size}
|
| 454 |
+
|
| 455 |
+
y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
|
| 456 |
+
y = self.out_conv(y)
|
| 457 |
+
return y
|
src/depth_anything_3/model/dualdpt.py
ADDED
|
@@ -0,0 +1,488 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# flake8: noqa E501
|
| 2 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import List, Sequence, Tuple
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from addict import Dict
|
| 20 |
+
|
| 21 |
+
from depth_anything_3.model.dpt import _make_fusion_block, _make_scratch
|
| 22 |
+
from depth_anything_3.model.utils.head_utils import (
|
| 23 |
+
Permute,
|
| 24 |
+
create_uv_grid,
|
| 25 |
+
custom_interpolate,
|
| 26 |
+
position_grid_to_embed,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class DualDPT(nn.Module):
|
| 31 |
+
"""
|
| 32 |
+
Dual-head DPT for dense prediction with an always-on auxiliary head.
|
| 33 |
+
|
| 34 |
+
Architectural notes:
|
| 35 |
+
- Sky/object branches are removed.
|
| 36 |
+
- `intermediate_layer_idx` is fixed to (0, 1, 2, 3).
|
| 37 |
+
- Auxiliary head has its **own** fusion blocks (no fusion_inplace / no sharing).
|
| 38 |
+
- Auxiliary head is internally multi-level; **only the final level** is returned.
|
| 39 |
+
- Returns a **dict** with keys from `head_names`, e.g.:
|
| 40 |
+
{ main_name, f"{main_name}_conf", aux_name, f"{aux_name}_conf" }
|
| 41 |
+
- `feature_only` is fixed to False.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(
|
| 45 |
+
self,
|
| 46 |
+
dim_in: int,
|
| 47 |
+
*,
|
| 48 |
+
patch_size: int = 14,
|
| 49 |
+
output_dim: int = 2,
|
| 50 |
+
activation: str = "exp",
|
| 51 |
+
conf_activation: str = "expp1",
|
| 52 |
+
features: int = 256,
|
| 53 |
+
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
| 54 |
+
pos_embed: bool = True,
|
| 55 |
+
down_ratio: int = 1,
|
| 56 |
+
aux_pyramid_levels: int = 4,
|
| 57 |
+
aux_out1_conv_num: int = 5,
|
| 58 |
+
head_names: Tuple[str, str] = ("depth", "ray"),
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
|
| 62 |
+
# -------------------- configuration --------------------
|
| 63 |
+
self.patch_size = patch_size
|
| 64 |
+
self.activation = activation
|
| 65 |
+
self.conf_activation = conf_activation
|
| 66 |
+
self.pos_embed = pos_embed
|
| 67 |
+
self.down_ratio = down_ratio
|
| 68 |
+
|
| 69 |
+
self.aux_levels = aux_pyramid_levels
|
| 70 |
+
self.aux_out1_conv_num = aux_out1_conv_num
|
| 71 |
+
|
| 72 |
+
# names ONLY come from config (no hard-coded strings elsewhere)
|
| 73 |
+
self.head_main, self.head_aux = head_names
|
| 74 |
+
|
| 75 |
+
# Always expect 4 scales; enforce intermediate idx = (0, 1, 2, 3)
|
| 76 |
+
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
| 77 |
+
|
| 78 |
+
# -------------------- token pre-norm + per-stage projection --------------------
|
| 79 |
+
self.norm = nn.LayerNorm(dim_in)
|
| 80 |
+
self.projects = nn.ModuleList(
|
| 81 |
+
[nn.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0) for oc in out_channels]
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# -------------------- spatial re-sizers (align to common scale before fusion) --------------------
|
| 85 |
+
# design: stage strides (x4, x2, x1, /2) relative to patch grid to align to a common pivot scale
|
| 86 |
+
self.resize_layers = nn.ModuleList(
|
| 87 |
+
[
|
| 88 |
+
nn.ConvTranspose2d(
|
| 89 |
+
out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0
|
| 90 |
+
),
|
| 91 |
+
nn.ConvTranspose2d(
|
| 92 |
+
out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0
|
| 93 |
+
),
|
| 94 |
+
nn.Identity(),
|
| 95 |
+
nn.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1),
|
| 96 |
+
]
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
# -------------------- scratch: stage adapters + fusion (main & aux are separate) --------------------
|
| 100 |
+
self.scratch = _make_scratch(list(out_channels), features, expand=False)
|
| 101 |
+
|
| 102 |
+
# Main fusion chain (independent)
|
| 103 |
+
self.scratch.refinenet1 = _make_fusion_block(features)
|
| 104 |
+
self.scratch.refinenet2 = _make_fusion_block(features)
|
| 105 |
+
self.scratch.refinenet3 = _make_fusion_block(features)
|
| 106 |
+
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False)
|
| 107 |
+
|
| 108 |
+
# Primary head neck + head (independent)
|
| 109 |
+
head_features_1 = features
|
| 110 |
+
head_features_2 = 32
|
| 111 |
+
self.scratch.output_conv1 = nn.Conv2d(
|
| 112 |
+
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1
|
| 113 |
+
)
|
| 114 |
+
self.scratch.output_conv2 = nn.Sequential(
|
| 115 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
| 116 |
+
nn.ReLU(inplace=True),
|
| 117 |
+
nn.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# Auxiliary fusion chain (completely separate; no sharing, i.e., "fusion_inplace=False")
|
| 121 |
+
self.scratch.refinenet1_aux = _make_fusion_block(features)
|
| 122 |
+
self.scratch.refinenet2_aux = _make_fusion_block(features)
|
| 123 |
+
self.scratch.refinenet3_aux = _make_fusion_block(features)
|
| 124 |
+
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False)
|
| 125 |
+
|
| 126 |
+
# Aux pre-head per level (we will only *return final level*)
|
| 127 |
+
self.scratch.output_conv1_aux = nn.ModuleList(
|
| 128 |
+
[self._make_aux_out1_block(head_features_1) for _ in range(self.aux_levels)]
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
# Aux final projection per level
|
| 132 |
+
use_ln = True
|
| 133 |
+
ln_seq = (
|
| 134 |
+
[Permute((0, 2, 3, 1)), nn.LayerNorm(head_features_2), Permute((0, 3, 1, 2))]
|
| 135 |
+
if use_ln
|
| 136 |
+
else []
|
| 137 |
+
)
|
| 138 |
+
self.scratch.output_conv2_aux = nn.ModuleList(
|
| 139 |
+
[
|
| 140 |
+
nn.Sequential(
|
| 141 |
+
nn.Conv2d(
|
| 142 |
+
head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1
|
| 143 |
+
),
|
| 144 |
+
*ln_seq,
|
| 145 |
+
nn.ReLU(inplace=True),
|
| 146 |
+
nn.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0),
|
| 147 |
+
)
|
| 148 |
+
for _ in range(self.aux_levels)
|
| 149 |
+
]
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# -------------------------------------------------------------------------
|
| 153 |
+
# Public forward (supports frame chunking for memory)
|
| 154 |
+
# -------------------------------------------------------------------------
|
| 155 |
+
|
| 156 |
+
def forward(
|
| 157 |
+
self,
|
| 158 |
+
feats: List[torch.Tensor],
|
| 159 |
+
H: int,
|
| 160 |
+
W: int,
|
| 161 |
+
patch_start_idx: int,
|
| 162 |
+
chunk_size: int = 8,
|
| 163 |
+
) -> Dict[str, torch.Tensor]:
|
| 164 |
+
"""
|
| 165 |
+
Args:
|
| 166 |
+
aggregated_tokens_list: List of 4 tensors [B, S, T, C] from transformer.
|
| 167 |
+
images: [B, S, 3, H, W], in [0, 1].
|
| 168 |
+
patch_start_idx: Patch-token start in the token sequence (to drop non-patch tokens).
|
| 169 |
+
frames_chunk_size: Optional chunking along S for memory.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
Dict[str, Tensor] with keys based on `head_names`, e.g.:
|
| 173 |
+
self.head_main, f"{self.head_main}_conf",
|
| 174 |
+
self.head_aux, f"{self.head_aux}_conf"
|
| 175 |
+
Shapes:
|
| 176 |
+
main: [B, S, out_dim, H/down_ratio, W/down_ratio]
|
| 177 |
+
main_cf: [B, S, 1, H/down_ratio, W/down_ratio]
|
| 178 |
+
aux: [B, S, 7, H/down_ratio, W/down_ratio]
|
| 179 |
+
aux_cf: [B, S, 1, H/down_ratio, W/down_ratio]
|
| 180 |
+
"""
|
| 181 |
+
B, S, N, C = feats[0][0].shape
|
| 182 |
+
feats = [feat[0].reshape(B * S, N, C) for feat in feats]
|
| 183 |
+
if chunk_size is None or chunk_size >= S:
|
| 184 |
+
out_dict = self._forward_impl(feats, H, W, patch_start_idx)
|
| 185 |
+
out_dict = {k: v.reshape(B, S, *v.shape[1:]) for k, v in out_dict.items()}
|
| 186 |
+
return Dict(out_dict)
|
| 187 |
+
out_dicts = []
|
| 188 |
+
for s0 in range(0, S, chunk_size):
|
| 189 |
+
s1 = min(s0 + chunk_size, S)
|
| 190 |
+
out_dict = self._forward_impl(
|
| 191 |
+
[feat[s0:s1] for feat in feats],
|
| 192 |
+
H,
|
| 193 |
+
W,
|
| 194 |
+
patch_start_idx,
|
| 195 |
+
)
|
| 196 |
+
out_dicts.append(out_dict)
|
| 197 |
+
out_dict = {
|
| 198 |
+
k: torch.cat([out_dict[k] for out_dict in out_dicts], dim=0)
|
| 199 |
+
for k in out_dicts[0].keys()
|
| 200 |
+
}
|
| 201 |
+
out_dict = {k: v.view(B, S, *v.shape[1:]) for k, v in out_dict.items()}
|
| 202 |
+
return Dict(out_dict)
|
| 203 |
+
|
| 204 |
+
# -------------------------------------------------------------------------
|
| 205 |
+
# Internal forward (single chunk)
|
| 206 |
+
# -------------------------------------------------------------------------
|
| 207 |
+
|
| 208 |
+
def _forward_impl(
|
| 209 |
+
self,
|
| 210 |
+
feats: List[torch.Tensor],
|
| 211 |
+
H: int,
|
| 212 |
+
W: int,
|
| 213 |
+
patch_start_idx: int,
|
| 214 |
+
) -> Dict[str, torch.Tensor]:
|
| 215 |
+
B, _, C = feats[0].shape
|
| 216 |
+
ph, pw = H // self.patch_size, W // self.patch_size
|
| 217 |
+
resized_feats = []
|
| 218 |
+
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
| 219 |
+
x = feats[take_idx][:, patch_start_idx:]
|
| 220 |
+
x = self.norm(x)
|
| 221 |
+
x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
|
| 222 |
+
|
| 223 |
+
x = self.projects[stage_idx](x)
|
| 224 |
+
if self.pos_embed:
|
| 225 |
+
x = self._add_pos_embed(x, W, H)
|
| 226 |
+
x = self.resize_layers[stage_idx](x) # align scales
|
| 227 |
+
resized_feats.append(x)
|
| 228 |
+
|
| 229 |
+
# 2) Fuse pyramid (main & aux are completely independent)
|
| 230 |
+
fused_main, fused_aux_pyr = self._fuse(resized_feats)
|
| 231 |
+
|
| 232 |
+
# 3) Upsample to target resolution and (optional) add pos-embed again
|
| 233 |
+
h_out = int(ph * self.patch_size / self.down_ratio)
|
| 234 |
+
w_out = int(pw * self.patch_size / self.down_ratio)
|
| 235 |
+
|
| 236 |
+
fused_main = custom_interpolate(
|
| 237 |
+
fused_main, (h_out, w_out), mode="bilinear", align_corners=True
|
| 238 |
+
)
|
| 239 |
+
if self.pos_embed:
|
| 240 |
+
fused_main = self._add_pos_embed(fused_main, W, H)
|
| 241 |
+
|
| 242 |
+
# Primary head: conv1 -> conv2 -> activate
|
| 243 |
+
# fused_main = self.scratch.output_conv1(fused_main)
|
| 244 |
+
main_logits = self.scratch.output_conv2(fused_main)
|
| 245 |
+
fmap = main_logits.permute(0, 2, 3, 1)
|
| 246 |
+
main_pred = self._apply_activation_single(fmap[..., :-1], self.activation)
|
| 247 |
+
main_conf = self._apply_activation_single(fmap[..., -1], self.conf_activation)
|
| 248 |
+
|
| 249 |
+
# Auxiliary head (multi-level inside) -> only last level returned (after activation)
|
| 250 |
+
last_aux = fused_aux_pyr[-1]
|
| 251 |
+
if self.pos_embed:
|
| 252 |
+
last_aux = self._add_pos_embed(last_aux, W, H)
|
| 253 |
+
# neck (per-level pre-conv) then final projection (only for last level)
|
| 254 |
+
# last_aux = self.scratch.output_conv1_aux[-1](last_aux)
|
| 255 |
+
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
|
| 256 |
+
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
|
| 257 |
+
aux_pred = self._apply_activation_single(fmap_last[..., :-1], "linear")
|
| 258 |
+
aux_conf = self._apply_activation_single(fmap_last[..., -1], self.conf_activation)
|
| 259 |
+
return {
|
| 260 |
+
self.head_main: main_pred.squeeze(-1),
|
| 261 |
+
f"{self.head_main}_conf": main_conf,
|
| 262 |
+
self.head_aux: aux_pred,
|
| 263 |
+
f"{self.head_aux}_conf": aux_conf,
|
| 264 |
+
}
|
| 265 |
+
|
| 266 |
+
# -------------------------------------------------------------------------
|
| 267 |
+
# Subroutines
|
| 268 |
+
# -------------------------------------------------------------------------
|
| 269 |
+
|
| 270 |
+
def _fuse(self, feats: List[torch.Tensor]) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
| 271 |
+
"""
|
| 272 |
+
Feature pyramid fusion.
|
| 273 |
+
Returns:
|
| 274 |
+
fused_main: Tensor at finest scale (after refinenet1)
|
| 275 |
+
aux_pyr: List of aux tensors at each level (pre out_conv1_aux)
|
| 276 |
+
"""
|
| 277 |
+
l1, l2, l3, l4 = feats
|
| 278 |
+
|
| 279 |
+
l1_rn = self.scratch.layer1_rn(l1)
|
| 280 |
+
l2_rn = self.scratch.layer2_rn(l2)
|
| 281 |
+
l3_rn = self.scratch.layer3_rn(l3)
|
| 282 |
+
l4_rn = self.scratch.layer4_rn(l4)
|
| 283 |
+
|
| 284 |
+
# level 4 -> 3
|
| 285 |
+
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
| 286 |
+
aux_out = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
|
| 287 |
+
aux_list: List[torch.Tensor] = []
|
| 288 |
+
if self.aux_levels >= 4:
|
| 289 |
+
aux_list.append(aux_out)
|
| 290 |
+
|
| 291 |
+
# level 3 -> 2
|
| 292 |
+
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
|
| 293 |
+
aux_out = self.scratch.refinenet3_aux(aux_out, l3_rn, size=l2_rn.shape[2:])
|
| 294 |
+
if self.aux_levels >= 3:
|
| 295 |
+
aux_list.append(aux_out)
|
| 296 |
+
|
| 297 |
+
# level 2 -> 1
|
| 298 |
+
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
|
| 299 |
+
aux_out = self.scratch.refinenet2_aux(aux_out, l2_rn, size=l1_rn.shape[2:])
|
| 300 |
+
if self.aux_levels >= 2:
|
| 301 |
+
aux_list.append(aux_out)
|
| 302 |
+
|
| 303 |
+
# level 1 (final)
|
| 304 |
+
out = self.scratch.refinenet1(out, l1_rn)
|
| 305 |
+
aux_out = self.scratch.refinenet1_aux(aux_out, l1_rn)
|
| 306 |
+
aux_list.append(aux_out)
|
| 307 |
+
|
| 308 |
+
out = self.scratch.output_conv1(out)
|
| 309 |
+
aux_list = [self.scratch.output_conv1_aux[i](aux) for i, aux in enumerate(aux_list)]
|
| 310 |
+
|
| 311 |
+
return out, aux_list
|
| 312 |
+
|
| 313 |
+
def _add_pos_embed(self, x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
| 314 |
+
"""Simple UV positional embedding added to feature maps."""
|
| 315 |
+
pw, ph = x.shape[-1], x.shape[-2]
|
| 316 |
+
pe = create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
| 317 |
+
pe = position_grid_to_embed(pe, x.shape[1]) * ratio
|
| 318 |
+
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1)
|
| 319 |
+
return x + pe
|
| 320 |
+
|
| 321 |
+
def _make_aux_out1_block(self, in_ch: int) -> nn.Sequential:
|
| 322 |
+
"""Factory for the aux pre-head stack before the final 1x1 projection."""
|
| 323 |
+
if self.aux_out1_conv_num == 5:
|
| 324 |
+
return nn.Sequential(
|
| 325 |
+
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
|
| 326 |
+
nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
|
| 327 |
+
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
|
| 328 |
+
nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
|
| 329 |
+
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
|
| 330 |
+
)
|
| 331 |
+
if self.aux_out1_conv_num == 3:
|
| 332 |
+
return nn.Sequential(
|
| 333 |
+
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
|
| 334 |
+
nn.Conv2d(in_ch // 2, in_ch, 3, 1, 1),
|
| 335 |
+
nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1),
|
| 336 |
+
)
|
| 337 |
+
if self.aux_out1_conv_num == 1:
|
| 338 |
+
return nn.Sequential(nn.Conv2d(in_ch, in_ch // 2, 3, 1, 1))
|
| 339 |
+
raise ValueError(f"aux_out1_conv_num {self.aux_out1_conv_num} not supported")
|
| 340 |
+
|
| 341 |
+
def _apply_activation_single(
|
| 342 |
+
self, x: torch.Tensor, activation: str = "linear"
|
| 343 |
+
) -> torch.Tensor:
|
| 344 |
+
"""
|
| 345 |
+
Apply activation to single channel output, maintaining semantic consistency with value branch in multi-channel case.
|
| 346 |
+
Supports: exp / relu / sigmoid / softplus / tanh / linear / expp1
|
| 347 |
+
"""
|
| 348 |
+
act = activation.lower() if isinstance(activation, str) else activation
|
| 349 |
+
if act == "exp":
|
| 350 |
+
return torch.exp(x)
|
| 351 |
+
if act == "expm1":
|
| 352 |
+
return torch.expm1(x)
|
| 353 |
+
if act == "expp1":
|
| 354 |
+
return torch.exp(x) + 1
|
| 355 |
+
if act == "relu":
|
| 356 |
+
return torch.relu(x)
|
| 357 |
+
if act == "sigmoid":
|
| 358 |
+
return torch.sigmoid(x)
|
| 359 |
+
if act == "softplus":
|
| 360 |
+
return torch.nn.functional.softplus(x)
|
| 361 |
+
if act == "tanh":
|
| 362 |
+
return torch.tanh(x)
|
| 363 |
+
# Default linear
|
| 364 |
+
return x
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
# # -----------------------------------------------------------------------------
|
| 368 |
+
# # Building blocks (tidy)
|
| 369 |
+
# # -----------------------------------------------------------------------------
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
# def _make_fusion_block(
|
| 373 |
+
# features: int,
|
| 374 |
+
# size: Tuple[int, int] = None,
|
| 375 |
+
# has_residual: bool = True,
|
| 376 |
+
# groups: int = 1,
|
| 377 |
+
# inplace: bool = False, # <- activation uses inplace=True by default; not related to "fusion_inplace"
|
| 378 |
+
# ) -> nn.Module:
|
| 379 |
+
# return FeatureFusionBlock(
|
| 380 |
+
# features=features,
|
| 381 |
+
# activation=nn.ReLU(inplace=inplace),
|
| 382 |
+
# deconv=False,
|
| 383 |
+
# bn=False,
|
| 384 |
+
# expand=False,
|
| 385 |
+
# align_corners=True,
|
| 386 |
+
# size=size,
|
| 387 |
+
# has_residual=has_residual,
|
| 388 |
+
# groups=groups,
|
| 389 |
+
# )
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
# def _make_scratch(
|
| 393 |
+
# in_shape: List[int], out_shape: int, groups: int = 1, expand: bool = False
|
| 394 |
+
# ) -> nn.Module:
|
| 395 |
+
# scratch = nn.Module()
|
| 396 |
+
# # optionally expand widths by stage
|
| 397 |
+
# c1 = out_shape
|
| 398 |
+
# c2 = out_shape * (2 if expand else 1)
|
| 399 |
+
# c3 = out_shape * (4 if expand else 1)
|
| 400 |
+
# c4 = out_shape * (8 if expand else 1)
|
| 401 |
+
|
| 402 |
+
# scratch.layer1_rn = nn.Conv2d(in_shape[0], c1, 3, 1, 1, bias=False, groups=groups)
|
| 403 |
+
# scratch.layer2_rn = nn.Conv2d(in_shape[1], c2, 3, 1, 1, bias=False, groups=groups)
|
| 404 |
+
# scratch.layer3_rn = nn.Conv2d(in_shape[2], c3, 3, 1, 1, bias=False, groups=groups)
|
| 405 |
+
# scratch.layer4_rn = nn.Conv2d(in_shape[3], c4, 3, 1, 1, bias=False, groups=groups)
|
| 406 |
+
# return scratch
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
# class ResidualConvUnit(nn.Module):
|
| 410 |
+
# """Lightweight residual conv block used within fusion."""
|
| 411 |
+
|
| 412 |
+
# def __init__(self, features: int, activation: nn.Module, bn: bool, groups: int = 1) -> None:
|
| 413 |
+
# super().__init__()
|
| 414 |
+
# self.bn = bn
|
| 415 |
+
# self.groups = groups
|
| 416 |
+
# self.conv1 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
|
| 417 |
+
# self.conv2 = nn.Conv2d(features, features, 3, 1, 1, bias=True, groups=groups)
|
| 418 |
+
# self.norm1 = None
|
| 419 |
+
# self.norm2 = None
|
| 420 |
+
# self.activation = activation
|
| 421 |
+
# self.skip_add = nn.quantized.FloatFunctional()
|
| 422 |
+
|
| 423 |
+
# def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
| 424 |
+
# out = self.activation(x)
|
| 425 |
+
# out = self.conv1(out)
|
| 426 |
+
# if self.norm1 is not None:
|
| 427 |
+
# out = self.norm1(out)
|
| 428 |
+
|
| 429 |
+
# out = self.activation(out)
|
| 430 |
+
# out = self.conv2(out)
|
| 431 |
+
# if self.norm2 is not None:
|
| 432 |
+
# out = self.norm2(out)
|
| 433 |
+
|
| 434 |
+
# return self.skip_add.add(out, x)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
# class FeatureFusionBlock(nn.Module):
|
| 438 |
+
# """Top-down fusion block: (optional) residual merge + upsample + 1x1 shrink."""
|
| 439 |
+
|
| 440 |
+
# def __init__(
|
| 441 |
+
# self,
|
| 442 |
+
# features: int,
|
| 443 |
+
# activation: nn.Module,
|
| 444 |
+
# deconv: bool = False,
|
| 445 |
+
# bn: bool = False,
|
| 446 |
+
# expand: bool = False,
|
| 447 |
+
# align_corners: bool = True,
|
| 448 |
+
# size: Tuple[int, int] = None,
|
| 449 |
+
# has_residual: bool = True,
|
| 450 |
+
# groups: int = 1,
|
| 451 |
+
# ) -> None:
|
| 452 |
+
# super().__init__()
|
| 453 |
+
# self.align_corners = align_corners
|
| 454 |
+
# self.size = size
|
| 455 |
+
# self.has_residual = has_residual
|
| 456 |
+
|
| 457 |
+
# self.resConfUnit1 = (
|
| 458 |
+
# ResidualConvUnit(features, activation, bn, groups=groups) if has_residual else None
|
| 459 |
+
# )
|
| 460 |
+
# self.resConfUnit2 = ResidualConvUnit(features, activation, bn, groups=groups)
|
| 461 |
+
|
| 462 |
+
# out_features = (features // 2) if expand else features
|
| 463 |
+
# self.out_conv = nn.Conv2d(features, out_features, 1, 1, 0, bias=True, groups=groups)
|
| 464 |
+
# self.skip_add = nn.quantized.FloatFunctional()
|
| 465 |
+
|
| 466 |
+
# def forward(self, *xs: torch.Tensor, size: Tuple[int, int] = None) -> torch.Tensor: # type: ignore[override]
|
| 467 |
+
# """
|
| 468 |
+
# xs:
|
| 469 |
+
# - xs[0]: top input
|
| 470 |
+
# - xs[1]: (optional) lateral (to be added with residual)
|
| 471 |
+
# """
|
| 472 |
+
# y = xs[0]
|
| 473 |
+
# if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
|
| 474 |
+
# y = self.skip_add.add(y, self.resConfUnit1(xs[1]))
|
| 475 |
+
|
| 476 |
+
# y = self.resConfUnit2(y)
|
| 477 |
+
|
| 478 |
+
# # upsample
|
| 479 |
+
# if (size is None) and (self.size is None):
|
| 480 |
+
# up_kwargs = {"scale_factor": 2}
|
| 481 |
+
# elif size is None:
|
| 482 |
+
# up_kwargs = {"size": self.size}
|
| 483 |
+
# else:
|
| 484 |
+
# up_kwargs = {"size": size}
|
| 485 |
+
|
| 486 |
+
# y = custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
|
| 487 |
+
# y = self.out_conv(y)
|
| 488 |
+
# return y
|
src/depth_anything_3/model/gs_adapter.py
ADDED
|
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Optional
|
| 16 |
+
import torch
|
| 17 |
+
from einops import einsum, rearrange, repeat
|
| 18 |
+
from torch import nn
|
| 19 |
+
|
| 20 |
+
from depth_anything_3.model.utils.transform import cam_quat_xyzw_to_world_quat_wxyz
|
| 21 |
+
from depth_anything_3.specs import Gaussians
|
| 22 |
+
from depth_anything_3.utils.geometry import affine_inverse, get_world_rays, sample_image_grid
|
| 23 |
+
from depth_anything_3.utils.pose_align import batch_align_poses_umeyama
|
| 24 |
+
from depth_anything_3.utils.sh_helpers import rotate_sh
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GaussianAdapter(nn.Module):
|
| 28 |
+
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
sh_degree: int = 0,
|
| 32 |
+
pred_color: bool = False,
|
| 33 |
+
pred_offset_depth: bool = False,
|
| 34 |
+
pred_offset_xy: bool = True,
|
| 35 |
+
gaussian_scale_min: float = 1e-5,
|
| 36 |
+
gaussian_scale_max: float = 30.0,
|
| 37 |
+
):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.sh_degree = sh_degree
|
| 40 |
+
self.pred_color = pred_color
|
| 41 |
+
self.pred_offset_depth = pred_offset_depth
|
| 42 |
+
self.pred_offset_xy = pred_offset_xy
|
| 43 |
+
self.gaussian_scale_min = gaussian_scale_min
|
| 44 |
+
self.gaussian_scale_max = gaussian_scale_max
|
| 45 |
+
|
| 46 |
+
# Create a mask for the spherical harmonics coefficients. This ensures that at
|
| 47 |
+
# initialization, the coefficients are biased towards having a large DC
|
| 48 |
+
# component and small view-dependent components.
|
| 49 |
+
if not pred_color:
|
| 50 |
+
self.register_buffer(
|
| 51 |
+
"sh_mask",
|
| 52 |
+
torch.ones((self.d_sh,), dtype=torch.float32),
|
| 53 |
+
persistent=False,
|
| 54 |
+
)
|
| 55 |
+
for degree in range(1, sh_degree + 1):
|
| 56 |
+
self.sh_mask[degree**2 : (degree + 1) ** 2] = 0.1 * 0.25**degree
|
| 57 |
+
|
| 58 |
+
def forward(
|
| 59 |
+
self,
|
| 60 |
+
extrinsics: torch.Tensor, # "*#batch 4 4"
|
| 61 |
+
intrinsics: torch.Tensor, # "*#batch 3 3"
|
| 62 |
+
depths: torch.Tensor, # "*#batch"
|
| 63 |
+
opacities: torch.Tensor, # "*#batch" | "*#batch _"
|
| 64 |
+
raw_gaussians: torch.Tensor, # "*#batch _"
|
| 65 |
+
image_shape: tuple[int, int],
|
| 66 |
+
eps: float = 1e-8,
|
| 67 |
+
gt_extrinsics: Optional[torch.Tensor] = None, # "*#batch 4 4"
|
| 68 |
+
**kwargs,
|
| 69 |
+
) -> Gaussians:
|
| 70 |
+
device = extrinsics.device
|
| 71 |
+
dtype = raw_gaussians.dtype
|
| 72 |
+
H, W = image_shape
|
| 73 |
+
b, v = raw_gaussians.shape[:2]
|
| 74 |
+
|
| 75 |
+
# get cam2worlds and intr_normed to adapt to 3DGS codebase
|
| 76 |
+
cam2worlds = affine_inverse(extrinsics)
|
| 77 |
+
intr_normed = intrinsics.clone().detach()
|
| 78 |
+
intr_normed[..., 0, :] /= W
|
| 79 |
+
intr_normed[..., 1, :] /= H
|
| 80 |
+
|
| 81 |
+
# 1. compute 3DGS means
|
| 82 |
+
# 1.1) offset the predicted depth if needed
|
| 83 |
+
if self.pred_offset_depth:
|
| 84 |
+
gs_depths = depths + raw_gaussians[..., -1]
|
| 85 |
+
raw_gaussians = raw_gaussians[..., :-1]
|
| 86 |
+
else:
|
| 87 |
+
gs_depths = depths
|
| 88 |
+
# 1.2) align predicted poses with GT if needed
|
| 89 |
+
if gt_extrinsics is not None and extrinsics != gt_extrinsics:
|
| 90 |
+
try:
|
| 91 |
+
_, _, pose_scales = batch_align_poses_umeyama(
|
| 92 |
+
gt_extrinsics.detach().float(),
|
| 93 |
+
extrinsics.detach().float(),
|
| 94 |
+
)
|
| 95 |
+
except Exception:
|
| 96 |
+
pose_scales = torch.ones_like(extrinsics[:, 0, 0, 0])
|
| 97 |
+
pose_scales = torch.clamp(pose_scales, min=1 / 3.0, max=3.0)
|
| 98 |
+
cam2worlds[:, :, :3, 3] = cam2worlds[:, :, :3, 3] * rearrange(
|
| 99 |
+
pose_scales, "b -> b () ()"
|
| 100 |
+
)
|
| 101 |
+
gs_depths = gs_depths * rearrange(pose_scales, "b -> b () () () ()")
|
| 102 |
+
# 1.3) casting xy in image space
|
| 103 |
+
xy_ray, _ = sample_image_grid((H, W), device)
|
| 104 |
+
xy_ray = xy_ray[None, None, ...].expand(b, v, -1, -1, -1) # b v h w xy
|
| 105 |
+
# offset xy if needed
|
| 106 |
+
if self.pred_offset_xy:
|
| 107 |
+
pixel_size = 1 / torch.tensor((W, H), dtype=xy_ray.dtype, device=device)
|
| 108 |
+
offset_xy = raw_gaussians[..., :2]
|
| 109 |
+
xy_ray = xy_ray + offset_xy * pixel_size
|
| 110 |
+
raw_gaussians = raw_gaussians[..., 2:] # skip the offset_xy
|
| 111 |
+
# 1.4) unproject depth + xy to world ray
|
| 112 |
+
origins, directions = get_world_rays(
|
| 113 |
+
xy_ray,
|
| 114 |
+
repeat(cam2worlds, "b v i j -> b v h w i j", h=H, w=W),
|
| 115 |
+
repeat(intr_normed, "b v i j -> b v h w i j", h=H, w=W),
|
| 116 |
+
)
|
| 117 |
+
gs_means_world = origins + directions * gs_depths[..., None]
|
| 118 |
+
gs_means_world = rearrange(gs_means_world, "b v h w d -> b (v h w) d")
|
| 119 |
+
|
| 120 |
+
# 2. compute other GS attributes
|
| 121 |
+
scales, rotations, sh = raw_gaussians.split((3, 4, 3 * self.d_sh), dim=-1)
|
| 122 |
+
|
| 123 |
+
# 2.1) 3DGS scales
|
| 124 |
+
# make the scale invarient to resolution
|
| 125 |
+
scale_min = self.gaussian_scale_min
|
| 126 |
+
scale_max = self.gaussian_scale_max
|
| 127 |
+
scales = scale_min + (scale_max - scale_min) * scales.sigmoid()
|
| 128 |
+
pixel_size = 1 / torch.tensor((W, H), dtype=dtype, device=device)
|
| 129 |
+
multiplier = self.get_scale_multiplier(intr_normed, pixel_size)
|
| 130 |
+
gs_scales = scales * gs_depths[..., None] * multiplier[..., None, None, None]
|
| 131 |
+
gs_scales = rearrange(gs_scales, "b v h w d -> b (v h w) d")
|
| 132 |
+
|
| 133 |
+
# 2.2) 3DGS quaternion (world space)
|
| 134 |
+
# due to historical issue, assume quaternion in order xyzw, not wxyz
|
| 135 |
+
# Normalize the quaternion features to yield a valid quaternion.
|
| 136 |
+
rotations = rotations / (rotations.norm(dim=-1, keepdim=True) + eps)
|
| 137 |
+
# rotate them to world space
|
| 138 |
+
cam_quat_xyzw = rearrange(rotations, "b v h w c -> b (v h w) c")
|
| 139 |
+
c2w_mat = repeat(
|
| 140 |
+
cam2worlds,
|
| 141 |
+
"b v i j -> b (v h w) i j",
|
| 142 |
+
h=H,
|
| 143 |
+
w=W,
|
| 144 |
+
)
|
| 145 |
+
world_quat_wxyz = cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w_mat)
|
| 146 |
+
gs_rotations_world = world_quat_wxyz # b (v h w) c
|
| 147 |
+
|
| 148 |
+
# 2.3) 3DGS color / SH coefficient (world space)
|
| 149 |
+
sh = rearrange(sh, "... (xyz d_sh) -> ... xyz d_sh", xyz=3)
|
| 150 |
+
if not self.pred_color:
|
| 151 |
+
sh = sh * self.sh_mask
|
| 152 |
+
|
| 153 |
+
if self.pred_color or self.sh_degree == 0:
|
| 154 |
+
# predict pre-computed color or predict only DC band, no need to transform
|
| 155 |
+
gs_sh_world = sh
|
| 156 |
+
else:
|
| 157 |
+
gs_sh_world = rotate_sh(sh, cam2worlds[:, :, None, None, None, :3, :3])
|
| 158 |
+
gs_sh_world = rearrange(gs_sh_world, "b v h w xyz d_sh -> b (v h w) xyz d_sh")
|
| 159 |
+
|
| 160 |
+
# 2.4) 3DGS opacity
|
| 161 |
+
gs_opacities = rearrange(opacities, "b v h w ... -> b (v h w) ...")
|
| 162 |
+
|
| 163 |
+
return Gaussians(
|
| 164 |
+
means=gs_means_world,
|
| 165 |
+
harmonics=gs_sh_world,
|
| 166 |
+
opacities=gs_opacities,
|
| 167 |
+
scales=gs_scales,
|
| 168 |
+
rotations=gs_rotations_world,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
def get_scale_multiplier(
|
| 172 |
+
self,
|
| 173 |
+
intrinsics: torch.Tensor, # "*#batch 3 3"
|
| 174 |
+
pixel_size: torch.Tensor, # "*#batch 2"
|
| 175 |
+
multiplier: float = 0.1,
|
| 176 |
+
) -> torch.Tensor: # " *batch"
|
| 177 |
+
xy_multipliers = multiplier * einsum(
|
| 178 |
+
intrinsics[..., :2, :2].float().inverse().to(intrinsics),
|
| 179 |
+
pixel_size,
|
| 180 |
+
"... i j, j -> ... i",
|
| 181 |
+
)
|
| 182 |
+
return xy_multipliers.sum(dim=-1)
|
| 183 |
+
|
| 184 |
+
@property
|
| 185 |
+
def d_sh(self) -> int:
|
| 186 |
+
return 1 if self.pred_color else (self.sh_degree + 1) ** 2
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def d_in(self) -> int:
|
| 190 |
+
# provided as reference to the gs_dpt output dim
|
| 191 |
+
raw_gs_dim = 0
|
| 192 |
+
if self.pred_offset_xy:
|
| 193 |
+
raw_gs_dim += 2
|
| 194 |
+
raw_gs_dim += 3 # scales
|
| 195 |
+
raw_gs_dim += 4 # quaternion
|
| 196 |
+
raw_gs_dim += 3 * self.d_sh # color
|
| 197 |
+
if self.pred_offset_depth:
|
| 198 |
+
raw_gs_dim += 1
|
| 199 |
+
|
| 200 |
+
return raw_gs_dim
|
src/depth_anything_3/model/gsdpt.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Dict as TyDict
|
| 16 |
+
from typing import List, Sequence
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
from depth_anything_3.model.dpt import DPT
|
| 21 |
+
from depth_anything_3.model.utils.head_utils import activate_head_gs, custom_interpolate
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GSDPT(DPT):
|
| 25 |
+
|
| 26 |
+
def __init__(
|
| 27 |
+
self,
|
| 28 |
+
dim_in: int,
|
| 29 |
+
patch_size: int = 14,
|
| 30 |
+
output_dim: int = 4,
|
| 31 |
+
activation: str = "linear",
|
| 32 |
+
conf_activation: str = "sigmoid",
|
| 33 |
+
features: int = 256,
|
| 34 |
+
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
| 35 |
+
pos_embed: bool = True,
|
| 36 |
+
feature_only: bool = False,
|
| 37 |
+
down_ratio: int = 1,
|
| 38 |
+
conf_dim: int = 1,
|
| 39 |
+
norm_type: str = "idt", # use to match legacy GS-DPT head, "idt" / "layer"
|
| 40 |
+
fusion_block_inplace: bool = False,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__(
|
| 43 |
+
dim_in=dim_in,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
output_dim=output_dim,
|
| 46 |
+
activation=activation,
|
| 47 |
+
conf_activation=conf_activation,
|
| 48 |
+
features=features,
|
| 49 |
+
out_channels=out_channels,
|
| 50 |
+
pos_embed=pos_embed,
|
| 51 |
+
down_ratio=down_ratio,
|
| 52 |
+
head_name="raw_gs",
|
| 53 |
+
use_sky_head=False,
|
| 54 |
+
norm_type=norm_type,
|
| 55 |
+
fusion_block_inplace=fusion_block_inplace,
|
| 56 |
+
)
|
| 57 |
+
self.conf_dim = conf_dim
|
| 58 |
+
if conf_dim and conf_dim > 1:
|
| 59 |
+
assert (
|
| 60 |
+
conf_activation == "linear"
|
| 61 |
+
), "use linear prediction when using view-dependent opacity"
|
| 62 |
+
|
| 63 |
+
merger_out_dim = features if feature_only else features // 2
|
| 64 |
+
self.images_merger = nn.Sequential(
|
| 65 |
+
nn.Conv2d(3, merger_out_dim // 4, 3, 1, 1), # fewer channels first
|
| 66 |
+
nn.GELU(),
|
| 67 |
+
nn.Conv2d(merger_out_dim // 4, merger_out_dim // 2, 3, 1, 1),
|
| 68 |
+
nn.GELU(),
|
| 69 |
+
nn.Conv2d(merger_out_dim // 2, merger_out_dim, 3, 1, 1),
|
| 70 |
+
nn.GELU(),
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# -------------------------------------------------------------------------
|
| 74 |
+
# Internal forward (single chunk)
|
| 75 |
+
# -------------------------------------------------------------------------
|
| 76 |
+
def _forward_impl(
|
| 77 |
+
self,
|
| 78 |
+
feats: List[torch.Tensor],
|
| 79 |
+
H: int,
|
| 80 |
+
W: int,
|
| 81 |
+
patch_start_idx: int,
|
| 82 |
+
images: torch.Tensor,
|
| 83 |
+
) -> TyDict[str, torch.Tensor]:
|
| 84 |
+
B, _, C = feats[0].shape
|
| 85 |
+
ph, pw = H // self.patch_size, W // self.patch_size
|
| 86 |
+
resized_feats = []
|
| 87 |
+
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
| 88 |
+
x = feats[take_idx][:, patch_start_idx:] # [B*S, N_patch, C]
|
| 89 |
+
x = self.norm(x)
|
| 90 |
+
x = x.permute(0, 2, 1).reshape(B, C, ph, pw) # [B*S, C, ph, pw]
|
| 91 |
+
|
| 92 |
+
x = self.projects[stage_idx](x)
|
| 93 |
+
if self.pos_embed:
|
| 94 |
+
x = self._add_pos_embed(x, W, H)
|
| 95 |
+
x = self.resize_layers[stage_idx](x) # Align scale
|
| 96 |
+
resized_feats.append(x)
|
| 97 |
+
|
| 98 |
+
# 2) Fusion pyramid (main branch only)
|
| 99 |
+
fused = self._fuse(resized_feats)
|
| 100 |
+
fused = self.scratch.output_conv1(fused)
|
| 101 |
+
|
| 102 |
+
# 3) Upsample to target resolution, optionally add position encoding again
|
| 103 |
+
h_out = int(ph * self.patch_size / self.down_ratio)
|
| 104 |
+
w_out = int(pw * self.patch_size / self.down_ratio)
|
| 105 |
+
|
| 106 |
+
fused = custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
|
| 107 |
+
|
| 108 |
+
# inject the image information here
|
| 109 |
+
fused = fused + self.images_merger(images)
|
| 110 |
+
|
| 111 |
+
if self.pos_embed:
|
| 112 |
+
fused = self._add_pos_embed(fused, W, H)
|
| 113 |
+
|
| 114 |
+
# 4) Shared neck1
|
| 115 |
+
# feat = self.scratch.output_conv1(fused)
|
| 116 |
+
feat = fused
|
| 117 |
+
|
| 118 |
+
# 5) Main head: logits -> activate_head or single channel activation
|
| 119 |
+
main_logits = self.scratch.output_conv2(feat)
|
| 120 |
+
outs: TyDict[str, torch.Tensor] = {}
|
| 121 |
+
if self.has_conf:
|
| 122 |
+
pred, conf = activate_head_gs(
|
| 123 |
+
main_logits,
|
| 124 |
+
activation=self.activation,
|
| 125 |
+
conf_activation=self.conf_activation,
|
| 126 |
+
conf_dim=self.conf_dim,
|
| 127 |
+
)
|
| 128 |
+
outs[self.head_main] = pred.squeeze(1)
|
| 129 |
+
outs[f"{self.head_main}_conf"] = conf.squeeze(1)
|
| 130 |
+
else:
|
| 131 |
+
outs[self.head_main] = self._apply_activation_single(main_logits).squeeze(1)
|
| 132 |
+
|
| 133 |
+
return outs
|
src/depth_anything_3/model/utils/attention.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 # noqa
|
| 16 |
+
|
| 17 |
+
from typing import Callable, Optional, Union
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
from torch import Tensor, nn
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class Attention(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
dim: int,
|
| 27 |
+
num_heads: int = 8,
|
| 28 |
+
qkv_bias: bool = True,
|
| 29 |
+
proj_bias: bool = True,
|
| 30 |
+
attn_drop: float = 0.0,
|
| 31 |
+
proj_drop: float = 0.0,
|
| 32 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 33 |
+
qk_norm: bool = False,
|
| 34 |
+
rope=None,
|
| 35 |
+
) -> None:
|
| 36 |
+
super().__init__()
|
| 37 |
+
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
| 38 |
+
self.num_heads = num_heads
|
| 39 |
+
self.head_dim = dim // num_heads
|
| 40 |
+
self.scale = self.head_dim**-0.5
|
| 41 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 42 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 43 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 44 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 45 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 46 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 47 |
+
self.rope = rope
|
| 48 |
+
|
| 49 |
+
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
|
| 50 |
+
# Debug breakpoint removed for production
|
| 51 |
+
B, N, C = x.shape
|
| 52 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 53 |
+
q, k, v = qkv.unbind(0)
|
| 54 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 55 |
+
q = self.rope(q, pos) if self.rope is not None else q
|
| 56 |
+
k = self.rope(k, pos) if self.rope is not None else k
|
| 57 |
+
x = F.scaled_dot_product_attention(
|
| 58 |
+
q,
|
| 59 |
+
k,
|
| 60 |
+
v,
|
| 61 |
+
dropout_p=self.attn_drop.p if self.training else 0.0,
|
| 62 |
+
attn_mask=attn_mask,
|
| 63 |
+
)
|
| 64 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 65 |
+
x = self.proj(x)
|
| 66 |
+
x = self.proj_drop(x)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class LayerScale(nn.Module):
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
dim: int,
|
| 74 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 75 |
+
inplace: bool = False,
|
| 76 |
+
) -> None:
|
| 77 |
+
super().__init__()
|
| 78 |
+
self.inplace = inplace
|
| 79 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 80 |
+
|
| 81 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 82 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Mlp(nn.Module):
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
in_features: int,
|
| 89 |
+
hidden_features: Optional[int] = None,
|
| 90 |
+
out_features: Optional[int] = None,
|
| 91 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 92 |
+
drop: float = 0.0,
|
| 93 |
+
bias: bool = True,
|
| 94 |
+
) -> None:
|
| 95 |
+
super().__init__()
|
| 96 |
+
out_features = out_features or in_features
|
| 97 |
+
hidden_features = hidden_features or in_features
|
| 98 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 99 |
+
self.act = act_layer()
|
| 100 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 101 |
+
self.drop = nn.Dropout(drop)
|
| 102 |
+
|
| 103 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 104 |
+
x = self.fc1(x)
|
| 105 |
+
x = self.act(x)
|
| 106 |
+
x = self.drop(x)
|
| 107 |
+
x = self.fc2(x)
|
| 108 |
+
x = self.drop(x)
|
| 109 |
+
return x
|
src/depth_anything_3/model/utils/block.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
from typing import Callable
|
| 17 |
+
from torch import Tensor, nn
|
| 18 |
+
|
| 19 |
+
from .attention import Attention, LayerScale, Mlp
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Block(nn.Module):
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
dim: int,
|
| 26 |
+
num_heads: int,
|
| 27 |
+
mlp_ratio: float = 4.0,
|
| 28 |
+
qkv_bias: bool = True,
|
| 29 |
+
proj_bias: bool = True,
|
| 30 |
+
ffn_bias: bool = True,
|
| 31 |
+
drop: float = 0.0,
|
| 32 |
+
attn_drop: float = 0.0,
|
| 33 |
+
init_values=None,
|
| 34 |
+
drop_path: float = 0.0,
|
| 35 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 36 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 37 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 38 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 39 |
+
qk_norm: bool = False,
|
| 40 |
+
rope=None,
|
| 41 |
+
) -> None:
|
| 42 |
+
super().__init__()
|
| 43 |
+
|
| 44 |
+
self.norm1 = norm_layer(dim)
|
| 45 |
+
|
| 46 |
+
self.attn = attn_class(
|
| 47 |
+
dim,
|
| 48 |
+
num_heads=num_heads,
|
| 49 |
+
qkv_bias=qkv_bias,
|
| 50 |
+
proj_bias=proj_bias,
|
| 51 |
+
attn_drop=attn_drop,
|
| 52 |
+
proj_drop=drop,
|
| 53 |
+
qk_norm=qk_norm,
|
| 54 |
+
rope=rope,
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 58 |
+
self.norm2 = norm_layer(dim)
|
| 59 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 60 |
+
self.mlp = ffn_layer(
|
| 61 |
+
in_features=dim,
|
| 62 |
+
hidden_features=mlp_hidden_dim,
|
| 63 |
+
act_layer=act_layer,
|
| 64 |
+
drop=drop,
|
| 65 |
+
bias=ffn_bias,
|
| 66 |
+
)
|
| 67 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 68 |
+
|
| 69 |
+
self.sample_drop_ratio = 0.0 # Equivalent to always having drop_path=0
|
| 70 |
+
|
| 71 |
+
def forward(self, x: Tensor, pos=None, attn_mask=None) -> Tensor:
|
| 72 |
+
def attn_residual_func(x: Tensor, pos=None, attn_mask=None) -> Tensor:
|
| 73 |
+
return self.ls1(self.attn(self.norm1(x), pos=pos, attn_mask=attn_mask))
|
| 74 |
+
|
| 75 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 76 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 77 |
+
|
| 78 |
+
# drop_path is always 0, so always take the else branch
|
| 79 |
+
x = x + attn_residual_func(x, pos=pos, attn_mask=attn_mask)
|
| 80 |
+
x = x + ffn_residual_func(x)
|
| 81 |
+
return x
|
src/depth_anything_3/model/utils/gs_renderer.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
from math import isqrt
|
| 17 |
+
from typing import Literal, Optional
|
| 18 |
+
import torch
|
| 19 |
+
from einops import rearrange, repeat
|
| 20 |
+
from tqdm import tqdm
|
| 21 |
+
|
| 22 |
+
from depth_anything_3.specs import Gaussians
|
| 23 |
+
from depth_anything_3.utils.camera_trj_helpers import (
|
| 24 |
+
interpolate_extrinsics,
|
| 25 |
+
interpolate_intrinsics,
|
| 26 |
+
render_dolly_zoom_path,
|
| 27 |
+
render_stabilization_path,
|
| 28 |
+
render_wander_path,
|
| 29 |
+
render_wobble_inter_path,
|
| 30 |
+
)
|
| 31 |
+
from depth_anything_3.utils.geometry import affine_inverse, as_homogeneous, get_fov
|
| 32 |
+
from depth_anything_3.utils.logger import logger
|
| 33 |
+
|
| 34 |
+
try:
|
| 35 |
+
from gsplat import rasterization
|
| 36 |
+
except ImportError:
|
| 37 |
+
logger.warn(
|
| 38 |
+
"Dependency `gsplat` is required for rendering 3DGS. "
|
| 39 |
+
"Install via: pip install git+https://github.com/nerfstudio-project/"
|
| 40 |
+
"gsplat.git@0b4dddf04cb687367602c01196913cde6a743d70"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def render_3dgs(
|
| 45 |
+
extrinsics: torch.Tensor, # "batch_views 4 4", w2c
|
| 46 |
+
intrinsics: torch.Tensor, # "batch_views 3 3", normalized
|
| 47 |
+
image_shape: tuple[int, int],
|
| 48 |
+
gaussian: Gaussians,
|
| 49 |
+
background_color: Optional[torch.Tensor] = None, # "batch_views 3"
|
| 50 |
+
use_sh: bool = True,
|
| 51 |
+
num_view: int = 1,
|
| 52 |
+
color_mode: Literal["RGB+D", "RGB+ED"] = "RGB+D",
|
| 53 |
+
**kwargs,
|
| 54 |
+
) -> tuple[
|
| 55 |
+
torch.Tensor, # "batch_views 3 height width"
|
| 56 |
+
torch.Tensor, # "batch_views height width"
|
| 57 |
+
]:
|
| 58 |
+
# extract gaussian params
|
| 59 |
+
gaussian_means = gaussian.means
|
| 60 |
+
gaussian_scales = gaussian.scales
|
| 61 |
+
gaussian_quats = gaussian.rotations
|
| 62 |
+
gaussian_opacities = gaussian.opacities
|
| 63 |
+
gaussian_sh_coefficients = gaussian.harmonics
|
| 64 |
+
b, _, _ = extrinsics.shape
|
| 65 |
+
|
| 66 |
+
if background_color is None:
|
| 67 |
+
background_color = repeat(torch.tensor([0.0, 0.0, 0.0]), "c -> b c", b=b).to(
|
| 68 |
+
gaussian_sh_coefficients
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
if use_sh:
|
| 72 |
+
_, _, _, n = gaussian_sh_coefficients.shape
|
| 73 |
+
degree = isqrt(n) - 1
|
| 74 |
+
shs = rearrange(gaussian_sh_coefficients, "b g xyz n -> b g n xyz").contiguous()
|
| 75 |
+
else: # use color
|
| 76 |
+
shs = (
|
| 77 |
+
gaussian_sh_coefficients.squeeze(-1).sigmoid().contiguous()
|
| 78 |
+
) # (b, g, c), normed to (0, 1)
|
| 79 |
+
|
| 80 |
+
h, w = image_shape
|
| 81 |
+
|
| 82 |
+
fov_x, fov_y = get_fov(intrinsics).unbind(dim=-1)
|
| 83 |
+
tan_fov_x = (0.5 * fov_x).tan()
|
| 84 |
+
tan_fov_y = (0.5 * fov_y).tan()
|
| 85 |
+
focal_length_x = w / (2 * tan_fov_x)
|
| 86 |
+
focal_length_y = h / (2 * tan_fov_y)
|
| 87 |
+
|
| 88 |
+
view_matrix = extrinsics.float()
|
| 89 |
+
|
| 90 |
+
all_images = []
|
| 91 |
+
all_radii = []
|
| 92 |
+
all_depths = []
|
| 93 |
+
# render view in a batch based, each batch contains one scene
|
| 94 |
+
# assume the Gaussian parameters are originally repeated along the view dim
|
| 95 |
+
batch_scene = b // num_view
|
| 96 |
+
|
| 97 |
+
def index_i_gs_attr(full_attr, idx):
|
| 98 |
+
# return rearrange(full_attr, "(b v) ... -> b v ...", v=num_view)[idx, 0]
|
| 99 |
+
return full_attr[idx]
|
| 100 |
+
|
| 101 |
+
for i in range(batch_scene):
|
| 102 |
+
K = repeat(
|
| 103 |
+
torch.tensor(
|
| 104 |
+
[
|
| 105 |
+
[0, 0, w / 2.0],
|
| 106 |
+
[0, 0, h / 2.0],
|
| 107 |
+
[0, 0, 1],
|
| 108 |
+
]
|
| 109 |
+
),
|
| 110 |
+
"i j -> v i j",
|
| 111 |
+
v=num_view,
|
| 112 |
+
).to(gaussian_means)
|
| 113 |
+
K[:, 0, 0] = focal_length_x.reshape(batch_scene, num_view)[i]
|
| 114 |
+
K[:, 1, 1] = focal_length_y.reshape(batch_scene, num_view)[i]
|
| 115 |
+
|
| 116 |
+
i_means = index_i_gs_attr(gaussian_means, i) # [N, 3]
|
| 117 |
+
i_scales = index_i_gs_attr(gaussian_scales, i)
|
| 118 |
+
i_quats = index_i_gs_attr(gaussian_quats, i)
|
| 119 |
+
i_opacities = index_i_gs_attr(gaussian_opacities, i) # [N,]
|
| 120 |
+
i_colors = index_i_gs_attr(shs, i) # [N, K, 3]
|
| 121 |
+
i_viewmats = rearrange(view_matrix, "(b v) ... -> b v ...", v=num_view)[i] # [v, 4, 4]
|
| 122 |
+
i_backgrounds = rearrange(background_color, "(b v) ... -> b v ...", v=num_view)[
|
| 123 |
+
i
|
| 124 |
+
] # [v, 3]
|
| 125 |
+
|
| 126 |
+
render_colors, render_alphas, info = rasterization(
|
| 127 |
+
means=i_means,
|
| 128 |
+
quats=i_quats, # [N, 4]
|
| 129 |
+
scales=i_scales, # [N, 3]
|
| 130 |
+
opacities=i_opacities,
|
| 131 |
+
colors=i_colors,
|
| 132 |
+
viewmats=i_viewmats, # [v, 4, 4]
|
| 133 |
+
Ks=K, # [v, 3, 3]
|
| 134 |
+
backgrounds=i_backgrounds,
|
| 135 |
+
render_mode=color_mode,
|
| 136 |
+
width=w,
|
| 137 |
+
height=h,
|
| 138 |
+
packed=False,
|
| 139 |
+
sh_degree=degree if use_sh else None,
|
| 140 |
+
)
|
| 141 |
+
depth = render_colors[..., -1].unbind(dim=0)
|
| 142 |
+
|
| 143 |
+
image = rearrange(render_colors[..., :3], "v h w c -> v c h w").unbind(dim=0)
|
| 144 |
+
radii = info["radii"].unbind(dim=0)
|
| 145 |
+
try:
|
| 146 |
+
info["means2d"].retain_grad() # [1, N, 2]
|
| 147 |
+
except Exception:
|
| 148 |
+
pass
|
| 149 |
+
all_images.extend(image)
|
| 150 |
+
all_depths.extend(depth)
|
| 151 |
+
all_radii.extend(radii)
|
| 152 |
+
|
| 153 |
+
return torch.stack(all_images), torch.stack(all_depths)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def run_renderer_in_chunk_w_trj_mode(
|
| 157 |
+
gaussians: Gaussians,
|
| 158 |
+
extrinsics: torch.Tensor, # world2cam, "batch view 4 4" | "batch view 3 4"
|
| 159 |
+
intrinsics: torch.Tensor, # unnormed intrinsics, "batch view 3 3"
|
| 160 |
+
image_shape: tuple[int, int],
|
| 161 |
+
chunk_size: Optional[int] = 8,
|
| 162 |
+
trj_mode: Literal[
|
| 163 |
+
"original",
|
| 164 |
+
"smooth",
|
| 165 |
+
"interpolate",
|
| 166 |
+
"interpolate_smooth",
|
| 167 |
+
"wander",
|
| 168 |
+
"dolly_zoom",
|
| 169 |
+
"extend",
|
| 170 |
+
"wobble_inter",
|
| 171 |
+
] = "smooth",
|
| 172 |
+
input_shape: Optional[tuple[int, int]] = None,
|
| 173 |
+
enable_tqdm: Optional[bool] = False,
|
| 174 |
+
**kwargs,
|
| 175 |
+
) -> tuple[
|
| 176 |
+
torch.Tensor, # color, "batch view 3 height width"
|
| 177 |
+
torch.Tensor, # depth, "batch view height width"
|
| 178 |
+
]:
|
| 179 |
+
cam2world = affine_inverse(as_homogeneous(extrinsics))
|
| 180 |
+
if input_shape is not None:
|
| 181 |
+
in_h, in_w = input_shape
|
| 182 |
+
else:
|
| 183 |
+
in_h, in_w = image_shape
|
| 184 |
+
intr_normed = intrinsics.clone().detach()
|
| 185 |
+
intr_normed[..., 0, :] /= in_w
|
| 186 |
+
intr_normed[..., 1, :] /= in_h
|
| 187 |
+
if extrinsics.shape[1] <= 1:
|
| 188 |
+
assert trj_mode in [
|
| 189 |
+
"wander",
|
| 190 |
+
"dolly_zoom",
|
| 191 |
+
], "Please set trj_mode to 'wander' or 'dolly_zoom' when n_views=1"
|
| 192 |
+
|
| 193 |
+
def _smooth_trj_fn_batch(raw_c2ws, k_size=50):
|
| 194 |
+
try:
|
| 195 |
+
smooth_c2ws = torch.stack(
|
| 196 |
+
[render_stabilization_path(c2w_i, k_size) for c2w_i in raw_c2ws],
|
| 197 |
+
dim=0,
|
| 198 |
+
)
|
| 199 |
+
except Exception as e:
|
| 200 |
+
print(f"[DEBUG] Path smoothing failed with error: {e}.")
|
| 201 |
+
smooth_c2ws = raw_c2ws
|
| 202 |
+
return smooth_c2ws
|
| 203 |
+
|
| 204 |
+
# get rendered trj
|
| 205 |
+
if trj_mode == "original":
|
| 206 |
+
tgt_c2w = cam2world
|
| 207 |
+
tgt_intr = intr_normed
|
| 208 |
+
elif trj_mode == "smooth":
|
| 209 |
+
tgt_c2w = _smooth_trj_fn_batch(cam2world)
|
| 210 |
+
tgt_intr = intr_normed
|
| 211 |
+
elif trj_mode in ["interpolate", "interpolate_smooth", "extend"]:
|
| 212 |
+
inter_len = 8
|
| 213 |
+
total_len = (cam2world.shape[1] - 1) * inter_len
|
| 214 |
+
if total_len > 24 * 18: # no more than 18s
|
| 215 |
+
inter_len = max(1, 24 * 10 // (cam2world.shape[1] - 1))
|
| 216 |
+
if total_len < 24 * 2: # no less than 2s
|
| 217 |
+
inter_len = max(1, 24 * 2 // (cam2world.shape[1] - 1))
|
| 218 |
+
|
| 219 |
+
if inter_len > 2:
|
| 220 |
+
t = torch.linspace(0, 1, inter_len, dtype=torch.float32, device=cam2world.device)
|
| 221 |
+
t = (torch.cos(torch.pi * (t + 1)) + 1) / 2
|
| 222 |
+
tgt_c2w_b = []
|
| 223 |
+
tgt_intr_b = []
|
| 224 |
+
for b_idx in range(cam2world.shape[0]):
|
| 225 |
+
tgt_c2w = []
|
| 226 |
+
tgt_intr = []
|
| 227 |
+
for cur_idx in range(cam2world.shape[1] - 1):
|
| 228 |
+
tgt_c2w.append(
|
| 229 |
+
interpolate_extrinsics(
|
| 230 |
+
cam2world[b_idx, cur_idx], cam2world[b_idx, cur_idx + 1], t
|
| 231 |
+
)[(0 if cur_idx == 0 else 1) :]
|
| 232 |
+
)
|
| 233 |
+
tgt_intr.append(
|
| 234 |
+
interpolate_intrinsics(
|
| 235 |
+
intr_normed[b_idx, cur_idx], intr_normed[b_idx, cur_idx + 1], t
|
| 236 |
+
)[(0 if cur_idx == 0 else 1) :]
|
| 237 |
+
)
|
| 238 |
+
tgt_c2w_b.append(torch.cat(tgt_c2w))
|
| 239 |
+
tgt_intr_b.append(torch.cat(tgt_intr))
|
| 240 |
+
tgt_c2w = torch.stack(tgt_c2w_b) # b v 4 4
|
| 241 |
+
tgt_intr = torch.stack(tgt_intr_b) # b v 3 3
|
| 242 |
+
else:
|
| 243 |
+
tgt_c2w = cam2world
|
| 244 |
+
tgt_intr = intr_normed
|
| 245 |
+
if trj_mode in ["interpolate_smooth", "extend"]:
|
| 246 |
+
tgt_c2w = _smooth_trj_fn_batch(tgt_c2w)
|
| 247 |
+
if trj_mode == "extend":
|
| 248 |
+
# apply dolly_zoom and wander in the middle frame
|
| 249 |
+
assert cam2world.shape[0] == 1, "extend only supports for batch_size=1 currently."
|
| 250 |
+
mid_idx = tgt_c2w.shape[1] // 2
|
| 251 |
+
c2w_wd, intr_wd = render_wander_path(
|
| 252 |
+
tgt_c2w[0, mid_idx],
|
| 253 |
+
tgt_intr[0, mid_idx],
|
| 254 |
+
h=in_h,
|
| 255 |
+
w=in_w,
|
| 256 |
+
num_frames=max(36, min(60, mid_idx // 2)),
|
| 257 |
+
max_disp=24.0,
|
| 258 |
+
)
|
| 259 |
+
c2w_dz, intr_dz = render_dolly_zoom_path(
|
| 260 |
+
tgt_c2w[0, mid_idx],
|
| 261 |
+
tgt_intr[0, mid_idx],
|
| 262 |
+
h=in_h,
|
| 263 |
+
w=in_w,
|
| 264 |
+
num_frames=max(36, min(60, mid_idx // 2)),
|
| 265 |
+
)
|
| 266 |
+
tgt_c2w = torch.cat(
|
| 267 |
+
[
|
| 268 |
+
tgt_c2w[:, :mid_idx],
|
| 269 |
+
c2w_wd.unsqueeze(0),
|
| 270 |
+
c2w_dz.unsqueeze(0),
|
| 271 |
+
tgt_c2w[:, mid_idx:],
|
| 272 |
+
],
|
| 273 |
+
dim=1,
|
| 274 |
+
)
|
| 275 |
+
tgt_intr = torch.cat(
|
| 276 |
+
[
|
| 277 |
+
tgt_intr[:, :mid_idx],
|
| 278 |
+
intr_wd.unsqueeze(0),
|
| 279 |
+
intr_dz.unsqueeze(0),
|
| 280 |
+
tgt_intr[:, mid_idx:],
|
| 281 |
+
],
|
| 282 |
+
dim=1,
|
| 283 |
+
)
|
| 284 |
+
elif trj_mode in ["wander", "dolly_zoom"]:
|
| 285 |
+
if trj_mode == "wander":
|
| 286 |
+
render_fn = render_wander_path
|
| 287 |
+
extra_kwargs = {"max_disp": 24.0}
|
| 288 |
+
else:
|
| 289 |
+
render_fn = render_dolly_zoom_path
|
| 290 |
+
extra_kwargs = {"D_focus": 30.0, "max_disp": 2.0}
|
| 291 |
+
tgt_c2w = []
|
| 292 |
+
tgt_intr = []
|
| 293 |
+
for b_idx in range(cam2world.shape[0]):
|
| 294 |
+
c2w_i, intr_i = render_fn(
|
| 295 |
+
cam2world[b_idx, 0], intr_normed[b_idx, 0], h=in_h, w=in_w, **extra_kwargs
|
| 296 |
+
)
|
| 297 |
+
tgt_c2w.append(c2w_i)
|
| 298 |
+
tgt_intr.append(intr_i)
|
| 299 |
+
tgt_c2w = torch.stack(tgt_c2w)
|
| 300 |
+
tgt_intr = torch.stack(tgt_intr)
|
| 301 |
+
elif trj_mode == "wobble_inter":
|
| 302 |
+
tgt_c2w, tgt_intr = render_wobble_inter_path(
|
| 303 |
+
cam2world=cam2world,
|
| 304 |
+
intr_normed=intr_normed,
|
| 305 |
+
inter_len=10,
|
| 306 |
+
n_skip=3,
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
raise Exception(f"trj mode [{trj_mode}] is not implemented.")
|
| 310 |
+
|
| 311 |
+
_, v = tgt_c2w.shape[:2]
|
| 312 |
+
tgt_extr = affine_inverse(tgt_c2w)
|
| 313 |
+
if chunk_size is None:
|
| 314 |
+
chunk_size = v
|
| 315 |
+
chunk_size = min(v, chunk_size)
|
| 316 |
+
all_colors = []
|
| 317 |
+
all_depths = []
|
| 318 |
+
for chunk_idx in tqdm(
|
| 319 |
+
range(math.ceil(v / chunk_size)),
|
| 320 |
+
desc="Rendering novel views",
|
| 321 |
+
disable=(not enable_tqdm),
|
| 322 |
+
leave=False,
|
| 323 |
+
):
|
| 324 |
+
s = int(chunk_idx * chunk_size)
|
| 325 |
+
e = int((chunk_idx + 1) * chunk_size)
|
| 326 |
+
cur_n_view = tgt_extr[:, s:e].shape[1]
|
| 327 |
+
color, depth = render_3dgs(
|
| 328 |
+
extrinsics=rearrange(tgt_extr[:, s:e], "b v ... -> (b v) ..."), # w2c
|
| 329 |
+
intrinsics=rearrange(tgt_intr[:, s:e], "b v ... -> (b v) ..."), # normed
|
| 330 |
+
image_shape=image_shape,
|
| 331 |
+
gaussian=gaussians,
|
| 332 |
+
num_view=cur_n_view,
|
| 333 |
+
**kwargs,
|
| 334 |
+
)
|
| 335 |
+
all_colors.append(rearrange(color, "(b v) ... -> b v ...", v=cur_n_view))
|
| 336 |
+
all_depths.append(rearrange(depth, "(b v) ... -> b v ...", v=cur_n_view))
|
| 337 |
+
all_colors = torch.cat(all_colors, dim=1)
|
| 338 |
+
all_depths = torch.cat(all_depths, dim=1)
|
| 339 |
+
|
| 340 |
+
return all_colors, all_depths
|
src/depth_anything_3/model/utils/head_utils.py
ADDED
|
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from typing import Tuple, Union
|
| 16 |
+
import torch
|
| 17 |
+
import torch.nn as nn
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
|
| 20 |
+
# -----------------------------------------------------------------------------
|
| 21 |
+
# Activation functions
|
| 22 |
+
# -----------------------------------------------------------------------------
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def activate_head_gs(out, activation="norm_exp", conf_activation="expp1", conf_dim=None):
|
| 26 |
+
"""
|
| 27 |
+
Process network output to extract GS params and density values.
|
| 28 |
+
Density could be view-dependent as SH coefficient
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
out: Network output tensor (B, C, H, W)
|
| 33 |
+
activation: Activation type for 3D points
|
| 34 |
+
conf_activation: Activation type for confidence values
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Tuple of (3D points tensor, confidence tensor)
|
| 38 |
+
"""
|
| 39 |
+
# Move channels from last dim to the 4th dimension => (B, H, W, C)
|
| 40 |
+
fmap = out.permute(0, 2, 3, 1) # B,H,W,C expected
|
| 41 |
+
|
| 42 |
+
# Split into xyz (first C-1 channels) and confidence (last channel)
|
| 43 |
+
conf_dim = 1 if conf_dim is None else conf_dim
|
| 44 |
+
xyz = fmap[:, :, :, :-conf_dim]
|
| 45 |
+
conf = fmap[:, :, :, -1] if conf_dim == 1 else fmap[:, :, :, -conf_dim:]
|
| 46 |
+
|
| 47 |
+
if activation == "norm_exp":
|
| 48 |
+
d = xyz.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 49 |
+
xyz_normed = xyz / d
|
| 50 |
+
pts3d = xyz_normed * torch.expm1(d)
|
| 51 |
+
elif activation == "norm":
|
| 52 |
+
pts3d = xyz / xyz.norm(dim=-1, keepdim=True)
|
| 53 |
+
elif activation == "exp":
|
| 54 |
+
pts3d = torch.exp(xyz)
|
| 55 |
+
elif activation == "relu":
|
| 56 |
+
pts3d = F.relu(xyz)
|
| 57 |
+
elif activation == "sigmoid":
|
| 58 |
+
pts3d = torch.sigmoid(xyz)
|
| 59 |
+
elif activation == "linear":
|
| 60 |
+
pts3d = xyz
|
| 61 |
+
else:
|
| 62 |
+
raise ValueError(f"Unknown activation: {activation}")
|
| 63 |
+
|
| 64 |
+
if conf_activation == "expp1":
|
| 65 |
+
conf_out = 1 + conf.exp()
|
| 66 |
+
elif conf_activation == "expp0":
|
| 67 |
+
conf_out = conf.exp()
|
| 68 |
+
elif conf_activation == "sigmoid":
|
| 69 |
+
conf_out = torch.sigmoid(conf)
|
| 70 |
+
elif conf_activation == "linear":
|
| 71 |
+
conf_out = conf
|
| 72 |
+
else:
|
| 73 |
+
raise ValueError(f"Unknown conf_activation: {conf_activation}")
|
| 74 |
+
|
| 75 |
+
return pts3d, conf_out
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# -----------------------------------------------------------------------------
|
| 79 |
+
# Other utilities
|
| 80 |
+
# -----------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
class Permute(nn.Module):
|
| 84 |
+
"""nn.Module wrapper around Tensor.permute for cleaner nn.Sequential usage."""
|
| 85 |
+
|
| 86 |
+
dims: Tuple[int, ...]
|
| 87 |
+
|
| 88 |
+
def __init__(self, dims: Tuple[int, ...]) -> None:
|
| 89 |
+
super().__init__()
|
| 90 |
+
self.dims = dims
|
| 91 |
+
|
| 92 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore[override]
|
| 93 |
+
return x.permute(*self.dims)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def position_grid_to_embed(
|
| 97 |
+
pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100
|
| 98 |
+
) -> torch.Tensor:
|
| 99 |
+
"""
|
| 100 |
+
Convert 2D position grid (HxWx2) to sinusoidal embeddings (HxWxC)
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
pos_grid: Tensor of shape (H, W, 2) containing 2D coordinates
|
| 104 |
+
embed_dim: Output channel dimension for embeddings
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
Tensor of shape (H, W, embed_dim) with positional embeddings
|
| 108 |
+
"""
|
| 109 |
+
H, W, grid_dim = pos_grid.shape
|
| 110 |
+
assert grid_dim == 2
|
| 111 |
+
pos_flat = pos_grid.reshape(-1, grid_dim) # Flatten to (H*W, 2)
|
| 112 |
+
|
| 113 |
+
# Process x and y coordinates separately
|
| 114 |
+
emb_x = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0) # [1, H*W, D/2]
|
| 115 |
+
emb_y = make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0) # [1, H*W, D/2]
|
| 116 |
+
|
| 117 |
+
# Combine and reshape
|
| 118 |
+
emb = torch.cat([emb_x, emb_y], dim=-1) # [1, H*W, D]
|
| 119 |
+
|
| 120 |
+
return emb.view(H, W, embed_dim) # [H, W, D]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100) -> torch.Tensor:
|
| 124 |
+
"""
|
| 125 |
+
This function generates a 1D positional embedding from a given grid using sine and cosine functions. # noqa
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
- embed_dim: The embedding dimension.
|
| 129 |
+
- pos: The position to generate the embedding from.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
- emb: The generated 1D positional embedding.
|
| 133 |
+
"""
|
| 134 |
+
assert embed_dim % 2 == 0
|
| 135 |
+
omega = torch.arange(embed_dim // 2, dtype=torch.double, device=pos.device)
|
| 136 |
+
omega /= embed_dim / 2.0
|
| 137 |
+
omega = 1.0 / omega_0**omega # (D/2,)
|
| 138 |
+
|
| 139 |
+
pos = pos.reshape(-1) # (M,)
|
| 140 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 141 |
+
|
| 142 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 143 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 144 |
+
|
| 145 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 146 |
+
return emb.float()
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
# Inspired by https://github.com/microsoft/moge
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def create_uv_grid(
|
| 153 |
+
width: int,
|
| 154 |
+
height: int,
|
| 155 |
+
aspect_ratio: float = None,
|
| 156 |
+
dtype: torch.dtype = None,
|
| 157 |
+
device: torch.device = None,
|
| 158 |
+
) -> torch.Tensor:
|
| 159 |
+
"""
|
| 160 |
+
Create a normalized UV grid of shape (width, height, 2).
|
| 161 |
+
|
| 162 |
+
The grid spans horizontally and vertically according to an aspect ratio,
|
| 163 |
+
ensuring the top-left corner is at (-x_span, -y_span) and the bottom-right
|
| 164 |
+
corner is at (x_span, y_span), normalized by the diagonal of the plane.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
width (int): Number of points horizontally.
|
| 168 |
+
height (int): Number of points vertically.
|
| 169 |
+
aspect_ratio (float, optional): Width-to-height ratio. Defaults to width/height.
|
| 170 |
+
dtype (torch.dtype, optional): Data type of the resulting tensor.
|
| 171 |
+
device (torch.device, optional): Device on which the tensor is created.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
torch.Tensor: A (width, height, 2) tensor of UV coordinates.
|
| 175 |
+
"""
|
| 176 |
+
# Derive aspect ratio if not explicitly provided
|
| 177 |
+
if aspect_ratio is None:
|
| 178 |
+
aspect_ratio = float(width) / float(height)
|
| 179 |
+
|
| 180 |
+
# Compute normalized spans for X and Y
|
| 181 |
+
diag_factor = (aspect_ratio**2 + 1.0) ** 0.5
|
| 182 |
+
span_x = aspect_ratio / diag_factor
|
| 183 |
+
span_y = 1.0 / diag_factor
|
| 184 |
+
|
| 185 |
+
# Establish the linspace boundaries
|
| 186 |
+
left_x = -span_x * (width - 1) / width
|
| 187 |
+
right_x = span_x * (width - 1) / width
|
| 188 |
+
top_y = -span_y * (height - 1) / height
|
| 189 |
+
bottom_y = span_y * (height - 1) / height
|
| 190 |
+
|
| 191 |
+
# Generate 1D coordinates
|
| 192 |
+
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
| 193 |
+
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
| 194 |
+
|
| 195 |
+
# Create 2D meshgrid (width x height) and stack into UV
|
| 196 |
+
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
| 197 |
+
uv_grid = torch.stack((uu, vv), dim=-1)
|
| 198 |
+
|
| 199 |
+
return uv_grid
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
# -----------------------------------------------------------------------------
|
| 203 |
+
# Interpolation (safe interpolation, avoid INT_MAX overflow)
|
| 204 |
+
# -----------------------------------------------------------------------------
|
| 205 |
+
def custom_interpolate(
|
| 206 |
+
x: torch.Tensor,
|
| 207 |
+
size: Union[Tuple[int, int], None] = None,
|
| 208 |
+
scale_factor: Union[float, None] = None,
|
| 209 |
+
mode: str = "bilinear",
|
| 210 |
+
align_corners: bool = True,
|
| 211 |
+
) -> torch.Tensor:
|
| 212 |
+
"""
|
| 213 |
+
Safe interpolation implementation to avoid INT_MAX overflow in torch.nn.functional.interpolate.
|
| 214 |
+
"""
|
| 215 |
+
if size is None:
|
| 216 |
+
assert scale_factor is not None, "Either size or scale_factor must be provided."
|
| 217 |
+
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
| 218 |
+
|
| 219 |
+
INT_MAX = 1610612736
|
| 220 |
+
total = size[0] * size[1] * x.shape[0] * x.shape[1]
|
| 221 |
+
|
| 222 |
+
if total > INT_MAX:
|
| 223 |
+
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
|
| 224 |
+
outs = [
|
| 225 |
+
nn.functional.interpolate(c, size=size, mode=mode, align_corners=align_corners)
|
| 226 |
+
for c in chunks
|
| 227 |
+
]
|
| 228 |
+
return torch.cat(outs, dim=0).contiguous()
|
| 229 |
+
|
| 230 |
+
return nn.functional.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
src/depth_anything_3/model/utils/transform.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def extri_intri_to_pose_encoding(
|
| 20 |
+
extrinsics,
|
| 21 |
+
intrinsics,
|
| 22 |
+
image_size_hw=None,
|
| 23 |
+
):
|
| 24 |
+
"""Convert camera extrinsics and intrinsics to a compact pose encoding."""
|
| 25 |
+
|
| 26 |
+
# extrinsics: BxSx3x4
|
| 27 |
+
# intrinsics: BxSx3x3
|
| 28 |
+
R = extrinsics[:, :, :3, :3] # BxSx3x3
|
| 29 |
+
T = extrinsics[:, :, :3, 3] # BxSx3
|
| 30 |
+
|
| 31 |
+
quat = mat_to_quat(R)
|
| 32 |
+
# Note the order of h and w here
|
| 33 |
+
H, W = image_size_hw
|
| 34 |
+
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
| 35 |
+
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
| 36 |
+
pose_encoding = torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
| 37 |
+
|
| 38 |
+
return pose_encoding
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def pose_encoding_to_extri_intri(
|
| 42 |
+
pose_encoding,
|
| 43 |
+
image_size_hw=None,
|
| 44 |
+
):
|
| 45 |
+
"""Convert a pose encoding back to camera extrinsics and intrinsics."""
|
| 46 |
+
|
| 47 |
+
T = pose_encoding[..., :3]
|
| 48 |
+
quat = pose_encoding[..., 3:7]
|
| 49 |
+
fov_h = pose_encoding[..., 7]
|
| 50 |
+
fov_w = pose_encoding[..., 8]
|
| 51 |
+
|
| 52 |
+
R = quat_to_mat(quat)
|
| 53 |
+
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
| 54 |
+
|
| 55 |
+
H, W = image_size_hw
|
| 56 |
+
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
|
| 57 |
+
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
|
| 58 |
+
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device)
|
| 59 |
+
intrinsics[..., 0, 0] = fx
|
| 60 |
+
intrinsics[..., 1, 1] = fy
|
| 61 |
+
intrinsics[..., 0, 2] = W / 2
|
| 62 |
+
intrinsics[..., 1, 2] = H / 2
|
| 63 |
+
intrinsics[..., 2, 2] = 1.0 # Set the homogeneous coordinate to 1
|
| 64 |
+
|
| 65 |
+
return extrinsics, intrinsics
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
| 69 |
+
"""
|
| 70 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 71 |
+
|
| 72 |
+
Convert rotations given as quaternions to rotation matrices.
|
| 73 |
+
Args:
|
| 74 |
+
quaternions: quaternions with real part last,
|
| 75 |
+
as tensor of shape (..., 4).
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 79 |
+
"""
|
| 80 |
+
i, j, k, r = torch.unbind(quaternions, -1)
|
| 81 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
| 82 |
+
|
| 83 |
+
o = torch.stack(
|
| 84 |
+
(
|
| 85 |
+
1 - two_s * (j * j + k * k),
|
| 86 |
+
two_s * (i * j - k * r),
|
| 87 |
+
two_s * (i * k + j * r),
|
| 88 |
+
two_s * (i * j + k * r),
|
| 89 |
+
1 - two_s * (i * i + k * k),
|
| 90 |
+
two_s * (j * k - i * r),
|
| 91 |
+
two_s * (i * k - j * r),
|
| 92 |
+
two_s * (j * k + i * r),
|
| 93 |
+
1 - two_s * (i * i + j * j),
|
| 94 |
+
),
|
| 95 |
+
-1,
|
| 96 |
+
)
|
| 97 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
| 101 |
+
"""
|
| 102 |
+
Convert rotations given as rotation matrices to quaternions.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
quaternions with real part last, as tensor of shape (..., 4).
|
| 109 |
+
Quaternion Order: XYZW or say ijkr, scalar-last
|
| 110 |
+
"""
|
| 111 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
| 112 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 113 |
+
|
| 114 |
+
batch_dim = matrix.shape[:-2]
|
| 115 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
| 116 |
+
matrix.reshape(batch_dim + (9,)), dim=-1
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
q_abs = _sqrt_positive_part(
|
| 120 |
+
torch.stack(
|
| 121 |
+
[
|
| 122 |
+
1.0 + m00 + m11 + m22,
|
| 123 |
+
1.0 + m00 - m11 - m22,
|
| 124 |
+
1.0 - m00 + m11 - m22,
|
| 125 |
+
1.0 - m00 - m11 + m22,
|
| 126 |
+
],
|
| 127 |
+
dim=-1,
|
| 128 |
+
)
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
quat_by_rijk = torch.stack(
|
| 132 |
+
[
|
| 133 |
+
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
| 134 |
+
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
| 135 |
+
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
| 136 |
+
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
| 137 |
+
],
|
| 138 |
+
dim=-2,
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
| 142 |
+
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
| 143 |
+
|
| 144 |
+
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
|
| 145 |
+
batch_dim + (4,)
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
out = out[..., [1, 2, 3, 0]]
|
| 149 |
+
|
| 150 |
+
out = standardize_quaternion(out)
|
| 151 |
+
|
| 152 |
+
return out
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
| 156 |
+
"""
|
| 157 |
+
Returns torch.sqrt(torch.max(0, x))
|
| 158 |
+
but with a zero subgradient where x is 0.
|
| 159 |
+
"""
|
| 160 |
+
ret = torch.zeros_like(x)
|
| 161 |
+
positive_mask = x > 0
|
| 162 |
+
if torch.is_grad_enabled():
|
| 163 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
| 164 |
+
else:
|
| 165 |
+
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
| 166 |
+
return ret
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
| 170 |
+
"""
|
| 171 |
+
Convert a unit quaternion to a standard form: one in which the real
|
| 172 |
+
part is non negative.
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
quaternions: Quaternions with real part last,
|
| 176 |
+
as tensor of shape (..., 4).
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
Standardized quaternions as tensor of shape (..., 4).
|
| 180 |
+
"""
|
| 181 |
+
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def cam_quat_xyzw_to_world_quat_wxyz(cam_quat_xyzw, c2w):
|
| 185 |
+
# cam_quat_xyzw: (b, n, 4) in xyzw
|
| 186 |
+
# c2w: (b, n, 4, 4)
|
| 187 |
+
b, n = cam_quat_xyzw.shape[:2]
|
| 188 |
+
# 1. xyzw -> wxyz
|
| 189 |
+
cam_quat_wxyz = torch.cat(
|
| 190 |
+
[
|
| 191 |
+
cam_quat_xyzw[..., 3:4], # w
|
| 192 |
+
cam_quat_xyzw[..., 0:1], # x
|
| 193 |
+
cam_quat_xyzw[..., 1:2], # y
|
| 194 |
+
cam_quat_xyzw[..., 2:3], # z
|
| 195 |
+
],
|
| 196 |
+
dim=-1,
|
| 197 |
+
)
|
| 198 |
+
# 2. Quaternion to matrix
|
| 199 |
+
cam_quat_wxyz_flat = cam_quat_wxyz.reshape(-1, 4)
|
| 200 |
+
rotmat_cam = quat_to_mat(cam_quat_wxyz_flat).reshape(b, n, 3, 3)
|
| 201 |
+
# 3. Transform to world space
|
| 202 |
+
rotmat_c2w = c2w[..., :3, :3]
|
| 203 |
+
rotmat_world = torch.matmul(rotmat_c2w, rotmat_cam)
|
| 204 |
+
# 4. Matrix to quaternion (wxyz)
|
| 205 |
+
rotmat_world_flat = rotmat_world.reshape(-1, 3, 3)
|
| 206 |
+
world_quat_wxyz_flat = mat_to_quat(rotmat_world_flat)
|
| 207 |
+
world_quat_wxyz = world_quat_wxyz_flat.reshape(b, n, 4)
|
| 208 |
+
return world_quat_wxyz
|
src/depth_anything_3/registry.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from collections import OrderedDict
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def get_all_models() -> OrderedDict:
|
| 20 |
+
"""
|
| 21 |
+
Scans all YAML files in the configs directory and returns a sorted dictionary where:
|
| 22 |
+
- Keys are model names (YAML filenames without the .yaml extension)
|
| 23 |
+
- Values are absolute paths to the corresponding YAML files
|
| 24 |
+
"""
|
| 25 |
+
# Get path to the configs directory within the da3 package
|
| 26 |
+
# Works both in development and after pip installation
|
| 27 |
+
# configs_dir = files("depth_anything_3").joinpath("configs")
|
| 28 |
+
configs_dir = Path(__file__).resolve().parent / "configs"
|
| 29 |
+
|
| 30 |
+
# Ensure path is a Path object for consistent cross-platform handling
|
| 31 |
+
configs_dir = Path(configs_dir)
|
| 32 |
+
|
| 33 |
+
model_entries = []
|
| 34 |
+
# Iterate through all items in the configs directory
|
| 35 |
+
for item in configs_dir.iterdir():
|
| 36 |
+
# Filter for YAML files (excluding directories)
|
| 37 |
+
if item.is_file() and item.suffix == ".yaml":
|
| 38 |
+
# Extract model name (filename without .yaml extension)
|
| 39 |
+
model_name = item.stem
|
| 40 |
+
# Get absolute path (resolve() handles symlinks)
|
| 41 |
+
file_abs_path = str(item.resolve())
|
| 42 |
+
model_entries.append((model_name, file_abs_path))
|
| 43 |
+
|
| 44 |
+
# Sort entries by model name and convert to OrderedDict
|
| 45 |
+
sorted_entries = sorted(model_entries, key=lambda x: x[0])
|
| 46 |
+
return OrderedDict(sorted_entries)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# Global registry for external imports
|
| 50 |
+
MODEL_REGISTRY = get_all_models()
|
src/depth_anything_3/services/__init__.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
"""
|
| 16 |
+
Services module for Depth Anything 3.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
from depth_anything_3.services.backend import create_app, start_server
|
| 20 |
+
|
| 21 |
+
__all__ = [
|
| 22 |
+
start_server,
|
| 23 |
+
create_app,
|
| 24 |
+
]
|