diff --git a/.github/workflows/huggingface.yml b/.github/workflows/huggingface.yml index e9a1af9cd9cf2195b72826362e4447ce2b92a00a..58a637e2c9121b8e31c68e2f97d03e41d967a1c8 100644 --- a/.github/workflows/huggingface.yml +++ b/.github/workflows/huggingface.yml @@ -17,9 +17,9 @@ jobs: env: HF: ${{secrets.HF_TOKEN }} HFUSER: ${{secrets.HFUSER }} - run: git remote add space https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/Depth-Anything-Compare-demo - - name: Push to hub + run: git remote add space https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/Depth-Estimation-Compare-demo + - name: Push to huggingface hub env: HF: ${{ secrets.HF_TOKEN}} HFUSER: ${{secrets.HFUSER }} - run: git push --force https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/Depth-Anything-Compare-demo main \ No newline at end of file + run: git push --force https://$HFUSER:$HF@huggingface.co/spaces/$HFUSER/Depth-Estimation-Compare-demo main \ No newline at end of file diff --git a/Depth-Anything-V2/depth_anything_v2/__pycache__/__init__.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 331174322208ddb00d44610ff0885804f9ce1a02..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/__pycache__/dinov2.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/__pycache__/dinov2.cpython-311.pyc deleted file mode 100644 index 91760d31acd1fd33a9695dcd05e0a42d810f7430..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/__pycache__/dinov2.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/__pycache__/dpt.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/__pycache__/dpt.cpython-311.pyc deleted file mode 100644 index 40d719744697e8944789eed715802ed47927ae92..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/__pycache__/dpt.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index fc5c5cd23fb4b3c637a476a2cea78a6d79645ab1..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-311.pyc deleted file mode 100644 index 976116ab539a70d37341a453840035dd33043920..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-311.pyc deleted file mode 100644 index 409eea18951b8ebd81eb20897dc4f56e188e9ac7..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/block.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-311.pyc deleted file mode 100644 index 1a4cd38ea36adb4f5602a3566dd7ef81d517503a..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-311.pyc deleted file mode 100644 index 2239b32a6d70d7e65afd16aa1b6f3dfac0167e18..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-311.pyc deleted file mode 100644 index c0d4cf93204d2d845a080ebfab1460c3f0e4291f..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-311.pyc deleted file mode 100644 index 356b018eeefad4a296d598229b32a7129ce996d4..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-311.pyc deleted file mode 100644 index e5113d95911ac050d9b66c398bb51ef9ea2f3653..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/util/__pycache__/__init__.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/util/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index d7a99479af8c4fc13ecb2233dc4d31f16f499b0e..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/util/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/util/__pycache__/blocks.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/util/__pycache__/blocks.cpython-311.pyc deleted file mode 100644 index 8e79bc461d1bfefeb8d51764c089d20f27c2c7ab..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/util/__pycache__/blocks.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything-V2/depth_anything_v2/util/__pycache__/transform.cpython-311.pyc b/Depth-Anything-V2/depth_anything_v2/util/__pycache__/transform.cpython-311.pyc deleted file mode 100644 index 620a429c7bb92be1724ea92816321381fa7c2ddd..0000000000000000000000000000000000000000 Binary files a/Depth-Anything-V2/depth_anything_v2/util/__pycache__/transform.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/depth_anything/__pycache__/__init__.cpython-311.pyc b/Depth-Anything/depth_anything/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index deae576a582a344bd10572badaee6dfe61fbfbee..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/depth_anything/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/depth_anything/__pycache__/blocks.cpython-311.pyc b/Depth-Anything/depth_anything/__pycache__/blocks.cpython-311.pyc deleted file mode 100644 index 1068e97e13d0d544dd5ccc32fe09be19ff47d54e..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/depth_anything/__pycache__/blocks.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/depth_anything/__pycache__/dpt.cpython-311.pyc b/Depth-Anything/depth_anything/__pycache__/dpt.cpython-311.pyc deleted file mode 100644 index 60c519e230ae47f667b3c7c2569f251bf28438af..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/depth_anything/__pycache__/dpt.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/depth_anything/util/__pycache__/__init__.cpython-311.pyc b/Depth-Anything/depth_anything/util/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 2c54a17c8c6b3d8b23b90f5ae59d392edd256568..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/depth_anything/util/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/depth_anything/util/__pycache__/transform.cpython-311.pyc b/Depth-Anything/depth_anything/util/__pycache__/transform.cpython-311.pyc deleted file mode 100644 index 495f2a760ad430fff7b49b669672917e6b4ae3c8..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/depth_anything/util/__pycache__/transform.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/__pycache__/hubconf.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/__pycache__/hubconf.cpython-311.pyc deleted file mode 100644 index cf434d2292f98d68ac70f2bb2144d617dc6e46ca..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/__pycache__/hubconf.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/__pycache__/vision_transformer.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/__pycache__/vision_transformer.cpython-311.pyc deleted file mode 100644 index 766fa76a2bca4fa3bdc0d46be09612f7a58d51a9..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/__pycache__/vision_transformer.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/__pycache__/__init__.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 7e9ae6f5a1173c18e7939175bab001782a4417da..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/__init__.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/__init__.cpython-311.pyc deleted file mode 100644 index 441684318e9a2d471dea3f99adcdb0a2d243d3f0..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/__init__.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/attention.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/attention.cpython-311.pyc deleted file mode 100644 index 92ed68a9ec0aa4e14065e45e68694b48173b2b81..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/attention.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/block.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/block.cpython-311.pyc deleted file mode 100644 index 177ba0634a25350b15a78cf85dcde637752e9d14..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/block.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/dino_head.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/dino_head.cpython-311.pyc deleted file mode 100644 index e193d25c6d53a34b24d682502fe76905cd9ad909..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/dino_head.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/drop_path.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/drop_path.cpython-311.pyc deleted file mode 100644 index 4382da314b41e343ee7ca2fa4584c0a50209ff8e..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/drop_path.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/layer_scale.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/layer_scale.cpython-311.pyc deleted file mode 100644 index f74fe83fae00f4019d8a8ef6c6a2b3b050406663..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/layer_scale.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/mlp.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/mlp.cpython-311.pyc deleted file mode 100644 index 7d89c9f8957bb2c0c73eab19f7a109cba6cf9089..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/mlp.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/patch_embed.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/patch_embed.cpython-311.pyc deleted file mode 100644 index 2aabd6b53d4284580d40257c310a052d0f88472f..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/patch_embed.cpython-311.pyc and /dev/null differ diff --git a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/swiglu_ffn.cpython-311.pyc b/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/swiglu_ffn.cpython-311.pyc deleted file mode 100644 index 3a668349b99cacc5a93b286fe634818eabfa55b3..0000000000000000000000000000000000000000 Binary files a/Depth-Anything/torchhub/facebookresearch_dinov2_main/dinov2/layers/__pycache__/swiglu_ffn.cpython-311.pyc and /dev/null differ diff --git a/Pixel-Perfect-Depth/.gitattributes b/Pixel-Perfect-Depth/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..6635078645335e1d89526cd35a3ecf2cb69b903c --- /dev/null +++ b/Pixel-Perfect-Depth/.gitattributes @@ -0,0 +1,54 @@ +*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +assets/examples/0001.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0003.png filter=lfs diff=lfs merge=lfs -text +assets/examples/0004.png filter=lfs diff=lfs merge=lfs -text +assets/examples/0005.png filter=lfs diff=lfs merge=lfs -text +assets/examples/0006.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0007.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0008.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0009.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0010.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0004.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0005.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0011.jpg filter=lfs diff=lfs merge=lfs -text +assets/examples/0001.png filter=lfs diff=lfs merge=lfs -text +assets/examples/0002.png filter=lfs diff=lfs merge=lfs -text +assets/examples/0003.JPG filter=lfs diff=lfs merge=lfs -text +assets/examples/0006.PNG filter=lfs diff=lfs merge=lfs -text +assets/examples/0007.PNG filter=lfs diff=lfs merge=lfs -text +assets/examples/0008.PNG filter=lfs diff=lfs merge=lfs -text +assets/examples/0009.PNG filter=lfs diff=lfs merge=lfs -text diff --git a/Pixel-Perfect-Depth/app.py b/Pixel-Perfect-Depth/app.py new file mode 100644 index 0000000000000000000000000000000000000000..6b89c3f9e17f99da974b433054248989611d309f --- /dev/null +++ b/Pixel-Perfect-Depth/app.py @@ -0,0 +1,209 @@ +import gradio as gr +import cv2 +import matplotlib +import numpy as np +import os +import time +from PIL import Image +import torch +import torch.nn.functional as F +import open3d as o3d +import trimesh +import tempfile +import shutil +from pathlib import Path +from concurrent.futures import ThreadPoolExecutor +from gradio_imageslider import ImageSlider +from huggingface_hub import hf_hub_download + +from ppd.utils.set_seed import set_seed +from ppd.utils.align_depth_func import recover_metric_depth_ransac +from ppd.utils.depth2pcd import depth2pcd +from moge.model.v2 import MoGeModel +from ppd.models.ppd import PixelPerfectDepth + +try: + import spaces + HUGGINFACE_SPACES_INSTALLED = True +except ImportError: + HUGGINFACE_SPACES_INSTALLED = False + +css = """ +#img-display-container { + max-height: 100vh; +} +#img-display-input { + max-height: 100vh; +} +#img-display-output { + max-height: 100vh; +} +#download { + height: 62px; +} + +#img-display-output .image-slider-image { + object-fit: contain !important; + width: 100% !important; + height: 100% !important; +} +""" + +set_seed(666) + +DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") + +default_steps = 20 +model = PixelPerfectDepth(sampling_steps=default_steps) +ckpt_path = hf_hub_download( + repo_id="gangweix/Pixel-Perfect-Depth", + filename="ppd.pth", + repo_type="model" +) +state_dict = torch.load(ckpt_path, map_location="cpu") +model.load_state_dict(state_dict, strict=False) +model = model.eval() +model = model.to(DEVICE) + +moge_model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").eval() +moge_model = moge_model.to(DEVICE) + + +def main(share=True): + print("Initializing Pixel-Perfect Depth Demo...") + + cmap = matplotlib.colormaps.get_cmap('Spectral') + + title = "# Pixel-Perfect Depth" + description = """Official demo for **Pixel-Perfect Depth**. + Please refer to our [paper](https://arxiv.org/pdf/2510.07316), [project page](https://pixel-perfect-depth.github.io), and [github](https://github.com/gangweix/pixel-perfect-depth) for more details.""" + + @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else (lambda x: x)) + def predict_depth(image, denoise_steps): + depth, resize_image = model.infer_image(image, sampling_steps=denoise_steps) + return depth, resize_image + + @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else (lambda x: x)) + def predict_moge_depth(image): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = torch.tensor(image / 255, dtype=torch.float32, device=DEVICE).permute(2, 0, 1) + metric_depth, mask, intrinsics = moge_model.infer(image) + metric_depth[~mask] = metric_depth[mask].max() + return metric_depth, mask, intrinsics + + def on_submit(image, denoise_steps, apply_filter, request: gr.Request = None): + + H, W = image.shape[:2] + ppd_depth, resize_image = predict_depth(image[:, :, ::-1], denoise_steps) + resize_H, resize_W = resize_image.shape[:2] + + # moge provide metric depth and intrinsics + moge_depth, mask, intrinsics = predict_moge_depth(resize_image) + + # relative depth -> metric depth + metric_depth = recover_metric_depth_ransac(ppd_depth, moge_depth, mask) + intrinsics[0, 0] *= resize_W + intrinsics[1, 1] *= resize_H + intrinsics[0, 2] *= resize_W + intrinsics[1, 2] *= resize_H + + # metric depth -> point cloud + pcd = depth2pcd(metric_depth, intrinsics, color=cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB), input_mask=mask, ret_pcd=True) + if apply_filter: + cl, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0) + pcd = pcd.select_by_index(ind) + + tempdir = Path(tempfile.gettempdir(), 'ppd') + tempdir.mkdir(exist_ok=True) + output_path = Path(tempdir, request.session_hash) + shutil.rmtree(output_path, ignore_errors=True) + output_path.mkdir(exist_ok=True, parents=True) + + ply_path = os.path.join(output_path, 'pointcloud.ply') + + # save pcd to temporary .ply + pcd.points = o3d.utility.Vector3dVector( + np.asarray(pcd.points) * np.array([1, -1, -1], dtype=np.float32) + ) + o3d.io.write_point_cloud(ply_path, pcd) + vertices = np.asarray(pcd.points) + vertex_colors = (np.asarray(pcd.colors) * 255).astype(np.uint8) + mesh = trimesh.PointCloud(vertices=vertices, colors=vertex_colors) + glb_path = os.path.join(output_path, 'pointcloud.glb') + mesh.export(glb_path) + + + # save raw depth (npy) + depth = cv2.resize(ppd_depth, (W, H), interpolation=cv2.INTER_LINEAR) + raw_depth_path = os.path.join(output_path, 'raw_depth.npy') + np.save(raw_depth_path, depth) + + depth_vis = (depth - depth.min()) / (depth.max() - depth.min() + 1e-5) * 255.0 + depth_vis = depth_vis.astype(np.uint8) + colored_depth = (cmap(depth_vis)[:, :, :3] * 255).astype(np.uint8) + + split_region = np.ones((image.shape[0], 50, 3), dtype=np.uint8) * 255 + combined_result = cv2.hconcat([image[:, :, ::-1], split_region, colored_depth[:, :, ::-1]]) + + vis_path = os.path.join(output_path, 'image_depth_vis.png') + cv2.imwrite(vis_path, combined_result) + + file_names = ["image_depth_vis.png", "raw_depth.npy", "pointcloud.ply"] + + download_files = [ + (output_path / name).as_posix() + for name in file_names + if (output_path / name).exists() + ] + + return [(image, colored_depth), glb_path, download_files] + + + with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown(title) + gr.Markdown(description) + gr.Markdown("### Point Cloud & Depth Prediction demo") + + with gr.Row(): + # Left: input image + settings + with gr.Column(): + input_image = gr.Image(label="Input Image", image_mode="RGB", type='numpy', elem_id='img-display-input') + with gr.Accordion(label="Settings", open=False): + denoise_steps = gr.Slider(label="Denoising Steps", minimum=1, maximum=100, value=20, step=1) + apply_filter = gr.Checkbox(label="Apply filter points", value=True) + submit_btn = gr.Button(value="Predict") + + # Right: 3D point cloud + depth + with gr.Column(): + with gr.Tabs(): + with gr.Tab("3D View"): + model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1,1,1,1], height="60vh") + with gr.Tab("Depth"): + depth_map = ImageSlider(label="Depth Map with Slider View", elem_id='img-display-output', position=0.5) + with gr.Tab("Download"): + download_files = gr.File(type='filepath', label="Download Files") + + submit_btn.click( + fn=lambda: [None, None, None, "", "", ""], + outputs=[depth_map, model_3d, download_files] + ).then( + fn=on_submit, + inputs=[input_image, denoise_steps, apply_filter], + outputs=[depth_map, model_3d, download_files] + ) + + example_files = os.listdir('assets/examples') + example_files.sort() + example_files = [os.path.join('assets/examples', filename) for filename in example_files] + examples = gr.Examples( + examples=example_files, + inputs=input_image, + outputs=[depth_map, model_3d, download_files], + fn=on_submit, + cache_examples=False + ) + + demo.queue().launch(share=share) + +if __name__ == '__main__': + main(share=True) \ No newline at end of file diff --git a/Pixel-Perfect-Depth/assets/examples/0001.jpg b/Pixel-Perfect-Depth/assets/examples/0001.jpg new file mode 100644 index 0000000000000000000000000000000000000000..380f3404dafe29390cb62937af57a77d8484a293 --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0001.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4205842dfc133d8e76625ded7c31d3a2a2b8f9500919a0f4ecedc32a9bac87be +size 249132 diff --git a/Pixel-Perfect-Depth/assets/examples/0002.png b/Pixel-Perfect-Depth/assets/examples/0002.png new file mode 100644 index 0000000000000000000000000000000000000000..41e5939ffe3ab53c50606024d62ac3cf407288dd --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0002.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d97f69094e48a27cb72ff7be5f7ddcde1eb4da31aee237867cabf1ea2abd5310 +size 1235741 diff --git a/Pixel-Perfect-Depth/assets/examples/0003.JPG b/Pixel-Perfect-Depth/assets/examples/0003.JPG new file mode 100644 index 0000000000000000000000000000000000000000..2bbe802a564c368e92d35ba1808c572b2a73c76d --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0003.JPG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3a194a4e6d6ca4ff07b51baba841ee775095a1d034dd67d24ed45e6da5928fb3 +size 9676202 diff --git a/Pixel-Perfect-Depth/assets/examples/0004.png b/Pixel-Perfect-Depth/assets/examples/0004.png new file mode 100644 index 0000000000000000000000000000000000000000..8af4fa0a2d01a0d6cbf47f23cbe7071402de9981 --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0004.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0d8cf934034b05e01d612452c5b4ae4381baf3929ae1944df09f614e5cbdb0d4 +size 489263 diff --git a/Pixel-Perfect-Depth/assets/examples/0005.jpg b/Pixel-Perfect-Depth/assets/examples/0005.jpg new file mode 100644 index 0000000000000000000000000000000000000000..a237edded6e4c4516e249ca7addd1c73467c2849 --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0005.jpg @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:eeebeb413ec78384dd5942992b5699ff281c6cc50a157e522a3b289e30d0b567 +size 102649 diff --git a/Pixel-Perfect-Depth/assets/examples/0006.PNG b/Pixel-Perfect-Depth/assets/examples/0006.PNG new file mode 100644 index 0000000000000000000000000000000000000000..464a945f479b0d947f955e3adada2e0d411025a5 --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0006.PNG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b8967fdf774e36ba0c27deb104d0afd72d19accf590a5d21fb5ccc2726912eea +size 2594227 diff --git a/Pixel-Perfect-Depth/assets/examples/0007.PNG b/Pixel-Perfect-Depth/assets/examples/0007.PNG new file mode 100644 index 0000000000000000000000000000000000000000..c715c785a868ce520979d8aa1343a4a9b50d5e22 --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0007.PNG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:74c959aabb376bf2541f9bd4b88028e5ff7321a61cff03b2223c802c950537f6 +size 2601977 diff --git a/Pixel-Perfect-Depth/assets/examples/0008.PNG b/Pixel-Perfect-Depth/assets/examples/0008.PNG new file mode 100644 index 0000000000000000000000000000000000000000..f5e29e39bf5d88b4300d6bdf906f226e5111ed04 --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0008.PNG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:668d2323220211f7c318b47668865689c4ff7ff8f9daec0dde0ca923dee47095 +size 2947558 diff --git a/Pixel-Perfect-Depth/assets/examples/0009.PNG b/Pixel-Perfect-Depth/assets/examples/0009.PNG new file mode 100644 index 0000000000000000000000000000000000000000..c1b6dac1ddfdd7b4f3c8faf4113922220ac4ce6c --- /dev/null +++ b/Pixel-Perfect-Depth/assets/examples/0009.PNG @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:5489c9f5b1eb3d856ffd3225ebb8756eb935c92b2ea0f87f4bf00f4dc45c0336 +size 2383574 diff --git a/Pixel-Perfect-Depth/moge/__init__.py b/Pixel-Perfect-Depth/moge/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Pixel-Perfect-Depth/moge/model/__init__.py b/Pixel-Perfect-Depth/moge/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c919e3be42c0005752e8c800129bd5f724b47ff9 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/__init__.py @@ -0,0 +1,18 @@ +import importlib +from typing import * + +if TYPE_CHECKING: + from .v1 import MoGeModel as MoGeModelV1 + from .v2 import MoGeModel as MoGeModelV2 + + +def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]: + assert version in ['v1', 'v2'], f'Unsupported model version: {version}' + + try: + module = importlib.import_module(f'.{version}', __package__) + except ModuleNotFoundError: + raise ValueError(f'Model version "{version}" not found.') + + cls = getattr(module, 'MoGeModel') + return cls diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/__init__.py b/Pixel-Perfect-Depth/moge/model/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/hub/__init__.py b/Pixel-Perfect-Depth/moge/model/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/hub/backbones.py b/Pixel-Perfect-Depth/moge/model/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/hub/utils.py b/Pixel-Perfect-Depth/moge/model/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/__init__.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/attention.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c9f79d471fc099b1dcaa512dfdbdec8a9fc5908f --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/attention.py @@ -0,0 +1,100 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +import torch.nn.functional as F +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + # # Deprecated implementation, extremely slow + # def forward(self, x: Tensor, attn_bias=None) -> Tensor: + # B, N, C = x.shape + # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + # attn = q @ k.transpose(-2, -1) + # attn = attn.softmax(dim=-1) + # attn = self.attn_drop(attn) + # x = (attn @ v).transpose(1, 2).reshape(B, N, C) + # x = self.proj(x) + # x = self.proj_drop(x) + # return x + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = qkv.unbind(0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/block.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/dino_head.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/drop_path.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/layer_scale.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/mlp.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/patch_embed.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/layers/swiglu_ffn.py b/Pixel-Perfect-Depth/moge/model/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/models/__init__.py b/Pixel-Perfect-Depth/moge/model/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/models/vision_transformer.py b/Pixel-Perfect-Depth/moge/model/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..f0bed9d0b7cdcff2b5e129121251c58e41c4c61d --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/models/vision_transformer.py @@ -0,0 +1,407 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable, Optional, List + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + @property + def onnx_compatible_mode(self): + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, h, w): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + batch_size = x.shape[0] + N = self.pos_embed.shape[1] - 1 + if not self.onnx_compatible_mode and npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0, :] + patch_pos_embed = pos_embed[:, 1:, :] + dim = x.shape[-1] + h0, w0 = h // self.patch_size, w // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if not self.onnx_compatible_mode and self.interpolate_offset > 0: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sy, sx) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (h0, w0) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + + assert (h0, w0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2) + return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, h, w = x.shape + x = self.patch_embed(x) + + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, h, w) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks, ar in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/utils/__init__.py b/Pixel-Perfect-Depth/moge/model/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/utils/cluster.py b/Pixel-Perfect-Depth/moge/model/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/utils/config.py b/Pixel-Perfect-Depth/moge/model/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/utils/dtype.py b/Pixel-Perfect-Depth/moge/model/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/utils/param_groups.py b/Pixel-Perfect-Depth/moge/model/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/Pixel-Perfect-Depth/moge/model/dinov2/utils/utils.py b/Pixel-Perfect-Depth/moge/model/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/Pixel-Perfect-Depth/moge/model/modules.py b/Pixel-Perfect-Depth/moge/model/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b36ad48d40a8715da375eb15c74416f34f4f9c04 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/modules.py @@ -0,0 +1,254 @@ +from typing import * +from numbers import Number +import importlib +import itertools +import functools +import sys + +import torch +from torch import Tensor +import torch.nn as nn +import torch.nn.functional as F + +from .dinov2.models.vision_transformer import DinoVisionTransformer +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from ..utils.geometry_torch import normalized_view_plane_uv + + +class ResidualConvBlock(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int = None, + hidden_channels: int = None, + kernel_size: int = 3, + padding_mode: str = 'replicate', + activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', + in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm', + hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm', + ): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = nn.ReLU + elif activation == 'leaky_relu': + activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2) + elif activation =='silu': + activation_cls = nn.SiLU + elif activation == 'elu': + activation_cls = nn.ELU + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \ + nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \ + nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \ + nn.Identity(), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \ + nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \ + nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\ + nn.Identity(), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class DINOv2Encoder(nn.Module): + "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]." + backbone: DinoVisionTransformer + image_mean: torch.Tensor + image_std: torch.Tensor + dim_features: int + + def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs): + super(DINOv2Encoder, self).__init__() + + self.intermediate_layers = intermediate_layers + + # Load the backbone + self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone) + self.backbone_name = backbone + self.backbone = self.hub_loader(pretrained=False) + + self.dim_features = self.backbone.blocks[0].attn.qkv.in_features + self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers) + + self.output_projections = nn.ModuleList([ + nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,) + for _ in range(self.num_features) + ]) + + self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) + self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) + + @property + def onnx_compatible_mode(self): + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + self.backbone.onnx_compatible_mode = value + + def init_weights(self): + pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict() + self.backbone.load_state_dict(pretrained_backbone_state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward(self, image: torch.Tensor, token_rows: Union[int, torch.LongTensor], token_cols: Union[int, torch.LongTensor], return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]: + image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode) + image_14 = (image_14 - self.image_mean) / self.image_std + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True) + + # Project features to the desired dimensionality + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous()) + for proj, (feat, clstoken) in zip(self.output_projections, features) + ], dim=1).sum(dim=1) + + if return_class_token: + return x, features[-1][1] + else: + return x + + +class Resampler(nn.Sequential): + def __init__(self, + in_channels: int, + out_channels: int, + type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], + scale_factor: int = 2, + ): + if type_ == 'pixel_shuffle': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.PixelShuffle(scale_factor), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + for i in range(1, scale_factor ** 2): + self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2] + self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2] + elif type_ in ['nearest', 'bilinear']: + nn.Sequential.__init__(self, + nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None), + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + elif type_ == 'conv_transpose': + nn.Sequential.__init__(self, + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1] + elif type_ == 'pixel_unshuffle': + nn.Sequential.__init__(self, + nn.PixelUnshuffle(scale_factor), + nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + elif type_ == 'avg_pool': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + elif type_ == 'max_pool': + nn.Sequential.__init__(self, + nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor), + ) + else: + raise ValueError(f'Unsupported resampler type: {type_}') + +class MLP(nn.Sequential): + def __init__(self, dims: Sequence[int]): + nn.Sequential.__init__(self, + *itertools.chain(*[ + (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True)) + for dim_in, dim_out in zip(dims[:-2], dims[1:-1]) + ]), + nn.Linear(dims[-2], dims[-1]), + ) + + +class ConvStack(nn.Module): + def __init__(self, + dim_in: List[Optional[int]], + dim_res_blocks: List[int], + dim_out: List[Optional[int]], + resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm', + res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm', + activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', + ): + super().__init__() + self.input_blocks = nn.ModuleList([ + nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity() + for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks) + ]) + self.resamplers = nn.ModuleList([ + Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler) + for i, (dim_prev, dim_succ, resampler) in enumerate(zip( + dim_res_blocks[:-1], + dim_res_blocks[1:], + resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers) + )) + ]) + self.res_blocks = nn.ModuleList([ + nn.Sequential( + *( + ResidualConvBlock( + dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_, + activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm + ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks) + ) + ) for i, dim_res_block_ in enumerate(dim_res_blocks) + ]) + self.output_blocks = nn.ModuleList([ + nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity() + for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks) + ]) + + def enable_gradient_checkpointing(self): + for i in range(len(self.resamplers)): + self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i]) + for i in range(len(self.res_blocks)): + for j in range(len(self.res_blocks[i])): + self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j]) + + def forward(self, in_features: List[torch.Tensor]): + out_features = [] + for i in range(len(self.res_blocks)): + feature = self.input_blocks[i](in_features[i]) + if i == 0: + x = feature + elif feature is not None: + x = x + feature + x = self.res_blocks[i](x) + out_features.append(self.output_blocks[i](x)) + if i < len(self.res_blocks) - 1: + x = self.resamplers[i](x) + return out_features diff --git a/Pixel-Perfect-Depth/moge/model/utils.py b/Pixel-Perfect-Depth/moge/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c50761d8740d9d0a0284e129503b8931c6fe08c4 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/utils.py @@ -0,0 +1,49 @@ +from typing import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + module.__class__ = _AttentionWrapper + return module + + +def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]: + group_to_use = torch.distributed.group.WORLD + world_size = group_to_use.size() + grad = bucket.buffer() + grad.div_(world_size) + torch.distributed.all_reduce(grad, group=group_to_use) + fut = torch.futures.Future() + fut.set_result(grad) + return fut diff --git a/Pixel-Perfect-Depth/moge/model/v1.py b/Pixel-Perfect-Depth/moge/model/v1.py new file mode 100644 index 0000000000000000000000000000000000000000..1c14cc7ab3e03e9eed310fd547fc85d9e2a6ad9e --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/v1.py @@ -0,0 +1,392 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import importlib +import warnings +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +import utils3d +from huggingface_hub import hf_hub_download + + +from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from ..utils.tools import timeit + + +class ResidualConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = lambda: nn.ReLU(inplace=True) + elif activation == 'leaky_relu': + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) + elif activation =='silu': + activation_cls = lambda: nn.SiLU(inplace=True) + elif activation == 'elu': + activation_cls = lambda: nn.ELU(inplace=True) + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class Head(nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1 + ): + super().__init__() + + self.projects = nn.ModuleList([ + nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features) + ]) + + self.upsample_blocks = nn.ModuleList([ + nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks)) + ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ]) + + self.output_block = nn.ModuleList([ + self._make_output_block( + dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm, + ) for dim_out_ in dim_out + ]) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']): + return nn.Sequential( + nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)), + nn.ReLU(inplace=True), + nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'), + ) + + def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + + # Process the hidden states + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) + for proj, (feat, clstoken) in zip(self.projects, hidden_states) + ], dim=1).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + if isinstance(self.output_block, nn.ModuleList): + output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block] + else: + output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) + + return output + + +class MoGeModel(nn.Module): + image_mean: torch.Tensor + image_std: torch.Tensor + + def __init__(self, + encoder: str = 'dinov2_vitb14', + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + num_tokens_range: Tuple[Number, Number] = [1200, 2500], + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + mask_threshold: float = 0.5, + **deprecated_kwargs + ): + super(MoGeModel, self).__init__() + + if deprecated_kwargs: + # Process legacy arguments + if 'trained_area_range' in deprecated_kwargs: + num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2] + del deprecated_kwargs['trained_area_range'] + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.encoder = encoder + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.num_tokens_range = num_tokens_range + self.mask_threshold = mask_threshold + + # NOTE: We have copied the DINOv2 code in torchhub to this repository. + # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues. + hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) + self.backbone = hub_loader(pretrained=False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + self.head = Head( + num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers), + dim_in=dim_feature, + dim_out=[3, 1], + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size + ) + + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel': + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True) + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True) + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model']) + return model + + def init_weights(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == 'linear': + pass + elif self.remap_output =='sinh': + points = torch.sinh(points) + elif self.remap_output == 'exp': + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output =='sinh_exp': + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + original_height, original_width = image.shape[-2:] + + # Resize to expected resolution defined by num_tokens + resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5 + resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor) + image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True) + + # Apply image transformation for DINOv2 + image = (image - self.image_mean) / self.image_std + image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True) + + # Get intermediate layers from the backbone + features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) + + # Predict points (and mask) + output = self.head(features, image) + points, mask = output + + # Make sure fp32 precision for output + with torch.autocast(device_type=image.device.type, dtype=torch.float32): + # Resize to original resolution + points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False) + mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False) + + # Post-process points and mask + points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1) + points = self._remap_points(points) # slightly improves the performance in case of very large output values + + return_dict = {'points': points, 'mask': mask} + return return_dict + + @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + fov_x: Union[Number, torch.Tensor] = None, + resolution_level: int = 9, + num_tokens: int = None, + apply_mask: bool = True, + force_projection: bool = True, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\ + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `resolution_level`: An integer [0-9] for the resolution level for inference. + The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. + `resolution_level` actually controls `num_tokens`. See `num_tokens` for more details. + - `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. + `resolution_level` will be ignored if `num_tokens` is provided. Default: None + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + aspect_ratio = original_width / original_height + + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16): + output = self.forward(image, num_tokens) + points, mask = output['points'], output['mask'] + + # Always process the output in fp32 precision + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x]) + + mask_binary = mask > self.mask_threshold + + # Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) + if fov_x is None: + focal, shift = recover_focal_shift(points, mask_binary) + else: + focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio + fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 + intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + depth = points[..., 2] + shift[..., None, None] + + # If projection constraint is forced, recompute the point map using the actual depth map + if force_projection: + points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) + else: + points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :] + + # Apply mask if needed + if apply_mask: + points = torch.where(mask_binary[..., None], points, torch.inf) + depth = torch.where(mask_binary, depth, torch.inf) + + return_dict = { + 'points': points, + 'intrinsics': intrinsics, + 'depth': depth, + 'mask': mask_binary, + } + + if omit_batch_dim: + return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + return return_dict \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/model/v2.py b/Pixel-Perfect-Depth/moge/model/v2.py new file mode 100644 index 0000000000000000000000000000000000000000..6b69e47951f1b349df6f72fa401ca2eecfdbe81b --- /dev/null +++ b/Pixel-Perfect-Depth/moge/model/v2.py @@ -0,0 +1,303 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import warnings + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.amp +import torch.version +import utils3d +from huggingface_hub import hf_hub_download + +from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3 +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from .modules import DINOv2Encoder, MLP, ConvStack + + +class MoGeModel(nn.Module): + encoder: DINOv2Encoder + neck: ConvStack + points_head: ConvStack + mask_head: ConvStack + scale_head: MLP + onnx_compatible_mode: bool + + def __init__(self, + encoder: Dict[str, Any], + neck: Dict[str, Any], + points_head: Dict[str, Any] = None, + mask_head: Dict[str, Any] = None, + normal_head: Dict[str, Any] = None, + scale_head: Dict[str, Any] = None, + remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', + num_tokens_range: List[int] = [1200, 3600], + **deprecated_kwargs + ): + super(MoGeModel, self).__init__() + if deprecated_kwargs: + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.remap_output = remap_output + self.num_tokens_range = num_tokens_range + + self.encoder = DINOv2Encoder(**encoder) + self.neck = ConvStack(**neck) + if points_head is not None: + self.points_head = ConvStack(**points_head) + if mask_head is not None: + self.mask_head = ConvStack(**mask_head) + if normal_head is not None: + self.normal_head = ConvStack(**normal_head) + if scale_head is not None: + self.scale_head = MLP(**scale_head) + + @property + def device(self) -> torch.device: + return next(self.parameters()).device + + @property + def dtype(self) -> torch.dtype: + return next(self.parameters()).dtype + + @property + def onnx_compatible_mode(self) -> bool: + return getattr(self, "_onnx_compatible_mode", False) + + @onnx_compatible_mode.setter + def onnx_compatible_mode(self, value: bool): + self._onnx_compatible_mode = value + self.encoder.onnx_compatible_mode = value + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel': + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `compiled` + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint_path = pretrained_model_name_or_path + else: + checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model'], strict=False) + + return model + + def init_weights(self): + self.encoder.init_weights() + + def enable_gradient_checkpointing(self): + self.encoder.enable_gradient_checkpointing() + self.neck.enable_gradient_checkpointing() + for head in ['points_head', 'normal_head', 'mask_head']: + if hasattr(self, head): + getattr(self, head).enable_gradient_checkpointing() + + def enable_pytorch_native_sdpa(self): + self.encoder.enable_pytorch_native_sdpa() + + def _remap_points(self, points: torch.Tensor) -> torch.Tensor: + if self.remap_output == 'linear': + pass + elif self.remap_output =='sinh': + points = torch.sinh(points) + elif self.remap_output == 'exp': + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output =='sinh_exp': + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + return points + + def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: + batch_size, _, img_h, img_w = image.shape + device, dtype = image.device, image.dtype + + aspect_ratio = img_w / img_h + base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5) + num_tokens = base_h * base_w + + # Backbones encoding + features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True) + features = [features, None, None, None, None] + + # Concat UVs for aspect ratio input + for level in range(5): + uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1) + if features[level] is None: + features[level] = uv + else: + features[level] = torch.concat([features[level], uv], dim=1) + + # Shared neck + features = self.neck(features) + + # Heads decoding + + points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head']) + metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None + + # Resize + points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask]) + + # Remap output + if points is not None: + points = points.permute(0, 2, 3, 1) + points = self._remap_points(points) # slightly improves the performance in case of very large output values + if normal is not None: + normal = normal.permute(0, 2, 3, 1) + normal = F.normalize(normal, dim=-1) + if mask is not None: + mask = mask.squeeze(1).sigmoid() + if metric_scale is not None: + metric_scale = metric_scale.squeeze(1).exp() + + return_dict = { + 'points': points, + 'normal': normal, + 'mask': mask, + 'metric_scale': metric_scale + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + return return_dict + + @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + num_tokens: int = None, + resolution_level: int = 9, + force_projection: bool = True, + apply_mask: Literal[False, True, 'blend'] = True, + fov_x: Optional[Union[Number, torch.Tensor]] = None, + use_fp16: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W) + - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500. + More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`. + - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True + - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True + - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None + - `use_fp16`: if True, use mixed precision to speed up inference. Default: True + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + image = image.to(dtype=self.dtype, device=self.device) + + original_height, original_width = image.shape[-2:] + area = original_height * original_width + aspect_ratio = original_width / original_height + + # Determine the number of base tokens to use + if num_tokens is None: + min_tokens, max_tokens = self.num_tokens_range + num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) + + # Forward pass + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16): + output = self.forward(image, num_tokens=num_tokens) + points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale']) + + # Always process the output in fp32 precision + points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x]) + with torch.autocast(device_type=self.device.type, dtype=torch.float32): + if mask is not None: + mask_binary = mask > 0.5 + else: + mask_binary = None + + if points is not None: + # Convert affine point map to camera-space. Recover depth and intrinsics from point map. + # NOTE: Focal here is the focal length relative to half the image diagonal + if fov_x is None: + # Recover focal and shift from predicted point map + focal, shift = recover_focal_shift(points, mask_binary) + else: + # Focal is known, recover shift only + focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) + if focal.ndim == 0: + focal = focal[None].expand(points.shape[0]) + _, shift = recover_focal_shift(points, mask_binary, focal=focal) + fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 + intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) + points[..., 2] += shift[..., None, None] + if mask_binary is not None: + mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice) + depth = points[..., 2].clone() + else: + depth, intrinsics = None, None + + # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics + if force_projection and depth is not None: + points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) + + # Apply metric scale + if metric_scale is not None: + if points is not None: + points *= metric_scale[:, None, None, None] + if depth is not None: + depth *= metric_scale[:, None, None] + + # Apply mask + if apply_mask and mask_binary is not None: + points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None + depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None + normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None + + return depth.squeeze().cpu().numpy(), mask_binary.squeeze().cpu().numpy(), intrinsics.squeeze().cpu().numpy() + + # return_dict = { + # 'points': points, + # 'intrinsics': intrinsics, + # 'depth': depth, + # 'mask': mask_binary, + # 'normal': normal + # } + # return_dict = {k: v for k, v in return_dict.items() if v is not None} + + # if omit_batch_dim: + # return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} + + # return return_dict diff --git a/Pixel-Perfect-Depth/moge/scripts/__init__.py b/Pixel-Perfect-Depth/moge/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Pixel-Perfect-Depth/moge/scripts/app.py b/Pixel-Perfect-Depth/moge/scripts/app.py new file mode 100644 index 0000000000000000000000000000000000000000..ba660247f03f5a63bfdbeaaf1a1d48eb8e53777a --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/app.py @@ -0,0 +1,301 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +import sys +from pathlib import Path +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +import time +import uuid +import tempfile +import itertools +from typing import * +import atexit +from concurrent.futures import ThreadPoolExecutor +import shutil + +import click + + +@click.command(help='Web demo') +@click.option('--share', is_flag=True, help='Whether to run the app in shared mode.') +@click.option('--pretrained', 'pretrained_model_name_or_path', default=None, help='The name or path of the pre-trained model.') +@click.option('--version', 'model_version', default='v2', help='The version of the model.') +@click.option('--fp16', 'use_fp16', is_flag=True, help='Whether to use fp16 inference.') +def main(share: bool, pretrained_model_name_or_path: str, model_version: str, use_fp16: bool): + print("Import modules...") + # Lazy import + import cv2 + import torch + import numpy as np + import trimesh + import trimesh.visual + from PIL import Image + import gradio as gr + try: + import spaces # This is for deployment at huggingface.co/spaces + HUGGINFACE_SPACES_INSTALLED = True + except ImportError: + HUGGINFACE_SPACES_INSTALLED = False + + import utils3d + from moge.utils.io import write_normal + from moge.utils.vis import colorize_depth, colorize_normal + from moge.model import import_model_class_by_version + from moge.utils.geometry_numpy import depth_occlusion_edge_numpy + from moge.utils.tools import timeit + + print("Load model...") + if pretrained_model_name_or_path is None: + DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = { + "v1": "Ruicheng/moge-vitl", + "v2": "Ruicheng/moge-2-vitl-normal", + } + pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version] + model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).cuda().eval() + if use_fp16: + model.half() + thread_pool_executor = ThreadPoolExecutor(max_workers=1) + + def delete_later(path: Union[str, os.PathLike], delay: int = 300): + def _delete(): + try: + os.remove(path) + except FileNotFoundError: + pass + def _wait_and_delete(): + time.sleep(delay) + _delete(path) + thread_pool_executor.submit(_wait_and_delete) + atexit.register(_delete) + + # Inference on GPU. + @(spaces.GPU if HUGGINFACE_SPACES_INSTALLED else lambda x: x) + def run_with_gpu(image: np.ndarray, resolution_level: int, apply_mask: bool) -> Dict[str, np.ndarray]: + image_tensor = torch.tensor(image, dtype=torch.float32 if not use_fp16 else torch.float16, device=torch.device('cuda')).permute(2, 0, 1) / 255 + output = model.infer(image_tensor, apply_mask=apply_mask, resolution_level=resolution_level, use_fp16=use_fp16) + output = {k: v.cpu().numpy() for k, v in output.items()} + return output + + # Full inference pipeline + def run(image: np.ndarray, max_size: int = 800, resolution_level: str = 'High', apply_mask: bool = True, remove_edge: bool = True, request: gr.Request = None): + larger_size = max(image.shape[:2]) + if larger_size > max_size: + scale = max_size / larger_size + image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + + height, width = image.shape[:2] + + resolution_level_int = {'Low': 0, 'Medium': 5, 'High': 9, 'Ultra': 30}.get(resolution_level, 9) + output = run_with_gpu(image, resolution_level_int, apply_mask) + + points, depth, mask, normal = output['points'], output['depth'], output['mask'], output.get('normal', None) + + if remove_edge: + mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=0.04) + else: + mask_cleaned = mask + + results = { + **output, + 'mask_cleaned': mask_cleaned, + 'image': image + } + + # depth & normal visualization + depth_vis = colorize_depth(depth) + if normal is not None: + normal_vis = colorize_normal(normal) + else: + normal_vis = gr.update(label="Normal map (not avalable for this model)") + + # mesh & pointcloud + if normal is None: + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask_cleaned, + tri=True + ) + vertex_normals = None + else: + faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + normal, + mask=mask_cleaned, + tri=True + ) + vertices = vertices * np.array([1, -1, -1], dtype=np.float32) + vertex_uvs = vertex_uvs * np.array([1, -1], dtype=np.float32) + np.array([0, 1], dtype=np.float32) + if vertex_normals is not None: + vertex_normals = vertex_normals * np.array([1, -1, -1], dtype=np.float32) + + tempdir = Path(tempfile.gettempdir(), 'moge') + tempdir.mkdir(exist_ok=True) + output_path = Path(tempdir, request.session_hash) + shutil.rmtree(output_path, ignore_errors=True) + output_path.mkdir(exist_ok=True, parents=True) + trimesh.Trimesh( + vertices=vertices, + faces=faces, + visual = trimesh.visual.texture.TextureVisuals( + uv=vertex_uvs, + material=trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(image), + metallicFactor=0.5, + roughnessFactor=1.0 + ) + ), + vertex_normals=vertex_normals, + process=False + ).export(output_path / 'mesh.glb') + pointcloud = trimesh.PointCloud( + vertices=vertices, + colors=vertex_colors, + ) + pointcloud.vertex_normals = vertex_normals + pointcloud.export(output_path / 'pointcloud.ply', vertex_normal=True) + trimesh.PointCloud( + vertices=vertices, + colors=vertex_colors, + ).export(output_path / 'pointcloud.glb', include_normals=True) + cv2.imwrite(str(output_path /'mask.png'), mask.astype(np.uint8) * 255) + cv2.imwrite(str(output_path / 'depth.exr'), depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(output_path / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + if normal is not None: + cv2.imwrite(str(output_path / 'normal.exr'), cv2.cvtColor(normal.astype(np.float32) * np.array([1, -1, -1], dtype=np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF]) + + files = ['mesh.glb', 'pointcloud.ply', 'depth.exr', 'points.exr', 'mask.png'] + if normal is not None: + files.append('normal.exr') + + for f in files: + delete_later(output_path / f) + + # FOV + intrinsics = results['intrinsics'] + fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) + fov_x, fov_y = np.rad2deg([fov_x, fov_y]) + + # messages + viewer_message = f'**Note:** Inference has been completed. It may take a few seconds to download the 3D model.' + if resolution_level != 'Ultra': + depth_message = f'**Note:** Want sharper depth map? Try increasing the `maximum image size` and setting the `inference resolution level` to `Ultra` in the settings.' + else: + depth_message = "" + + return ( + results, + depth_vis, + normal_vis, + output_path / 'pointcloud.glb', + [(output_path / f).as_posix() for f in files if (output_path / f).exists()], + f'- **Horizontal FOV: {fov_x:.1f}Β°**. \n - **Vertical FOV: {fov_y:.1f}Β°**', + viewer_message, + depth_message + ) + + def reset_measure(results: Dict[str, np.ndarray]): + return [results['image'], [], ""] + + + def measure(results: Dict[str, np.ndarray], measure_points: List[Tuple[int, int]], event: gr.SelectData): + point2d = event.index[0], event.index[1] + measure_points.append(point2d) + + image = results['image'].copy() + for p in measure_points: + image = cv2.circle(image, p, radius=5, color=(255, 0, 0), thickness=2) + + depth_text = "" + for i, p in enumerate(measure_points): + d = results['depth'][p[1], p[0]] + depth_text += f"- **P{i + 1} depth: {d:.2f}m.**\n" + + if len(measure_points) == 2: + point1, point2 = measure_points + image = cv2.line(image, point1, point2, color=(255, 0, 0), thickness=2) + distance = np.linalg.norm(results['points'][point1[1], point1[0]] - results['points'][point2[1], point2[0]]) + measure_points = [] + + distance_text = f"- **Distance: {distance:.2f}m**" + + text = depth_text + distance_text + return [image, measure_points, text] + else: + return [image, measure_points, depth_text] + + print("Create Gradio app...") + with gr.Blocks(theme=gr.themes.Soft()) as demo: + gr.Markdown( +f''' +
+

Turn a 2D image into 3D with MoGe badge-github-stars

+
+''') + results = gr.State(value=None) + measure_points = gr.State(value=[]) + + with gr.Row(): + with gr.Column(): + input_image = gr.Image(type="numpy", image_mode="RGB", label="Input Image") + with gr.Accordion(label="Settings", open=False): + max_size_input = gr.Number(value=800, label="Maximum Image Size", precision=0, minimum=256, maximum=2048) + resolution_level = gr.Dropdown(['Low', 'Medium', 'High', 'Ultra'], label="Inference Resolution Level", value='High') + apply_mask = gr.Checkbox(value=True, label="Apply mask") + remove_edges = gr.Checkbox(value=True, label="Remove edges") + submit_btn = gr.Button("Submit", variant='primary') + + with gr.Column(): + with gr.Tabs(): + with gr.Tab("3D View"): + viewer_message = gr.Markdown("") + model_3d = gr.Model3D(display_mode="solid", label="3D Point Map", clear_color=[1.0, 1.0, 1.0, 1.0], height="60vh") + fov = gr.Markdown() + with gr.Tab("Depth"): + depth_message = gr.Markdown("") + depth_map = gr.Image(type="numpy", label="Colorized Depth Map", format='png', interactive=False) + with gr.Tab("Normal", interactive=hasattr(model, 'normal_head')): + normal_map = gr.Image(type="numpy", label="Normal Map", format='png', interactive=False) + with gr.Tab("Measure", interactive=hasattr(model, 'scale_head')): + gr.Markdown("### Click on the image to measure the distance between two points. \n" + "**Note:** Metric scale is most reliable for typical indoor or street scenes, and may degrade for contents unfamiliar to the model (e.g., stylized or close-up images).") + measure_image = gr.Image(type="numpy", show_label=False, format='webp', interactive=False, sources=[]) + measure_text = gr.Markdown("") + with gr.Tab("Download"): + files = gr.File(type='filepath', label="Output Files") + + if Path('example_images').exists(): + example_image_paths = sorted(list(itertools.chain(*[Path('example_images').glob(f'*.{ext}') for ext in ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG']]))) + examples = gr.Examples( + examples = example_image_paths, + inputs=input_image, + label="Examples" + ) + + submit_btn.click( + fn=lambda: [None, None, None, None, None, "", "", ""], + outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message] + ).then( + fn=run, + inputs=[input_image, max_size_input, resolution_level, apply_mask, remove_edges], + outputs=[results, depth_map, normal_map, model_3d, files, fov, viewer_message, depth_message] + ).then( + fn=reset_measure, + inputs=[results], + outputs=[measure_image, measure_points, measure_text] + ) + + measure_image.select( + fn=measure, + inputs=[results, measure_points], + outputs=[measure_image, measure_points, measure_text] + ) + + demo.launch(share=share) + + +if __name__ == '__main__': + main() diff --git a/Pixel-Perfect-Depth/moge/scripts/cli.py b/Pixel-Perfect-Depth/moge/scripts/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..45c3b9006bf56306e403f8da5b6d5068215221ee --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/cli.py @@ -0,0 +1,27 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) + +import click + + +@click.group(help='MoGe command line interface.') +def cli(): + pass + +def main(): + from moge.scripts import app, infer, infer_baseline, infer_panorama, eval_baseline, vis_data + cli.add_command(app.main, name='app') + cli.add_command(infer.main, name='infer') + cli.add_command(infer_baseline.main, name='infer_baseline') + cli.add_command(infer_panorama.main, name='infer_panorama') + cli.add_command(eval_baseline.main, name='eval_baseline') + cli.add_command(vis_data.main, name='vis_data') + cli() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/scripts/eval_baseline.py b/Pixel-Perfect-Depth/moge/scripts/eval_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..8217d9e6500b1d72a00e1a0a225ba4c2134b892e --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/eval_baseline.py @@ -0,0 +1,165 @@ +import os +import sys +from pathlib import Path +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +import json +from typing import * +import importlib +import importlib.util + +import click + + +@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Evaluation script.') +@click.option('--baseline', 'baseline_code_path', type=click.Path(), required=True, help='Path to the baseline model python code.') +@click.option('--config', 'config_path', type=click.Path(), default='configs/eval/all_benchmarks.json', help='Path to the evaluation configurations. ' + 'Defaults to "configs/eval/all_benchmarks.json".') +@click.option('--output', '-o', 'output_path', type=click.Path(), required=True, help='Path to the output json file.') +@click.option('--oracle', 'oracle_mode', is_flag=True, help='Use oracle mode for evaluation, i.e., use the GT intrinsics input.') +@click.option('--dump_pred', is_flag=True, help='Dump predition results.') +@click.option('--dump_gt', is_flag=True, help='Dump ground truth.') +@click.pass_context +def main(ctx: click.Context, baseline_code_path: str, config_path: str, oracle_mode: bool, output_path: Union[str, Path], dump_pred: bool, dump_gt: bool): + # Lazy import + import cv2 + import numpy as np + from tqdm import tqdm + import torch + import torch.nn.functional as F + import utils3d + + from moge.test.baseline import MGEBaselineInterface + from moge.test.dataloader import EvalDataLoaderPipeline + from moge.test.metrics import compute_metrics + from moge.utils.geometry_torch import intrinsics_to_fov + from moge.utils.vis import colorize_depth, colorize_normal + from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module + + # Load the baseline model + module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) + baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline') + baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False) + + # Load the evaluation configurations + with open(config_path, 'r') as f: + config = json.load(f) + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + all_metrics = {} + # Iterate over the dataset + for benchmark_name, benchmark_config in tqdm(list(config.items()), desc='Benchmarks'): + filenames, metrics_list = [], [] + with ( + EvalDataLoaderPipeline(**benchmark_config) as eval_data_pipe, + tqdm(total=len(eval_data_pipe), desc=benchmark_name, leave=False) as pbar + ): + # Iterate over the samples in the dataset + for i in range(len(eval_data_pipe)): + sample = eval_data_pipe.get() + sample = {k: v.to(baseline.device) if isinstance(v, torch.Tensor) else v for k, v in sample.items()} + image = sample['image'] + gt_intrinsics = sample['intrinsics'] + + # Inference + torch.cuda.synchronize() + with torch.inference_mode(), timeit('_inference_timer', verbose=False) as timer: + if oracle_mode: + pred = baseline.infer_for_evaluation(image, gt_intrinsics) + else: + pred = baseline.infer_for_evaluation(image) + torch.cuda.synchronize() + + # Compute metrics + metrics, misc = compute_metrics(pred, sample, vis=dump_pred or dump_gt) + metrics['inference_time'] = timer.time + metrics_list.append(metrics) + + # Dump results + dump_path = Path(output_path.replace(".json", f"_dump"), f'{benchmark_name}', sample['filename'].replace('.zip', '')) + if dump_pred: + dump_path.joinpath('pred').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(dump_path / 'pred' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + with Path(dump_path, 'pred', 'metrics.json').open('w') as f: + json.dump(metrics, f, indent=4) + + if 'pred_points' in misc: + points = misc['pred_points'].cpu().numpy() + cv2.imwrite(str(dump_path / 'pred' / 'points.exr'), cv2.cvtColor(points.astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + + if 'pred_depth' in misc: + depth = misc['pred_depth'].cpu().numpy() + if 'mask' in pred: + mask = pred['mask'].cpu().numpy() + depth = np.where(mask, depth, np.inf) + cv2.imwrite(str(dump_path / 'pred' / 'depth.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) + + if 'mask' in pred: + mask = pred['mask'].cpu().numpy() + cv2.imwrite(str(dump_path / 'pred' / 'mask.png'), (mask * 255).astype(np.uint8)) + + if 'normal' in pred: + normal = pred['normal'].cpu().numpy() + cv2.imwrite(str(dump_path / 'pred' / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR)) + + if 'intrinsics' in pred: + intrinsics = pred['intrinsics'] + fov_x, fov_y = intrinsics_to_fov(intrinsics) + with open(dump_path / 'pred' / 'fov.json', 'w') as f: + json.dump({ + 'fov_x': np.rad2deg(fov_x.item()), + 'fov_y': np.rad2deg(fov_y.item()), + 'intrinsics': intrinsics.cpu().numpy().tolist(), + }, f) + + if dump_gt: + dump_path.joinpath('gt').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(dump_path / 'gt' / 'image.jpg'), cv2.cvtColor((image.cpu().numpy().transpose(1, 2, 0) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + + if 'points' in sample: + points = sample['points'] + cv2.imwrite(str(dump_path / 'gt' / 'points.exr'), cv2.cvtColor(points.cpu().numpy().astype(np.float32), cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + + if 'depth' in sample: + depth = sample['depth'] + mask = sample['depth_mask'] + cv2.imwrite(str(dump_path / 'gt' / 'depth.png'), cv2.cvtColor(colorize_depth(depth.cpu().numpy(), mask=mask.cpu().numpy()), cv2.COLOR_RGB2BGR)) + + if 'normal' in sample: + normal = sample['normal'] + cv2.imwrite(str(dump_path / 'gt' / 'normal.png'), cv2.cvtColor(colorize_normal(normal.cpu().numpy()), cv2.COLOR_RGB2BGR)) + + if 'depth_mask' in sample: + mask = sample['depth_mask'] + cv2.imwrite(str(dump_path / 'gt' /'mask.png'), (mask.cpu().numpy() * 255).astype(np.uint8)) + + if 'intrinsics' in sample: + intrinsics = sample['intrinsics'] + fov_x, fov_y = intrinsics_to_fov(intrinsics) + with open(dump_path / 'gt' / 'info.json', 'w') as f: + json.dump({ + 'fov_x': np.rad2deg(fov_x.item()), + 'fov_y': np.rad2deg(fov_y.item()), + 'intrinsics': intrinsics.cpu().numpy().tolist(), + }, f) + + # Save intermediate results + if i % 100 == 0 or i == len(eval_data_pipe) - 1: + Path(output_path).write_text( + json.dumps({ + **all_metrics, + benchmark_name: key_average(metrics_list) + }, indent=4) + ) + pbar.update(1) + + all_metrics[benchmark_name] = key_average(metrics_list) + + # Save final results + all_metrics['mean'] = key_average(list(all_metrics.values())) + Path(output_path).write_text(json.dumps(all_metrics, indent=4)) + + +if __name__ == '__main__': + main() diff --git a/Pixel-Perfect-Depth/moge/scripts/infer.py b/Pixel-Perfect-Depth/moge/scripts/infer.py new file mode 100644 index 0000000000000000000000000000000000000000..fd92b29845a5079fdfef94ddb4b2168e70cd4ee6 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/infer.py @@ -0,0 +1,170 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +from typing import * +import itertools +import json +import warnings + + +import click + + +@click.command(help='Inference script') +@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), help='Input image or folder path. "jpg" and "png" are supported.') +@click.option('--fov_x', 'fov_x_', type=float, default=None, help='If camera parameters are known, set the horizontal field of view in degrees. Otherwise, MoGe will estimate it.') +@click.option('--output', '-o', 'output_path', default='./output', type=click.Path(), help='Output folder path') +@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default=None, help='Pretrained model name or path. If not provided, the corresponding default model will be chosen.') +@click.option('--version', 'model_version', type=click.Choice(['v1', 'v2']), default='v2', help='Model version. Defaults to "v2"') +@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"') +@click.option('--fp16', 'use_fp16', is_flag=True, help='Use fp16 precision for much faster inference.') +@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).') +@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level for inference. \ +Higher value means more tokens and the finer details will be captured, but inference can be slower. \ +Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. \ +`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.') +@click.option('--num_tokens', type=int, default=None, help='number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. \ +`resolution_level` will be ignored if `num_tokens` is provided. Default: None') +@click.option('--threshold', type=float, default=0.04, help='Threshold for removing edges. Defaults to 0.01. Smaller value removes more edges. "inf" means no thresholding.') +@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps (image, point map, depth map, normal map, mask) and fov.') +@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') +@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') +@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') +def main( + input_path: str, + fov_x_: float, + output_path: str, + pretrained_model_name_or_path: str, + model_version: str, + device_name: str, + use_fp16: bool, + resize_to: int, + resolution_level: int, + num_tokens: int, + threshold: float, + save_maps_: bool, + save_glb_: bool, + save_ply_: bool, + show: bool, +): + import cv2 + import numpy as np + import torch + from PIL import Image + from tqdm import tqdm + import trimesh + import trimesh.visual + import click + + from moge.model import import_model_class_by_version + from moge.utils.io import save_glb, save_ply + from moge.utils.vis import colorize_depth, colorize_normal + from moge.utils.geometry_numpy import depth_occlusion_edge_numpy + import utils3d + + device = torch.device(device_name) + + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if len(image_paths) == 0: + raise FileNotFoundError(f'No image files found in {input_path}') + + if pretrained_model_name_or_path is None: + DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION = { + "v1": "Ruicheng/moge-vitl", + "v2": "Ruicheng/moge-2-vitl-normal", + } + pretrained_model_name_or_path = DEFAULT_PRETRAINED_MODEL_FOR_EACH_VERSION[model_version] + model = import_model_class_by_version(model_version).from_pretrained(pretrained_model_name_or_path).to(device).eval() + if use_fp16: + model.half() + + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') + save_maps_ = save_glb_ = save_ply_ = True + + for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): + image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image.shape[:2] + if resize_to is not None: + height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) + image = cv2.resize(image, (width, height), cv2.INTER_AREA) + image_tensor = torch.tensor(image / 255, dtype=torch.float32, device=device).permute(2, 0, 1) + + # Inference + output = model.infer(image_tensor, fov_x=fov_x_, resolution_level=resolution_level, num_tokens=num_tokens, use_fp16=use_fp16) + points, depth, mask, intrinsics = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy(), output['intrinsics'].cpu().numpy() + normal = output['normal'].cpu().numpy() if 'normal' in output else None + + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + save_path.mkdir(exist_ok=True, parents=True) + + # Save images / maps + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(depth), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path / 'mask.png'), (mask * 255).astype(np.uint8)) + cv2.imwrite(str(save_path / 'points.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + if normal is not None: + cv2.imwrite(str(save_path / 'normal.png'), cv2.cvtColor(colorize_normal(normal), cv2.COLOR_RGB2BGR)) + fov_x, fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) + with open(save_path / 'fov.json', 'w') as f: + json.dump({ + 'fov_x': round(float(np.rad2deg(fov_x)), 2), + 'fov_y': round(float(np.rad2deg(fov_y)), 2), + }, f) + + # Export mesh & visulization + if save_glb_ or save_ply_ or show: + mask_cleaned = mask & ~utils3d.numpy.depth_edge(depth, rtol=threshold) + if normal is None: + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask_cleaned, + tri=True + ) + vertex_normals = None + else: + faces, vertices, vertex_colors, vertex_uvs, vertex_normals = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + normal, + mask=mask_cleaned, + tri=True + ) + # When exporting the model, follow the OpenGL coordinate conventions: + # - world coordinate system: x right, y up, z backward. + # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. + vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] + if normal is not None: + vertex_normals = vertex_normals * [1, -1, -1] + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image, vertex_normals) + + if save_ply_: + save_ply(save_path / 'pointcloud.ply', vertices, np.zeros((0, 3), dtype=np.int32), vertex_colors, vertex_normals) + + if show: + trimesh.Trimesh( + vertices=vertices, + vertex_colors=vertex_colors, + vertex_normals=vertex_normals, + faces=faces, + process=False + ).show() + + +if __name__ == '__main__': + main() diff --git a/Pixel-Perfect-Depth/moge/scripts/infer_baseline.py b/Pixel-Perfect-Depth/moge/scripts/infer_baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..5409674f7cd5ce21de9200fd9038cb7d71c99e0f --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/infer_baseline.py @@ -0,0 +1,140 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +import json +from pathlib import Path +from typing import * +import itertools +import warnings + +import click + + +@click.command(context_settings={"allow_extra_args": True, "ignore_unknown_options": True}, help='Inference script for wrapped baselines methods') +@click.option('--baseline', 'baseline_code_path', required=True, type=click.Path(), help='Path to the baseline model python code.') +@click.option('--input', '-i', 'input_path', type=str, required=True, help='Input image or folder') +@click.option('--output', '-o', 'output_path', type=str, default='./output', help='Output folder') +@click.option('--size', 'image_size', type=int, default=None, help='Resize input image') +@click.option('--skip', is_flag=True, help='Skip existing output') +@click.option('--maps', 'save_maps_', is_flag=True, help='Save output point / depth maps') +@click.option('--ply', 'save_ply_', is_flag=True, help='Save mesh in PLY format') +@click.option('--glb', 'save_glb_', is_flag=True, help='Save mesh in GLB format') +@click.option('--threshold', type=float, default=0.03, help='Depth edge detection threshold for saving mesh') +@click.pass_context +def main(ctx: click.Context, baseline_code_path: str, input_path: str, output_path: str, image_size: int, skip: bool, save_maps_, save_ply_: bool, save_glb_: bool, threshold: float): + # Lazy import + import cv2 + import numpy as np + from tqdm import tqdm + import torch + import utils3d + + from moge.utils.io import save_ply, save_glb + from moge.utils.geometry_numpy import intrinsics_to_fov_numpy + from moge.utils.vis import colorize_depth, colorize_depth_affine, colorize_disparity + from moge.utils.tools import key_average, flatten_nested_dict, timeit, import_file_as_module + from moge.test.baseline import MGEBaselineInterface + + # Load the baseline model + module = import_file_as_module(baseline_code_path, Path(baseline_code_path).stem) + baseline_cls: Type[MGEBaselineInterface] = getattr(module, 'Baseline') + baseline : MGEBaselineInterface = baseline_cls.load.main(ctx.args, standalone_mode=False) + + # Input images list + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Defaults to saving maps only. Please use "--maps", "--glb", or "--ply" to specify the output.') + save_maps_ = True + + for image_path in (pbar := tqdm(image_paths, desc='Inference', disable=len(image_paths) <= 1)): + # Load one image at a time + image_np = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image_np.shape[:2] + if image_size is not None and max(image_np.shape[:2]) > image_size: + height, width = min(image_size, int(image_size * height / width)), min(image_size, int(image_size * width / height)) + image_np = cv2.resize(image_np, (width, height), cv2.INTER_AREA) + image = torch.from_numpy(image_np.astype(np.float32) / 255.0).permute(2, 0, 1).to(baseline.device) + + # Inference + torch.cuda.synchronize() + with torch.inference_mode(), (timer := timeit('Inference', verbose=False, average=True)): + output = baseline.infer(image) + torch.cuda.synchronize() + + inference_time = timer.average_time + pbar.set_postfix({'average inference time': f'{inference_time:.3f}s'}) + + # Save the output + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + if skip and save_path.exists(): + continue + save_path.mkdir(parents=True, exist_ok=True) + + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image_np, cv2.COLOR_RGB2BGR)) + + if 'mask' in output: + mask = output['mask'].cpu().numpy() + cv2.imwrite(str(save_path /'mask.png'), (mask * 255).astype(np.uint8)) + + for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']: + if k in output: + points = output[k].cpu().numpy() + cv2.imwrite(str(save_path / f'{k}.exr'), cv2.cvtColor(points, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + + for k in ['depth_metric', 'depth_scale_invariant', 'depth_affine_invariant', 'disparity_affine_invariant']: + if k in output: + depth = output[k].cpu().numpy() + cv2.imwrite(str(save_path / f'{k}.exr'), depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + if k in ['depth_metric', 'depth_scale_invariant']: + depth_vis = colorize_depth(depth) + elif k == 'depth_affine_invariant': + depth_vis = colorize_depth_affine(depth) + elif k == 'disparity_affine_invariant': + depth_vis = colorize_disparity(depth) + cv2.imwrite(str(save_path / f'{k}_vis.png'), cv2.cvtColor(depth_vis, cv2.COLOR_RGB2BGR)) + + if 'intrinsics' in output: + intrinsics = output['intrinsics'].cpu().numpy() + fov_x, fov_y = intrinsics_to_fov_numpy(intrinsics) + with open(save_path / 'fov.json', 'w') as f: + json.dump({ + 'fov_x': float(np.rad2deg(fov_x)), + 'fov_y': float(np.rad2deg(fov_y)), + 'intrinsics': intrinsics.tolist() + }, f, indent=4) + + # Export mesh & visulization + if save_ply_ or save_glb_: + assert any(k in output for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant']), 'No point map found in output' + points = next(output[k] for k in ['points_metric', 'points_scale_invariant', 'points_affine_invariant'] if k in output).cpu().numpy() + mask = output['mask'] if 'mask' in output else np.ones_like(points[..., 0], dtype=bool) + normals, normals_mask = utils3d.numpy.points_to_normals(points, mask=mask) + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image_np.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=mask & ~(utils3d.numpy.depth_edge(depth, rtol=threshold, mask=mask) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), + tri=True + ) + # When exporting the model, follow the OpenGL coordinate conventions: + # - world coordinate system: x right, y up, z backward. + # - texture coordinate system: (0, 0) for left-bottom, (1, 1) for right-top. + vertices, vertex_uvs = vertices * [1, -1, -1], vertex_uvs * [1, -1] + [0, 1] + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image_np) + + if save_ply_: + save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) + +if __name__ == '__main__': + main() diff --git a/Pixel-Perfect-Depth/moge/scripts/infer_panorama.py b/Pixel-Perfect-Depth/moge/scripts/infer_panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..cce65cb90cd1c6750d42cdda4e72d4ce3a2c0549 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/infer_panorama.py @@ -0,0 +1,162 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +from typing import * +import itertools +import json +import warnings + +import click + + +@click.command(help='Inference script for panorama images') +@click.option('--input', '-i', 'input_path', type=click.Path(exists=True), required=True, help='Input image or folder path. "jpg" and "png" are supported.') +@click.option('--output', '-o', 'output_path', type=click.Path(), default='./output', help='Output folder path') +@click.option('--pretrained', 'pretrained_model_name_or_path', type=str, default='Ruicheng/moge-vitl', help='Pretrained model name or path. Defaults to "Ruicheng/moge-vitl"') +@click.option('--device', 'device_name', type=str, default='cuda', help='Device name (e.g. "cuda", "cuda:0", "cpu"). Defaults to "cuda"') +@click.option('--resize', 'resize_to', type=int, default=None, help='Resize the image(s) & output maps to a specific size. Defaults to None (no resizing).') +@click.option('--resolution_level', type=int, default=9, help='An integer [0-9] for the resolution level of inference. The higher, the better but slower. Defaults to 9. Note that it is irrelevant to the output resolution.') +@click.option('--threshold', type=float, default=0.03, help='Threshold for removing edges. Defaults to 0.03. Smaller value removes more edges. "inf" means no thresholding.') +@click.option('--batch_size', type=int, default=4, help='Batch size for inference. Defaults to 4.') +@click.option('--splitted', 'save_splitted', is_flag=True, help='Whether to save the splitted images. Defaults to False.') +@click.option('--maps', 'save_maps_', is_flag=True, help='Whether to save the output maps and fov(image, depth, mask, points, fov).') +@click.option('--glb', 'save_glb_', is_flag=True, help='Whether to save the output as a.glb file. The color will be saved as a texture.') +@click.option('--ply', 'save_ply_', is_flag=True, help='Whether to save the output as a.ply file. The color will be saved as vertex colors.') +@click.option('--show', 'show', is_flag=True, help='Whether show the output in a window. Note that this requires pyglet<2 installed as required by trimesh.') +def main( + input_path: str, + output_path: str, + pretrained_model_name_or_path: str, + device_name: str, + resize_to: int, + resolution_level: int, + threshold: float, + batch_size: int, + save_splitted: bool, + save_maps_: bool, + save_glb_: bool, + save_ply_: bool, + show: bool, +): + # Lazy import + import cv2 + import numpy as np + from numpy import ndarray + import torch + from PIL import Image + from tqdm import tqdm, trange + import trimesh + import trimesh.visual + from scipy.sparse import csr_array, hstack, vstack + from scipy.ndimage import convolve + from scipy.sparse.linalg import lsmr + + import utils3d + from moge.model.v1 import MoGeModel + from moge.utils.io import save_glb, save_ply + from moge.utils.vis import colorize_depth + from moge.utils.panorama import spherical_uv_to_directions, get_panorama_cameras, split_panorama_image, merge_panorama_depth + + + device = torch.device(device_name) + + include_suffices = ['jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG'] + if Path(input_path).is_dir(): + image_paths = sorted(itertools.chain(*(Path(input_path).rglob(f'*.{suffix}') for suffix in include_suffices))) + else: + image_paths = [Path(input_path)] + + if len(image_paths) == 0: + raise FileNotFoundError(f'No image files found in {input_path}') + + # Write outputs + if not any([save_maps_, save_glb_, save_ply_]): + warnings.warn('No output format specified. Defaults to saving all. Please use "--maps", "--glb", or "--ply" to specify the output.') + save_maps_ = save_glb_ = save_ply_ = True + + model = MoGeModel.from_pretrained(pretrained_model_name_or_path).to(device).eval() + + for image_path in (pbar := tqdm(image_paths, desc='Total images', disable=len(image_paths) <= 1)): + image = cv2.cvtColor(cv2.imread(str(image_path)), cv2.COLOR_BGR2RGB) + height, width = image.shape[:2] + if resize_to is not None: + height, width = min(resize_to, int(resize_to * height / width)), min(resize_to, int(resize_to * width / height)) + image = cv2.resize(image, (width, height), cv2.INTER_AREA) + + splitted_extrinsics, splitted_intriniscs = get_panorama_cameras() + splitted_resolution = 512 + splitted_images = split_panorama_image(image, splitted_extrinsics, splitted_intriniscs, splitted_resolution) + + # Infer each view + print('Inferring...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + + splitted_distance_maps, splitted_masks = [], [] + for i in trange(0, len(splitted_images), batch_size, desc='Inferring splitted views', disable=len(splitted_images) <= batch_size, leave=False): + image_tensor = torch.tensor(np.stack(splitted_images[i:i + batch_size]) / 255, dtype=torch.float32, device=device).permute(0, 3, 1, 2) + fov_x, fov_y = np.rad2deg(utils3d.numpy.intrinsics_to_fov(np.array(splitted_intriniscs[i:i + batch_size]))) + fov_x = torch.tensor(fov_x, dtype=torch.float32, device=device) + output = model.infer(image_tensor, fov_x=fov_x, apply_mask=False) + distance_map, mask = output['points'].norm(dim=-1).cpu().numpy(), output['mask'].cpu().numpy() + splitted_distance_maps.extend(list(distance_map)) + splitted_masks.extend(list(mask)) + + # Save splitted + if save_splitted: + splitted_save_path = Path(output_path, image_path.stem, 'splitted') + splitted_save_path.mkdir(exist_ok=True, parents=True) + for i in range(len(splitted_images)): + cv2.imwrite(str(splitted_save_path / f'{i:02d}.jpg'), cv2.cvtColor(splitted_images[i], cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(splitted_save_path / f'{i:02d}_distance_vis.png'), cv2.cvtColor(colorize_depth(splitted_distance_maps[i], splitted_masks[i]), cv2.COLOR_RGB2BGR)) + + # Merge + print('Merging...') if pbar.disable else pbar.set_postfix_str(f'Merging') + + merging_width, merging_height = min(1920, width), min(960, height) + panorama_depth, panorama_mask = merge_panorama_depth(merging_width, merging_height, splitted_distance_maps, splitted_masks, splitted_extrinsics, splitted_intriniscs) + panorama_depth = panorama_depth.astype(np.float32) + panorama_depth = cv2.resize(panorama_depth, (width, height), cv2.INTER_LINEAR) + panorama_mask = cv2.resize(panorama_mask.astype(np.uint8), (width, height), cv2.INTER_NEAREST) > 0 + points = panorama_depth[:, :, None] * spherical_uv_to_directions(utils3d.numpy.image_uv(width=width, height=height)) + + # Write outputs + print('Writing outputs...') if pbar.disable else pbar.set_postfix_str(f'Inferring') + save_path = Path(output_path, image_path.relative_to(input_path).parent, image_path.stem) + save_path.mkdir(exist_ok=True, parents=True) + if save_maps_: + cv2.imwrite(str(save_path / 'image.jpg'), cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth_vis.png'), cv2.cvtColor(colorize_depth(panorama_depth, mask=panorama_mask), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_path / 'depth.exr'), panorama_depth, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path / 'points.exr'), points, [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_path /'mask.png'), (panorama_mask * 255).astype(np.uint8)) + + # Export mesh & visulization + if save_glb_ or save_ply_ or show: + normals, normals_mask = utils3d.numpy.points_to_normals(points, panorama_mask) + faces, vertices, vertex_colors, vertex_uvs = utils3d.numpy.image_mesh( + points, + image.astype(np.float32) / 255, + utils3d.numpy.image_uv(width=width, height=height), + mask=panorama_mask & ~(utils3d.numpy.depth_edge(panorama_depth, rtol=threshold) & utils3d.numpy.normals_edge(normals, tol=5, mask=normals_mask)), + tri=True + ) + + if save_glb_: + save_glb(save_path / 'mesh.glb', vertices, faces, vertex_uvs, image) + + if save_ply_: + save_ply(save_path / 'mesh.ply', vertices, faces, vertex_colors) + + if show: + trimesh.Trimesh( + vertices=vertices, + vertex_colors=vertex_colors, + faces=faces, + process=False + ).show() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/scripts/train.py b/Pixel-Perfect-Depth/moge/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..d96d3ad4d31ed0b6c30bbbbcd83033b227e90829 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/train.py @@ -0,0 +1,452 @@ +import os +from pathlib import Path +import sys +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) +import json +import time +import random +from typing import * +import itertools +from contextlib import nullcontext +from concurrent.futures import ThreadPoolExecutor +import io + +import numpy as np +import cv2 +from PIL import Image +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.version +import accelerate +from accelerate import Accelerator, DistributedDataParallelKwargs +from accelerate.utils import set_seed +import utils3d +import click +from tqdm import tqdm, trange +import mlflow +torch.backends.cudnn.benchmark = False # Varying input size, make sure cudnn benchmark is disabled + +from moge.train.dataloader import TrainDataLoaderPipeline +from moge.train.losses import ( + affine_invariant_global_loss, + affine_invariant_local_loss, + edge_loss, + normal_loss, + mask_l2_loss, + mask_bce_loss, + monitoring, +) +from moge.train.utils import build_optimizer, build_lr_scheduler +from moge.utils.geometry_torch import intrinsics_to_fov +from moge.utils.vis import colorize_depth, colorize_normal +from moge.utils.tools import key_average, recursive_replace, CallbackOnException, flatten_nested_dict +from moge.test.metrics import compute_metrics + + +@click.command() +@click.option('--config', 'config_path', type=str, default='configs/debug.json') +@click.option('--workspace', type=str, default='workspace/debug', help='Path to the workspace') +@click.option('--checkpoint', 'checkpoint_path', type=str, default=None, help='Path to the checkpoint to load') +@click.option('--batch_size_forward', type=int, default=8, help='Batch size for each forward pass on each device') +@click.option('--gradient_accumulation_steps', type=int, default=1, help='Number of steps to accumulate gradients') +@click.option('--enable_gradient_checkpointing', type=bool, default=True, help='Use gradient checkpointing in backbone') +@click.option('--enable_mixed_precision', type=bool, default=False, help='Use mixed precision training. Backbone is converted to FP16') +@click.option('--enable_ema', type=bool, default=True, help='Maintain an exponential moving average of the model weights') +@click.option('--num_iterations', type=int, default=1000000, help='Number of iterations to train the model') +@click.option('--save_every', type=int, default=10000, help='Save checkpoint every n iterations') +@click.option('--log_every', type=int, default=1000, help='Log metrics every n iterations') +@click.option('--vis_every', type=int, default=0, help='Visualize every n iterations') +@click.option('--num_vis_images', type=int, default=32, help='Number of images to visualize, must be a multiple of divided batch size') +@click.option('--enable_mlflow', type=bool, default=True, help='Log metrics to MLFlow') +@click.option('--seed', type=int, default=0, help='Random seed') +def main( + config_path: str, + workspace: str, + checkpoint_path: str, + batch_size_forward: int, + gradient_accumulation_steps: int, + enable_gradient_checkpointing: bool, + enable_mixed_precision: bool, + enable_ema: bool, + num_iterations: int, + save_every: int, + log_every: int, + vis_every: int, + num_vis_images: int, + enable_mlflow: bool, + seed: Optional[int], +): + # Load config + with open(config_path, 'r') as f: + config = json.load(f) + + accelerator = Accelerator( + gradient_accumulation_steps=gradient_accumulation_steps, + mixed_precision='fp16' if enable_mixed_precision else None, + kwargs_handlers=[ + DistributedDataParallelKwargs(find_unused_parameters=True) + ] + ) + device = accelerator.device + batch_size_total = batch_size_forward * gradient_accumulation_steps * accelerator.num_processes + + # Log config + if accelerator.is_main_process: + if enable_mlflow: + try: + mlflow.log_params({ + **click.get_current_context().params, + 'batch_size_total': batch_size_total, + }) + except: + print('Failed to log config to MLFlow') + Path(workspace).mkdir(parents=True, exist_ok=True) + with Path(workspace).joinpath('config.json').open('w') as f: + json.dump(config, f, indent=4) + + # Set seed + if seed is not None: + set_seed(seed, device_specific=True) + + # Initialize model + print('Initialize model') + with accelerator.local_main_process_first(): + from moge.model import import_model_class_by_version + MoGeModel = import_model_class_by_version(config['model_version']) + model = MoGeModel(**config['model']) + count_total_parameters = sum(p.numel() for p in model.parameters()) + print(f'Total parameters: {count_total_parameters}') + + # Set up EMA model + if enable_ema and accelerator.is_main_process: + ema_avg_fn = lambda averaged_model_parameter, model_parameter, num_averaged: 0.999 * averaged_model_parameter + 0.001 * model_parameter + ema_model = torch.optim.swa_utils.AveragedModel(model, device=accelerator.device, avg_fn=ema_avg_fn) + + # Set gradient checkpointing + if enable_gradient_checkpointing: + model.enable_gradient_checkpointing() + import warnings + warnings.filterwarnings("ignore", category=FutureWarning, module="torch.utils.checkpoint") + + # Initalize optimizer & lr scheduler + optimizer = build_optimizer(model, config['optimizer']) + lr_scheduler = build_lr_scheduler(optimizer, config['lr_scheduler']) + + count_grouped_parameters = [sum(p.numel() for p in param_group['params'] if p.requires_grad) for param_group in optimizer.param_groups] + for i, count in enumerate(count_grouped_parameters): + print(f'- Group {i}: {count} parameters') + + # Attempt to load checkpoint + checkpoint: Dict[str, Any] + with accelerator.local_main_process_first(): + if checkpoint_path.endswith('.pt'): + # - Load specific checkpoint file + print(f'Load checkpoint: {checkpoint_path}') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + elif checkpoint_path == "latest": + # - Load latest + checkpoint_path = Path(workspace, 'checkpoint', 'latest.pt') + if checkpoint_path.exists(): + print(f'Load checkpoint: {checkpoint_path}') + checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True) + i_step = checkpoint['step'] + if 'model' not in checkpoint and (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists(): + print(f'Load model checkpoint: {checkpoint_model_path}') + checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model'] + if 'optimizer' not in checkpoint and (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists(): + print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}') + checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True)) + if enable_ema and accelerator.is_main_process: + if 'ema_model' not in checkpoint and (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists(): + print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}') + checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model'] + else: + checkpoint = None + elif checkpoint_path is not None: + # - Load by step number + i_step = int(checkpoint_path) + checkpoint = {'step': i_step} + if (checkpoint_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}.pt')).exists(): + print(f'Load model checkpoint: {checkpoint_model_path}') + checkpoint['model'] = torch.load(checkpoint_model_path, map_location='cpu', weights_only=True)['model'] + if (checkpoint_optimizer_path := Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt')).exists(): + print(f'Load optimizer checkpoint: {checkpoint_optimizer_path}') + checkpoint.update(torch.load(checkpoint_optimizer_path, map_location='cpu', weights_only=True)) + if enable_ema and accelerator.is_main_process: + if (checkpoint_ema_model_path := Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt')).exists(): + print(f'Load EMA model checkpoint: {checkpoint_ema_model_path}') + checkpoint['ema_model'] = torch.load(checkpoint_ema_model_path, map_location='cpu', weights_only=True)['model'] + else: + checkpoint = None + + if checkpoint is None: + # Initialize model weights + print('Initialize model weights') + with accelerator.local_main_process_first(): + model.init_weights() + initial_step = 0 + else: + model.load_state_dict(checkpoint['model'], strict=False) + if 'step' in checkpoint: + initial_step = checkpoint['step'] + 1 + else: + initial_step = 0 + if 'optimizer' in checkpoint: + optimizer.load_state_dict(checkpoint['optimizer']) + if enable_ema and accelerator.is_main_process and 'ema_model' in checkpoint: + ema_model.module.load_state_dict(checkpoint['ema_model'], strict=False) + if 'lr_scheduler' in checkpoint: + lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) + + del checkpoint + + model, optimizer = accelerator.prepare(model, optimizer) + if torch.version.hip and isinstance(model, torch.nn.parallel.DistributedDataParallel): + # Hacking potential gradient synchronization issue in ROCm backend + from moge.model.utils import sync_ddp_hook + model.register_comm_hook(None, sync_ddp_hook) + + # Initialize training data pipeline + with accelerator.local_main_process_first(): + train_data_pipe = TrainDataLoaderPipeline(config['data'], batch_size_forward) + + def _write_bytes_retry_loop(save_path: Path, data: bytes): + while True: + try: + save_path.write_bytes(data) + break + except Exception as e: + print('Error while saving checkpoint, retrying in 1 minute: ', e) + time.sleep(60) + + # Ready to train + records = [] + model.train() + with ( + train_data_pipe, + tqdm(initial=initial_step, total=num_iterations, desc='Training', disable=not accelerator.is_main_process) as pbar, + ThreadPoolExecutor(max_workers=1) as save_checkpoint_executor, + ): + # Get some batches for visualization + if accelerator.is_main_process: + batches_for_vis: List[Dict[str, torch.Tensor]] = [] + num_vis_images = num_vis_images // batch_size_forward * batch_size_forward + for _ in range(num_vis_images // batch_size_forward): + batch = train_data_pipe.get() + batches_for_vis.append(batch) + + # Visualize GT + if vis_every > 0 and accelerator.is_main_process and initial_step == 0: + save_dir = Path(workspace).joinpath('vis/gt') + for i_batch, batch in enumerate(tqdm(batches_for_vis, desc='Visualize GT', leave=False)): + image, gt_depth, gt_mask, gt_mask_inf, gt_intrinsics, info = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_inf'], batch['intrinsics'], batch['info'] + gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics) + gt_normal, gt_normal_mask = utils3d.torch.points_to_normals(gt_points, gt_mask) + for i_instance in range(batch['image'].shape[0]): + idx = i_batch * batch_size_forward + i_instance + image_i = (image[i_instance].numpy().transpose(1, 2, 0) * 255).astype(np.uint8) + gt_depth_i = gt_depth[i_instance].numpy() + gt_mask_i = gt_mask[i_instance].numpy() + gt_mask_inf_i = gt_mask_inf[i_instance].numpy() + gt_points_i = gt_points[i_instance].numpy() + gt_normal_i = gt_normal[i_instance].numpy() + save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(gt_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), gt_mask_i * 255) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(gt_depth_i, gt_mask_i), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/normal.png')), cv2.cvtColor(colorize_normal(gt_normal_i), cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask_inf.png')), gt_mask_inf_i * 255) + with save_dir.joinpath(f'{idx:04d}/info.json').open('w') as f: + json.dump(info[i_instance], f) + + # Reset seed to avoid training on the same data when resuming training + if seed is not None: + set_seed(seed + initial_step, device_specific=True) + + # Training loop + for i_step in range(initial_step, num_iterations): + + i_accumulate, weight_accumulate = 0, 0 + while i_accumulate < gradient_accumulation_steps: + # Load batch + batch = train_data_pipe.get() + image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics, label_type, is_metric = batch['image'], batch['depth'], batch['depth_mask'], batch['depth_mask_fin'], batch['depth_mask_inf'], batch['intrinsics'], batch['label_type'], batch['is_metric'] + image, gt_depth, gt_mask, gt_mask_fin, gt_mask_inf, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_mask_fin.to(device), gt_mask_inf.to(device), gt_intrinsics.to(device) + current_batch_size = image.shape[0] + if all(label == 'invalid' for label in label_type): + continue # NOTE: Skip all-invalid batches to avoid messing up the optimizer. + + gt_points = utils3d.torch.depth_to_points(gt_depth, intrinsics=gt_intrinsics) + gt_focal = 1 / (1 / gt_intrinsics[..., 0, 0] ** 2 + 1 / gt_intrinsics[..., 1, 1] ** 2) ** 0.5 + + with accelerator.accumulate(model): + # Forward + if i_step <= config.get('low_resolution_training_steps', 0): + num_tokens = config['model']['num_tokens_range'][0] + else: + num_tokens = accelerate.utils.broadcast_object_list([random.randint(*config['model']['num_tokens_range'])])[0] + with torch.autocast(device_type=accelerator.device.type, dtype=torch.float16, enabled=enable_mixed_precision): + output = model(image, num_tokens=num_tokens) + pred_points, pred_mask, pred_metric_scale = output['points'], output['mask'], output.get('metric_scale', None) + + # Compute loss (per instance) + loss_list, weight_list = [], [] + for i in range(current_batch_size): + gt_metric_scale = None + loss_dict, weight_dict, misc_dict = {}, {}, {} + misc_dict['monitoring'] = monitoring(pred_points[i]) + for k, v in config['loss'][label_type[i]].items(): + weight_dict[k] = v['weight'] + if v['function'] == 'affine_invariant_global_loss': + loss_dict[k], misc_dict[k], gt_metric_scale = affine_invariant_global_loss(pred_points[i], gt_points[i], gt_mask[i], **v['params']) + elif v['function'] == 'affine_invariant_local_loss': + loss_dict[k], misc_dict[k] = affine_invariant_local_loss(pred_points[i], gt_points[i], gt_mask[i], gt_focal[i], gt_metric_scale, **v['params']) + elif v['function'] == 'normal_loss': + loss_dict[k], misc_dict[k] = normal_loss(pred_points[i], gt_points[i], gt_mask[i]) + elif v['function'] == 'edge_loss': + loss_dict[k], misc_dict[k] = edge_loss(pred_points[i], gt_points[i], gt_mask[i]) + elif v['function'] == 'mask_bce_loss': + loss_dict[k], misc_dict[k] = mask_bce_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i]) + elif v['function'] == 'mask_l2_loss': + loss_dict[k], misc_dict[k] = mask_l2_loss(pred_mask[i], gt_mask_fin[i], gt_mask_inf[i]) + else: + raise ValueError(f'Undefined loss function: {v["function"]}') + weight_dict = {'.'.join(k): v for k, v in flatten_nested_dict(weight_dict).items()} + loss_dict = {'.'.join(k): v for k, v in flatten_nested_dict(loss_dict).items()} + loss_ = sum([weight_dict[k] * loss_dict[k] for k in loss_dict], start=torch.tensor(0.0, device=device)) + loss_list.append(loss_) + + if torch.isnan(loss_).item(): + pbar.write(f'NaN loss in process {accelerator.process_index}') + pbar.write(str(loss_dict)) + + misc_dict = {'.'.join(k): v for k, v in flatten_nested_dict(misc_dict).items()} + records.append({ + **{k: v.item() for k, v in loss_dict.items()}, + **misc_dict, + }) + + loss = sum(loss_list) / len(loss_list) + + # Backward & update + accelerator.backward(loss) + if accelerator.sync_gradients: + if not enable_mixed_precision and any(torch.isnan(p.grad).any() for p in model.parameters() if p.grad is not None): + if accelerator.is_main_process: + pbar.write(f'NaN gradients, skip update') + optimizer.zero_grad() + continue + accelerator.clip_grad_norm_(model.parameters(), 1.0) + + optimizer.step() + optimizer.zero_grad() + + i_accumulate += 1 + + lr_scheduler.step() + + # EMA update + if enable_ema and accelerator.is_main_process and accelerator.sync_gradients: + ema_model.update_parameters(model) + + # Log metrics + if i_step == initial_step or i_step % log_every == 0: + records = [key_average(records)] + records = accelerator.gather_for_metrics(records, use_gather_object=True) + if accelerator.is_main_process: + records = key_average(records) + if enable_mlflow: + try: + mlflow.log_metrics(records, step=i_step) + except Exception as e: + print(f'Error while logging metrics to mlflow: {e}') + records = [] + + # Save model weight checkpoint + if accelerator.is_main_process and (i_step % save_every == 0): + # NOTE: Writing checkpoint is done in a separate thread to avoid blocking the main process + pbar.write(f'Save checkpoint: {i_step:08d}') + Path(workspace, 'checkpoint').mkdir(parents=True, exist_ok=True) + + # Model checkpoint + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'model': accelerator.unwrap_model(model).state_dict(), + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}.pt'), checkpoint_bytes + ) + + # Optimizer checkpoint + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'step': i_step, + 'optimizer': optimizer.state_dict(), + 'lr_scheduler': lr_scheduler.state_dict(), + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_optimizer.pt'), checkpoint_bytes + ) + + # EMA model checkpoint + if enable_ema: + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'model': ema_model.module.state_dict(), + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', f'{i_step:08d}_ema.pt'), checkpoint_bytes + ) + + # Latest checkpoint + with io.BytesIO() as f: + torch.save({ + 'model_config': config['model'], + 'step': i_step, + }, f) + checkpoint_bytes = f.getvalue() + save_checkpoint_executor.submit( + _write_bytes_retry_loop, Path(workspace, 'checkpoint', 'latest.pt'), checkpoint_bytes + ) + + # Visualize + if vis_every > 0 and accelerator.is_main_process and (i_step == initial_step or i_step % vis_every == 0): + unwrapped_model = accelerator.unwrap_model(model) + save_dir = Path(workspace).joinpath(f'vis/step_{i_step:08d}') + save_dir.mkdir(parents=True, exist_ok=True) + with torch.inference_mode(): + for i_batch, batch in enumerate(tqdm(batches_for_vis, desc=f'Visualize: {i_step:08d}', leave=False)): + image, gt_depth, gt_mask, gt_intrinsics = batch['image'], batch['depth'], batch['depth_mask'], batch['intrinsics'] + image, gt_depth, gt_mask, gt_intrinsics = image.to(device), gt_depth.to(device), gt_mask.to(device), gt_intrinsics.to(device) + + output = unwrapped_model.infer(image) + pred_points, pred_depth, pred_mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy() + image = image.cpu().numpy() + + for i_instance in range(image.shape[0]): + idx = i_batch * batch_size_forward + i_instance + image_i = (image[i_instance].transpose(1, 2, 0) * 255).astype(np.uint8) + pred_points_i = pred_points[i_instance] + pred_mask_i = pred_mask[i_instance] + pred_depth_i = pred_depth[i_instance] + save_dir.joinpath(f'{idx:04d}').mkdir(parents=True, exist_ok=True) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/image.jpg')), cv2.cvtColor(image_i, cv2.COLOR_RGB2BGR)) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/points.exr')), cv2.cvtColor(pred_points_i, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT]) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/mask.png')), pred_mask_i * 255) + cv2.imwrite(str(save_dir.joinpath(f'{idx:04d}/depth_vis.png')), cv2.cvtColor(colorize_depth(pred_depth_i, pred_mask_i), cv2.COLOR_RGB2BGR)) + + pbar.set_postfix({'loss': loss.item()}, refresh=False) + pbar.update(1) + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/scripts/vis_data.py b/Pixel-Perfect-Depth/moge/scripts/vis_data.py new file mode 100644 index 0000000000000000000000000000000000000000..bdb21766a67e4370578acbdf7bd17d1feb46b937 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/scripts/vis_data.py @@ -0,0 +1,84 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +import sys +from pathlib import Path +if (_package_root := str(Path(__file__).absolute().parents[2])) not in sys.path: + sys.path.insert(0, _package_root) + +import click + + +@click.command() +@click.argument('folder_or_path', type=click.Path(exists=True)) +@click.option('--output', '-o', 'output_folder', type=click.Path(), help='Path to output folder') +@click.option('--max_depth', '-m', type=float, default=float('inf'), help='max depth') +@click.option('--fov', type=float, default=None, help='field of view in degrees') +@click.option('--show', 'show', is_flag=True, help='show point cloud') +@click.option('--depth', 'depth_filename', type=str, default='depth.png', help='depth image file name') +@click.option('--ply', 'save_ply', is_flag=True, help='save point cloud as PLY file') +@click.option('--depth_vis', 'save_depth_vis', is_flag=True, help='save depth image') +@click.option('--inf', 'inf_mask', is_flag=True, help='use infinity mask') +@click.option('--version', 'version', type=str, default='v3', help='version of rgbd data') +def main( + folder_or_path: str, + output_folder: str, + max_depth: float, + fov: float, + depth_filename: str, + show: bool, + save_ply: bool, + save_depth_vis: bool, + inf_mask: bool, + version: str +): + # Lazy import + import cv2 + import numpy as np + import utils3d + from tqdm import tqdm + import trimesh + + from moge.utils.io import read_image, read_depth, read_meta + from moge.utils.vis import colorize_depth, colorize_normal + + filepaths = sorted(p.parent for p in Path(folder_or_path).rglob('meta.json')) + + for filepath in tqdm(filepaths): + image = read_image(Path(filepath, 'image.jpg')) + depth, unit = read_depth(Path(filepath, depth_filename)) + meta = read_meta(Path(filepath,'meta.json')) + depth_mask = np.isfinite(depth) + depth_mask_inf = (depth == np.inf) + intrinsics = np.array(meta['intrinsics']) + + extrinsics = np.array([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]], dtype=float) # OpenGL's identity camera + verts = utils3d.numpy.unproject_cv(utils3d.numpy.image_uv(*image.shape[:2]), depth, extrinsics=extrinsics, intrinsics=intrinsics) + + depth_mask_ply = depth_mask & (depth < depth[depth_mask].min() * max_depth) + point_cloud = trimesh.PointCloud(verts[depth_mask_ply], image[depth_mask_ply] / 255) + + if show: + point_cloud.show() + + if output_folder is None: + output_path = filepath + else: + output_path = Path(output_folder, filepath.name) + output_path.mkdir(exist_ok=True, parents=True) + + if inf_mask: + depth = np.where(depth_mask_inf, np.inf, depth) + depth_mask = depth_mask | depth_mask_inf + + if save_depth_vis: + p = output_path.joinpath('depth_vis.png') + cv2.imwrite(str(p), cv2.cvtColor(colorize_depth(depth, depth_mask), cv2.COLOR_RGB2BGR)) + print(f"{p}") + + if save_ply: + p = output_path.joinpath('pointcloud.ply') + point_cloud.export(p) + print(f"{p}") + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/test/__init__.py b/Pixel-Perfect-Depth/moge/test/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Pixel-Perfect-Depth/moge/test/baseline.py b/Pixel-Perfect-Depth/moge/test/baseline.py new file mode 100644 index 0000000000000000000000000000000000000000..05980aaf96870304534fcec6532225e870351a66 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/test/baseline.py @@ -0,0 +1,43 @@ +from typing import * + +import click +import torch + + +class MGEBaselineInterface: + """ + Abstract class for model wrapper to uniformize the interface of loading and inference across different models. + """ + device: torch.device + + @click.command() + @staticmethod + def load(*args, **kwargs) -> "MGEBaselineInterface": + """ + Customized static method to create an instance of the model wrapper from command line arguments. Decorated by `click.command()` + """ + raise NotImplementedError(f"{type(self).__name__} has not implemented the load method.") + + def infer(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + """ + ### Parameters + `image`: [B, 3, H, W] or [3, H, W], RGB values in range [0, 1] + `intrinsics`: [B, 3, 3] or [3, 3], camera intrinsics. Optional. + + ### Returns + A dictionary containing: + - `points_*`. point map output in OpenCV identity camera space. + Supported suffixes: `metric`, `scale_invariant`, `affine_invariant`. + - `depth_*`. depth map output + Supported suffixes: `metric` (in meters), `scale_invariant`, `affine_invariant`. + - `disparity_affine_invariant`. affine disparity map output + """ + raise NotImplementedError(f"{type(self).__name__} has not implemented the infer method.") + + def infer_for_evaluation(self, image: torch.FloatTensor, intrinsics: Optional[torch.Tensor] = None) -> Dict[str, torch.Tensor]: + """ + If the model has a special evaluation mode, override this method to provide the evaluation mode inference. + + By default, this method simply calls `infer()`. + """ + return self.infer(image, intrinsics) \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/test/dataloader.py b/Pixel-Perfect-Depth/moge/test/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..76679829afdf385938b604fa8bb5ef07b2560e7b --- /dev/null +++ b/Pixel-Perfect-Depth/moge/test/dataloader.py @@ -0,0 +1,221 @@ +import os +from typing import * +from pathlib import Path +import math + +import numpy as np +import torch +from PIL import Image +import cv2 +import utils3d + +from ..utils import pipeline +from ..utils.geometry_numpy import focal_to_fov_numpy, mask_aware_nearest_resize_numpy, norm3d +from ..utils.io import * +from ..utils.tools import timeit + + +class EvalDataLoaderPipeline: + + def __init__( + self, + path: str, + width: int, + height: int, + split: int = '.index.txt', + drop_max_depth: float = 1000., + num_load_workers: int = 4, + num_process_workers: int = 8, + include_segmentation: bool = False, + include_normal: bool = False, + depth_to_normal: bool = False, + max_segments: int = 100, + min_seg_area: int = 1000, + depth_unit: str = None, + has_sharp_boundary = False, + subset: int = None, + ): + filenames = Path(path).joinpath(split).read_text(encoding='utf-8').splitlines() + filenames = filenames[::subset] + self.width = width + self.height = height + self.drop_max_depth = drop_max_depth + self.path = Path(path) + self.filenames = filenames + self.include_segmentation = include_segmentation + self.include_normal = include_normal + self.max_segments = max_segments + self.min_seg_area = min_seg_area + self.depth_to_normal = depth_to_normal + self.depth_unit = depth_unit + self.has_sharp_boundary = has_sharp_boundary + + self.rng = np.random.default_rng(seed=0) + + self.pipeline = pipeline.Sequential([ + self._generator, + pipeline.Parallel([self._load_instance] * num_load_workers), + pipeline.Parallel([self._process_instance] * num_process_workers), + pipeline.Buffer(4) + ]) + + def __len__(self): + return math.ceil(len(self.filenames)) + + def _generator(self): + for idx in range(len(self)): + yield idx + + def _load_instance(self, idx): + if idx >= len(self.filenames): + return None + + path = self.path.joinpath(self.filenames[idx]) + + instance = { + 'filename': self.filenames[idx], + 'width': self.width, + 'height': self.height, + } + instance['image'] = read_image(Path(path, 'image.jpg')) + + depth, _ = read_depth(Path(path, 'depth.png')) # ignore depth unit from depth file, use config instead + instance.update({ + 'depth': np.nan_to_num(depth, nan=1, posinf=1, neginf=1), + 'depth_mask': np.isfinite(depth), + 'depth_mask_inf': np.isinf(depth), + }) + + if self.include_segmentation: + segmentation_mask, segmentation_labels = read_segmentation(Path(path,'segmentation.png')) + instance.update({ + 'segmentation_mask': segmentation_mask, + 'segmentation_labels': segmentation_labels, + }) + + meta = read_meta(Path(path, 'meta.json')) + instance['intrinsics'] = np.array(meta['intrinsics'], dtype=np.float32) + + return instance + + def _process_instance(self, instance: dict): + if instance is None: + return None + + image, depth, depth_mask, intrinsics = instance['image'], instance['depth'], instance['depth_mask'], instance['intrinsics'] + segmentation_mask, segmentation_labels = instance.get('segmentation_mask', None), instance.get('segmentation_labels', None) + + raw_height, raw_width = image.shape[:2] + raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) + raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height + tgt_width, tgt_height = instance['width'], instance['height'] + tgt_aspect = tgt_width / tgt_height + + # set expected target view field + tgt_horizontal = min(raw_horizontal, raw_vertical * tgt_aspect) + tgt_vertical = tgt_horizontal / tgt_aspect + + # set target view direction + cu, cv = 0.5, 0.5 + direction = utils3d.numpy.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0] + R = utils3d.numpy.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32)) + + # restrict target view field within the raw view + corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32) + corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane + corners = corners[:, :2] / corners[:, 2:3] + + warp_horizontal, warp_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) + for i in range(4): + intersection, _ = utils3d.numpy.ray_intersection( + np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]), + corners[i - 1], corners[i] - corners[i - 1], + ) + warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min()) + tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical) + + # get target view intrinsics + fx, fy = 1.0 / tgt_horizontal, 1.0 / tgt_vertical + tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32) + + # do homogeneous transformation with the rotation and intrinsics + # 4.1 The image and depth is resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling + tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes) + rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h) + image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS)) + + depth, depth_mask = mask_aware_nearest_resize_numpy(depth, depth_mask, (rescaled_w, rescaled_h)) + distance = norm3d(utils3d.numpy.depth_to_points(depth, intrinsics=intrinsics)) + segmentation_mask = cv2.resize(segmentation_mask, (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) if segmentation_mask is not None else None + + # 4.2 calculate homography warping + transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics) + uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height) + pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T + uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12) + pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32) + + tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) + tgt_distance = cv2.remap(distance, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) + tgt_ray_length = utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics) + tgt_ray_length = (tgt_ray_length[:, :, 0] ** 2 + tgt_ray_length[:, :, 1] ** 2 + tgt_ray_length[:, :, 2] ** 2) ** 0.5 + tgt_depth = tgt_distance / (tgt_ray_length + 1e-12) + tgt_depth_mask = cv2.remap(depth_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + tgt_segmentation_mask = cv2.remap(segmentation_mask, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) if segmentation_mask is not None else None + + # drop depth greater than drop_max_depth + max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.drop_max_depth + tgt_depth_mask &= tgt_depth <= max_depth + tgt_depth = np.nan_to_num(tgt_depth, nan=0.0) + + if self.depth_unit is not None: + tgt_depth *= self.depth_unit + + if not np.any(tgt_depth_mask): + # always make sure that mask is not empty, otherwise the loss calculation will crash + tgt_depth_mask = np.ones_like(tgt_depth_mask) + tgt_depth = np.ones_like(tgt_depth) + instance['label_type'] = 'invalid' + + tgt_pts = utils3d.numpy.unproject_cv(uv_tgt, tgt_depth, intrinsics=tgt_intrinsics) + + # Process segmentation labels + if self.include_segmentation and segmentation_mask is not None: + for k in ['undefined', 'unannotated', 'background', 'sky']: + if k in segmentation_labels: + del segmentation_labels[k] + seg_id2count = dict(zip(*np.unique(tgt_segmentation_mask, return_counts=True))) + sorted_labels = sorted(segmentation_labels.keys(), key=lambda x: seg_id2count.get(segmentation_labels[x], 0), reverse=True) + segmentation_labels = {k: segmentation_labels[k] for k in sorted_labels[:self.max_segments] if seg_id2count.get(segmentation_labels[k], 0) >= self.min_seg_area} + + instance.update({ + 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1), + 'depth': torch.from_numpy(tgt_depth).float(), + 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(), + 'intrinsics': torch.from_numpy(tgt_intrinsics).float(), + 'points': torch.from_numpy(tgt_pts).float(), + 'segmentation_mask': torch.from_numpy(tgt_segmentation_mask).long() if tgt_segmentation_mask is not None else None, + 'segmentation_labels': segmentation_labels, + 'is_metric': self.depth_unit is not None, + 'has_sharp_boundary': self.has_sharp_boundary, + }) + + instance = {k: v for k, v in instance.items() if v is not None} + + return instance + + def start(self): + self.pipeline.start() + + def stop(self): + self.pipeline.stop() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.stop() + + def get(self): + return self.pipeline.get() \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/test/metrics.py b/Pixel-Perfect-Depth/moge/test/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..904064f2a30d05dca3a53db7ecc076a0c2aaa0ad --- /dev/null +++ b/Pixel-Perfect-Depth/moge/test/metrics.py @@ -0,0 +1,343 @@ +from typing import * +from numbers import Number + +import torch +import torch.nn.functional as F +import numpy as np +import utils3d + +from ..utils.geometry_torch import ( + weighted_mean, + mask_aware_nearest_resize, + intrinsics_to_fov +) +from ..utils.alignment import ( + align_points_scale_z_shift, + align_points_scale_xyz_shift, + align_points_xyz_shift, + align_affine_lstsq, + align_depth_scale, + align_depth_affine, + align_points_scale, +) +from ..utils.tools import key_average, timeit + + +def rel_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + rel = (torch.abs(pred - gt) / (gt + eps)).mean() + return rel.item() + + +def delta1_depth(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + delta1 = (torch.maximum(gt / pred, pred / gt) < 1.25).float().mean() + return delta1.item() + + +def rel_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + dist_gt = torch.norm(gt, dim=-1) + dist_err = torch.norm(pred - gt, dim=-1) + rel = (dist_err / (dist_gt + eps)).mean() + return rel.item() + + +def delta1_point(pred: torch.Tensor, gt: torch.Tensor, eps: float = 1e-6): + dist_pred = torch.norm(pred, dim=-1) + dist_gt = torch.norm(gt, dim=-1) + dist_err = torch.norm(pred - gt, dim=-1) + + delta1 = (dist_err < 0.25 * torch.minimum(dist_gt, dist_pred)).float().mean() + return delta1.item() + + +def rel_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor): + dist_err = torch.norm(pred - gt, dim=-1) + rel = (dist_err / diameter).mean() + return rel.item() + + +def delta1_point_local(pred: torch.Tensor, gt: torch.Tensor, diameter: torch.Tensor): + dist_err = torch.norm(pred - gt, dim=-1) + delta1 = (dist_err < 0.25 * diameter).float().mean() + return delta1.item() + + +def boundary_f1(pred: torch.Tensor, gt: torch.Tensor, mask: torch.Tensor, radius: int = 1): + neighbor_x, neight_y = torch.meshgrid( + torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device), + torch.linspace(-radius, radius, 2 * radius + 1, device=pred.device), + indexing='xy' + ) + neighbor_mask = (neighbor_x ** 2 + neight_y ** 2) <= radius ** 2 + 1e-5 + + pred_window = utils3d.torch.sliding_window_2d(pred, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1] + gt_window = utils3d.torch.sliding_window_2d(gt, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1] + mask_window = neighbor_mask & utils3d.torch.sliding_window_2d(mask, window_size=2 * radius + 1, stride=1, dim=(-2, -1)) # [H, W, 2*R+1, 2*R+1] + + pred_rel = pred_window / pred[radius:-radius, radius:-radius, None, None] + gt_rel = gt_window / gt[radius:-radius, radius:-radius, None, None] + valid = mask[radius:-radius, radius:-radius, None, None] & mask_window + + f1_list = [] + w_list = t_list = torch.linspace(0.05, 0.25, 10).tolist() + + for t in t_list: + pred_label = pred_rel > 1 + t + gt_label = gt_rel > 1 + t + TP = (pred_label & gt_label & valid).float().sum() + precision = TP / (gt_label & valid).float().sum().clamp_min(1e-12) + recall = TP / (pred_label & valid).float().sum().clamp_min(1e-12) + f1 = 2 * precision * recall / (precision + recall).clamp_min(1e-12) + f1_list.append(f1.item()) + + f1_avg = sum(w * f1 for w, f1 in zip(w_list, f1_list)) / sum(w_list) + return f1_avg + + +def compute_metrics( + pred: Dict[str, torch.Tensor], + gt: Dict[str, torch.Tensor], + vis: bool = False +) -> Tuple[Dict[str, Dict[str, Number]], Dict[str, torch.Tensor]]: + """ + A unified function to compute metrics for different types of predictions and ground truths. + + #### Supported keys in pred: + - `disparity_affine_invariant`: disparity map predicted by a depth estimator with scale and shift invariant. + - `depth_scale_invariant`: depth map predicted by a depth estimator with scale invariant. + - `depth_affine_invariant`: depth map predicted by a depth estimator with scale and shift invariant. + - `depth_metric`: depth map predicted by a depth estimator with no scale or shift. + - `points_scale_invariant`: point map predicted by a point estimator with scale invariant. + - `points_affine_invariant`: point map predicted by a point estimator with scale and xyz shift invariant. + - `points_metric`: point map predicted by a point estimator with no scale or shift. + - `intrinsics`: normalized camera intrinsics matrix. + + #### Required keys in gt: + - `depth`: depth map ground truth (in metric units if `depth_metric` is used) + - `points`: point map ground truth in camera coordinates. + - `mask`: mask indicating valid pixels in the ground truth. + - `intrinsics`: normalized ground-truth camera intrinsics matrix. + - `is_metric`: whether the depth is in metric units. + """ + metrics = {} + misc = {} + + mask = gt['depth_mask'] + gt_depth = gt['depth'] + gt_points = gt['points'] + + height, width = mask.shape[-2:] + _, lr_mask, lr_index = mask_aware_nearest_resize(None, mask, (64, 64), return_index=True) + + only_depth = not any('point' in k for k in pred) + pred_depth_aligned, pred_points_aligned = None, None + + # Metric depth + if 'depth_metric' in pred and gt['is_metric']: + pred_depth, gt_depth = pred['depth_metric'], gt['depth'] + metrics['depth_metric'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = pred_depth + + # Scale-invariant depth + if 'depth_scale_invariant' in pred: + pred_depth_scale_invariant = pred['depth_scale_invariant'] + elif 'depth_metric' in pred: + pred_depth_scale_invariant = pred['depth_metric'] + else: + pred_depth_scale_invariant = None + + if pred_depth_scale_invariant is not None: + pred_depth = pred_depth_scale_invariant + + pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask] + scale = align_depth_scale(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked) + pred_depth = pred_depth * scale + + metrics['depth_scale_invariant'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = pred_depth + + # Affine-invariant depth + if 'depth_affine_invariant' in pred: + pred_depth_affine_invariant = pred['depth_affine_invariant'] + elif 'depth_scale_invariant' in pred: + pred_depth_affine_invariant = pred['depth_scale_invariant'] + elif 'depth_metric' in pred: + pred_depth_affine_invariant = pred['depth_metric'] + else: + pred_depth_affine_invariant = None + + if pred_depth_affine_invariant is not None: + pred_depth = pred_depth_affine_invariant + + pred_depth_lr_masked, gt_depth_lr_masked = pred_depth[lr_index][lr_mask], gt_depth[lr_index][lr_mask] + scale, shift = align_depth_affine(pred_depth_lr_masked, gt_depth_lr_masked, 1 / gt_depth_lr_masked) + pred_depth = pred_depth * scale + shift + + metrics['depth_affine_invariant'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = pred_depth + + # Affine-invariant disparity + if 'disparity_affine_invariant' in pred: + pred_disparity_affine_invariant = pred['disparity_affine_invariant'] + elif 'depth_scale_invariant' in pred: + pred_disparity_affine_invariant = 1 / pred['depth_scale_invariant'] + elif 'depth_metric' in pred: + pred_disparity_affine_invariant = 1 / pred['depth_metric'] + else: + pred_disparity_affine_invariant = None + + if pred_disparity_affine_invariant is not None: + pred_disp = pred_disparity_affine_invariant + + scale, shift = align_affine_lstsq(pred_disp[mask], 1 / gt_depth[mask]) + pred_disp = pred_disp * scale + shift + + # NOTE: The alignment is done on the disparity map could introduce extreme outliers at disparities close to 0. + # Therefore we clamp the disparities by minimum ground truth disparity. + pred_depth = 1 / pred_disp.clamp_min(1 / gt_depth[mask].max().item()) + + metrics['disparity_affine_invariant'] = { + 'rel': rel_depth(pred_depth[mask], gt_depth[mask]), + 'delta1': delta1_depth(pred_depth[mask], gt_depth[mask]) + } + + if pred_depth_aligned is None: + pred_depth_aligned = 1 / pred_disp.clamp_min(1e-6) + + # Metric points + if 'points_metric' in pred and gt['is_metric']: + pred_points = pred['points_metric'] + + pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask] + shift = align_points_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1)) + pred_points = pred_points + shift + + metrics['points_metric'] = { + 'rel': rel_point(pred_points[mask], gt_points[mask]), + 'delta1': delta1_point(pred_points[mask], gt_points[mask]) + } + + if pred_points_aligned is None: + pred_points_aligned = pred['points_metric'] + + # Scale-invariant points (in camera space) + if 'points_scale_invariant' in pred: + pred_points_scale_invariant = pred['points_scale_invariant'] + elif 'points_metric' in pred: + pred_points_scale_invariant = pred['points_metric'] + else: + pred_points_scale_invariant = None + + if pred_points_scale_invariant is not None: + pred_points = pred_points_scale_invariant + + pred_points_lr_masked, gt_points_lr_masked = pred_points_scale_invariant[lr_index][lr_mask], gt_points[lr_index][lr_mask] + scale = align_points_scale(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1)) + pred_points = pred_points * scale + + metrics['points_scale_invariant'] = { + 'rel': rel_point(pred_points[mask], gt_points[mask]), + 'delta1': delta1_point(pred_points[mask], gt_points[mask]) + } + + if vis and pred_points_aligned is None: + pred_points_aligned = pred['points_scale_invariant'] * scale + + # Affine-invariant points + if 'points_affine_invariant' in pred: + pred_points_affine_invariant = pred['points_affine_invariant'] + elif 'points_scale_invariant' in pred: + pred_points_affine_invariant = pred['points_scale_invariant'] + elif 'points_metric' in pred: + pred_points_affine_invariant = pred['points_metric'] + else: + pred_points_affine_invariant = None + + if pred_points_affine_invariant is not None: + pred_points = pred_points_affine_invariant + + pred_points_lr_masked, gt_points_lr_masked = pred_points[lr_index][lr_mask], gt_points[lr_index][lr_mask] + scale, shift = align_points_scale_xyz_shift(pred_points_lr_masked, gt_points_lr_masked, 1 / gt_points_lr_masked.norm(dim=-1)) + pred_points = pred_points * scale + shift + + metrics['points_affine_invariant'] = { + 'rel': rel_point(pred_points[mask], gt_points[mask]), + 'delta1': delta1_point(pred_points[mask], gt_points[mask]) + } + + if vis and pred_points_aligned is None: + pred_points_aligned = pred['points_affine_invariant'] * scale + shift + + # Local points + if 'segmentation_mask' in gt and 'points' in gt and any('points' in k for k in pred.keys()): + pred_points = next(pred[k] for k in pred.keys() if 'points' in k) + gt_points = gt['points'] + segmentation_mask = gt['segmentation_mask'] + segmentation_labels = gt['segmentation_labels'] + segmentation_mask_lr = segmentation_mask[lr_index] + local_points_metrics = [] + for _, seg_id in segmentation_labels.items(): + valid_mask = (segmentation_mask == seg_id) & mask + + pred_points_masked = pred_points[valid_mask] + gt_points_masked = gt_points[valid_mask] + + valid_mask_lr = (segmentation_mask_lr == seg_id) & lr_mask + if valid_mask_lr.sum().item() < 10: + continue + pred_points_masked_lr = pred_points[lr_index][valid_mask_lr] + gt_points_masked_lr = gt_points[lr_index][valid_mask_lr] + diameter = (gt_points_masked.max(dim=0).values - gt_points_masked.min(dim=0).values).max() + scale, shift = align_points_scale_xyz_shift(pred_points_masked_lr, gt_points_masked_lr, 1 / diameter.expand(gt_points_masked_lr.shape[0])) + pred_points_masked = pred_points_masked * scale + shift + + local_points_metrics.append({ + 'rel': rel_point_local(pred_points_masked, gt_points_masked, diameter), + 'delta1': delta1_point_local(pred_points_masked, gt_points_masked, diameter), + }) + + metrics['local_points'] = key_average(local_points_metrics) + + # FOV. NOTE: If there is no random augmentation applied to the input images, all GT FOV are generallly the same. + # Fair evaluation of FOV requires random augmentation. + if 'intrinsics' in pred and 'intrinsics' in gt: + pred_intrinsics = pred['intrinsics'] + gt_intrinsics = gt['intrinsics'] + pred_fov_x, pred_fov_y = intrinsics_to_fov(pred_intrinsics) + gt_fov_x, gt_fov_y = intrinsics_to_fov(gt_intrinsics) + metrics['fov_x'] = { + 'mae': torch.rad2deg(pred_fov_x - gt_fov_x).abs().mean().item(), + 'deviation': torch.rad2deg(pred_fov_x - gt_fov_x).item(), + } + + # Boundary F1 + if pred_depth_aligned is not None and gt['has_sharp_boundary']: + metrics['boundary'] = { + 'radius1_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=1), + 'radius2_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=2), + 'radius3_f1': boundary_f1(pred_depth_aligned, gt_depth, mask, radius=3), + } + + if vis: + if pred_points_aligned is not None: + misc['pred_points'] = pred_points_aligned + if only_depth: + misc['pred_points'] = utils3d.torch.depth_to_points(pred_depth_aligned, intrinsics=gt['intrinsics']) + if pred_depth_aligned is not None: + misc['pred_depth'] = pred_depth_aligned + + return metrics, misc \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/train/__init__.py b/Pixel-Perfect-Depth/moge/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Pixel-Perfect-Depth/moge/train/dataloader.py b/Pixel-Perfect-Depth/moge/train/dataloader.py new file mode 100644 index 0000000000000000000000000000000000000000..a3bfc280844dac602e89bee747e247946dbc6f67 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/train/dataloader.py @@ -0,0 +1,338 @@ +import os +from pathlib import Path +import json +import time +import random +from typing import * +import traceback +import itertools +from numbers import Number +import io + +import numpy as np +import cv2 +from PIL import Image +import torch +import torchvision.transforms.v2.functional as TF +import utils3d +from tqdm import tqdm + +from ..utils import pipeline +from ..utils.io import * +from ..utils.geometry_numpy import mask_aware_nearest_resize_numpy, harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy, depth_of_field + + +class TrainDataLoaderPipeline: + def __init__(self, config: dict, batch_size: int, num_load_workers: int = 4, num_process_workers: int = 8, buffer_size: int = 8): + self.config = config + + self.batch_size = batch_size + self.clamp_max_depth = config['clamp_max_depth'] + self.fov_range_absolute = config.get('fov_range_absolute', 0.0) + self.fov_range_relative = config.get('fov_range_relative', 0.0) + self.center_augmentation = config.get('center_augmentation', 0.0) + self.image_augmentation = config.get('image_augmentation', []) + self.depth_interpolation = config.get('depth_interpolation', 'bilinear') + + if 'image_sizes' in config: + self.image_size_strategy = 'fixed' + self.image_sizes = config['image_sizes'] + elif 'aspect_ratio_range' in config and 'area_range' in config: + self.image_size_strategy = 'aspect_area' + self.aspect_ratio_range = config['aspect_ratio_range'] + self.area_range = config['area_range'] + else: + raise ValueError('Invalid image size configuration') + + # Load datasets + self.datasets = {} + for dataset in tqdm(config['datasets'], desc='Loading datasets'): + name = dataset['name'] + content = Path(dataset['path'], dataset.get('index', '.index.txt')).joinpath().read_text() + filenames = content.splitlines() + self.datasets[name] = { + **dataset, + 'path': dataset['path'], + 'filenames': filenames, + } + self.dataset_names = [dataset['name'] for dataset in config['datasets']] + self.dataset_weights = [dataset['weight'] for dataset in config['datasets']] + + # Build pipeline + self.pipeline = pipeline.Sequential([ + self._sample_batch, + pipeline.Unbatch(), + pipeline.Parallel([self._load_instance] * num_load_workers), + pipeline.Parallel([self._process_instance] * num_process_workers), + pipeline.Batch(self.batch_size), + self._collate_batch, + pipeline.Buffer(buffer_size), + ]) + + self.invalid_instance = { + 'intrinsics': np.array([[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]], dtype=np.float32), + 'image': np.zeros((256, 256, 3), dtype=np.uint8), + 'depth': np.ones((256, 256), dtype=np.float32), + 'depth_mask': np.ones((256, 256), dtype=bool), + 'depth_mask_inf': np.zeros((256, 256), dtype=bool), + 'label_type': 'invalid', + } + + def _sample_batch(self): + batch_id = 0 + last_area = None + while True: + # Depending on the sample strategy, choose a dataset and a filename + batch_id += 1 + batch = [] + + # Sample instances + for _ in range(self.batch_size): + dataset_name = random.choices(self.dataset_names, weights=self.dataset_weights)[0] + filename = random.choice(self.datasets[dataset_name]['filenames']) + + path = Path(self.datasets[dataset_name]['path'], filename) + + instance = { + 'batch_id': batch_id, + 'seed': random.randint(0, 2 ** 32 - 1), + 'dataset': dataset_name, + 'filename': filename, + 'path': path, + 'label_type': self.datasets[dataset_name]['label_type'], + } + batch.append(instance) + + # Decide the image size for this batch + if self.image_size_strategy == 'fixed': + width, height = random.choice(self.config['image_sizes']) + elif self.image_size_strategy == 'aspect_area': + area = random.uniform(*self.area_range) + aspect_ratio_ranges = [self.datasets[instance['dataset']].get('aspect_ratio_range', self.aspect_ratio_range) for instance in batch] + aspect_ratio_range = (min(r[0] for r in aspect_ratio_ranges), max(r[1] for r in aspect_ratio_ranges)) + aspect_ratio = random.uniform(*aspect_ratio_range) + width, height = int((area * aspect_ratio) ** 0.5), int((area / aspect_ratio) ** 0.5) + else: + raise ValueError('Invalid image size strategy') + + for instance in batch: + instance['width'], instance['height'] = width, height + + yield batch + + def _load_instance(self, instance: dict): + try: + image = read_image(Path(instance['path'], 'image.jpg')) + depth, _ = read_depth(Path(instance['path'], self.datasets[instance['dataset']].get('depth', 'depth.png'))) + + meta = read_meta(Path(instance['path'], 'meta.json')) + intrinsics = np.array(meta['intrinsics'], dtype=np.float32) + depth_mask = np.isfinite(depth) + depth_mask_inf = np.isinf(depth) + depth = np.nan_to_num(depth, nan=1, posinf=1, neginf=1) + data = { + 'image': image, + 'depth': depth, + 'depth_mask': depth_mask, + 'depth_mask_inf': depth_mask_inf, + 'intrinsics': intrinsics + } + instance.update({ + **data, + }) + except Exception as e: + print(f"Failed to load instance {instance['dataset']}/{instance['filename']} because of exception:", e) + instance.update(self.invalid_instance) + return instance + + def _process_instance(self, instance: Dict[str, Union[np.ndarray, str, float, bool]]): + image, depth, depth_mask, depth_mask_inf, intrinsics, label_type = instance['image'], instance['depth'], instance['depth_mask'], instance['depth_mask_inf'], instance['intrinsics'], instance['label_type'] + depth_unit = self.datasets[instance['dataset']].get('depth_unit', None) + + raw_height, raw_width = image.shape[:2] + raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1]) + raw_fov_x, raw_fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics) + raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height + tgt_width, tgt_height = instance['width'], instance['height'] + tgt_aspect = tgt_width / tgt_height + + rng = np.random.default_rng(instance['seed']) + + # 1. set target fov + center_augmentation = self.datasets[instance['dataset']].get('center_augmentation', self.center_augmentation) + fov_range_absolute_min, fov_range_absolute_max = self.datasets[instance['dataset']].get('fov_range_absolute', self.fov_range_absolute) + fov_range_relative_min, fov_range_relative_max = self.datasets[instance['dataset']].get('fov_range_relative', self.fov_range_relative) + tgt_fov_x_min = min(fov_range_relative_min * raw_fov_x, fov_range_relative_min * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect)) + tgt_fov_x_max = min(fov_range_relative_max * raw_fov_x, fov_range_relative_max * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect)) + tgt_fov_x_min, tgt_fov_x_max = max(np.deg2rad(fov_range_absolute_min), tgt_fov_x_min), min(np.deg2rad(fov_range_absolute_max), tgt_fov_x_max) + tgt_fov_x = rng.uniform(min(tgt_fov_x_min, tgt_fov_x_max), tgt_fov_x_max) + tgt_fov_y = utils3d.focal_to_fov(utils3d.numpy.fov_to_focal(tgt_fov_x) * tgt_aspect) + + # 2. set target image center (principal point) and the corresponding z-direction in raw camera space + center_dtheta = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_x - tgt_fov_x) + center_dphi = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_y - tgt_fov_y) + cu, cv = 0.5 + 0.5 * np.tan(center_dtheta) / np.tan(raw_fov_x / 2), 0.5 + 0.5 * np.tan(center_dphi) / np.tan(raw_fov_y / 2) + direction = utils3d.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0] + + # 3. obtain the rotation matrix for homography warping + R = utils3d.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32)) + + # 4. shrink the target view to fit into the warped image + corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32) + corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane + corners = corners[:, :2] / corners[:, 2:3] + tgt_horizontal, tgt_vertical = np.tan(tgt_fov_x / 2) * 2, np.tan(tgt_fov_y / 2) * 2 + warp_horizontal, warp_vertical = float('inf'), float('inf') + for i in range(4): + intersection, _ = utils3d.numpy.ray_intersection( + np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]), + corners[i - 1], corners[i] - corners[i - 1], + ) + warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min()) + tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical) + + # 5. obtain the target intrinsics + fx, fy = 1 / tgt_horizontal, 1 / tgt_vertical + tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32) + + # 6. do homogeneous transformation + # 6.1 The image and depth are resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling + tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes) + rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h) + image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS)) + + edge_mask = depth_occlusion_edge_numpy(depth, mask=depth_mask, thickness=2, tol=0.01) + _, depth_mask_nearest, resize_index = mask_aware_nearest_resize_numpy(None, depth_mask, (rescaled_w, rescaled_h), return_index=True) + depth_nearest = depth[resize_index] + distance_nearest = norm3d(utils3d.numpy.depth_to_points(depth_nearest, intrinsics=intrinsics)) + edge_mask = edge_mask[resize_index] + + if self.depth_interpolation == 'bilinear': + depth_mask_bilinear = cv2.resize(depth_mask.astype(np.float32), (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR) + depth_bilinear = 1 / cv2.resize(1 / depth, (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR) + distance_bilinear = norm3d(utils3d.numpy.depth_to_points(depth_bilinear, intrinsics=intrinsics)) + + depth_mask_inf = cv2.resize(depth_mask_inf.astype(np.uint8), (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) > 0 + + # 6.2 calculate homography warping + transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics) + uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height) + pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T + uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12) + pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32) + + tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LANCZOS4) + tgt_ray_length = norm3d(utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics)) + tgt_depth_mask_nearest = cv2.remap(depth_mask_nearest.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + tgt_depth_nearest = cv2.remap(distance_nearest, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) / tgt_ray_length + tgt_edge_mask = cv2.remap(edge_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + if self.depth_interpolation == 'bilinear': + tgt_depth_mask_bilinear = cv2.remap(depth_mask_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) + tgt_depth_bilinear = cv2.remap(distance_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) / tgt_ray_length + tgt_depth = np.where((tgt_depth_mask_bilinear == 1) & ~tgt_edge_mask, tgt_depth_bilinear, tgt_depth_nearest) + else: + tgt_depth = tgt_depth_nearest + tgt_depth_mask = tgt_depth_mask_nearest + + tgt_depth_mask_inf = cv2.remap(depth_mask_inf.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0 + + # always make sure that mask is not empty + if tgt_depth_mask.sum() / tgt_depth_mask.size < 0.001: + tgt_depth_mask = np.ones_like(tgt_depth_mask) + tgt_depth = np.ones_like(tgt_depth) + instance['label_type'] = 'invalid' + + # Flip augmentation + if rng.choice([True, False]): + tgt_image = np.flip(tgt_image, axis=1).copy() + tgt_depth = np.flip(tgt_depth, axis=1).copy() + tgt_depth_mask = np.flip(tgt_depth_mask, axis=1).copy() + tgt_depth_mask_inf = np.flip(tgt_depth_mask_inf, axis=1).copy() + + # Color augmentation + image_augmentation = self.datasets[instance['dataset']].get('image_augmentation', self.image_augmentation) + if 'jittering' in image_augmentation: + tgt_image = torch.from_numpy(tgt_image).permute(2, 0, 1) + tgt_image = TF.adjust_brightness(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = TF.adjust_contrast(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = TF.adjust_saturation(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = TF.adjust_hue(tgt_image, rng.uniform(-0.1, 0.1)) + tgt_image = TF.adjust_gamma(tgt_image, rng.uniform(0.7, 1.3)) + tgt_image = tgt_image.permute(1, 2, 0).numpy() + if 'dof' in image_augmentation: + if rng.uniform() < 0.5: + dof_strength = rng.integers(12) + tgt_disp = np.where(tgt_depth_mask_inf, 0, 1 / tgt_depth) + disp_min, disp_max = tgt_disp[tgt_depth_mask].min(), tgt_disp[tgt_depth_mask].max() + tgt_disp = cv2.inpaint(tgt_disp, (~tgt_depth_mask & ~tgt_depth_mask_inf).astype(np.uint8), 3, cv2.INPAINT_TELEA).clip(disp_min, disp_max) + dof_focus = rng.uniform(disp_min, disp_max) + tgt_image = depth_of_field(tgt_image, tgt_disp, dof_focus, dof_strength) + if 'shot_noise' in image_augmentation: + if rng.uniform() < 0.5: + k = np.exp(rng.uniform(np.log(100), np.log(10000))) / 255 + tgt_image = (rng.poisson(tgt_image * k) / k).clip(0, 255).astype(np.uint8) + if 'jpeg_loss' in image_augmentation: + if rng.uniform() < 0.5: + tgt_image = cv2.imdecode(cv2.imencode('.jpg', tgt_image, [cv2.IMWRITE_JPEG_QUALITY, rng.integers(20, 100)])[1], cv2.IMREAD_COLOR) + if 'blurring' in image_augmentation: + if rng.uniform() < 0.5: + ratio = rng.uniform(0.25, 1) + tgt_image = cv2.resize(cv2.resize(tgt_image, (int(tgt_width * ratio), int(tgt_height * ratio)), interpolation=cv2.INTER_AREA), (tgt_width, tgt_height), interpolation=rng.choice([cv2.INTER_LINEAR_EXACT, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4])) + + # convert depth to metric if necessary + if depth_unit is not None: + tgt_depth *= depth_unit + instance['is_metric'] = True + else: + instance['is_metric'] = False + + # clamp depth maximum values + max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.clamp_max_depth + tgt_depth = np.clip(tgt_depth, 0, max_depth) + tgt_depth = np.nan_to_num(tgt_depth, nan=1.0) + + if self.datasets[instance['dataset']].get('finite_depth_mask', None) == "only_known": + tgt_depth_mask_fin = tgt_depth_mask + else: + tgt_depth_mask_fin = ~tgt_depth_mask_inf + + instance.update({ + 'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1), + 'depth': torch.from_numpy(tgt_depth).float(), + 'depth_mask': torch.from_numpy(tgt_depth_mask).bool(), + 'depth_mask_fin': torch.from_numpy(tgt_depth_mask_fin).bool(), + 'depth_mask_inf': torch.from_numpy(tgt_depth_mask_inf).bool(), + 'intrinsics': torch.from_numpy(tgt_intrinsics).float(), + }) + + return instance + + def _collate_batch(self, instances: List[Dict[str, Any]]): + batch = {k: torch.stack([instance[k] for instance in instances], dim=0) for k in ['image', 'depth', 'depth_mask', 'depth_mask_fin', 'depth_mask_inf', 'intrinsics']} + batch = { + 'label_type': [instance['label_type'] for instance in instances], + 'is_metric': [instance['is_metric'] for instance in instances], + 'info': [{'dataset': instance['dataset'], 'filename': instance['filename']} for instance in instances], + **batch, + } + return batch + + def get(self) -> Dict[str, Union[torch.Tensor, str]]: + return self.pipeline.get() + + def start(self): + self.pipeline.start() + + def stop(self): + self.pipeline.stop() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.pipeline.terminate() + self.pipeline.join() + return False + + diff --git a/Pixel-Perfect-Depth/moge/train/losses.py b/Pixel-Perfect-Depth/moge/train/losses.py new file mode 100644 index 0000000000000000000000000000000000000000..2b251b230f4cc86d8358f613acf483badfb49e14 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/train/losses.py @@ -0,0 +1,270 @@ +from typing import * +import math + +import torch +import torch.nn.functional as F +import utils3d + +from ..utils.geometry_torch import ( + weighted_mean, + harmonic_mean, + geometric_mean, + mask_aware_nearest_resize, + normalized_view_plane_uv, + angle_diff_vec3 +) +from ..utils.alignment import ( + align_points_scale_z_shift, + align_points_scale, + align_points_scale_xyz_shift, + align_points_z_shift, +) + + +def _smooth(err: torch.FloatTensor, beta: float = 0.0) -> torch.FloatTensor: + if beta == 0: + return err + else: + return torch.where(err < beta, 0.5 * err.square() / beta, err - 0.5 * beta) + + +def affine_invariant_global_loss( + pred_points: torch.Tensor, + gt_points: torch.Tensor, + mask: torch.Tensor, + align_resolution: int = 64, + beta: float = 0.0, + trunc: float = 1.0, + sparsity_aware: bool = False +): + device = pred_points.device + + # Align + (pred_points_lr, gt_points_lr), lr_mask = mask_aware_nearest_resize((pred_points, gt_points), mask=mask, size=(align_resolution, align_resolution)) + scale, shift = align_points_scale_z_shift(pred_points_lr.flatten(-3, -2), gt_points_lr.flatten(-3, -2), lr_mask.flatten(-2, -1) / gt_points_lr[..., 2].flatten(-2, -1).clamp_min(1e-2), trunc=trunc) + valid = scale > 0 + scale, shift = torch.where(valid, scale, 0), torch.where(valid[..., None], shift, 0) + + pred_points = scale[..., None, None, None] * pred_points + shift[..., None, None, :] + + # Compute loss + weight = (valid[..., None, None] & mask).float() / gt_points[..., 2].clamp_min(1e-5) + weight = weight.clamp_max(10.0 * weighted_mean(weight, mask, dim=(-2, -1), keepdim=True)) # In case your data contains extremely small depth values + loss = _smooth((pred_points - gt_points).abs() * weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) + + if sparsity_aware: + # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1. + sparsity = mask.float().mean(dim=(-2, -1)) / lr_mask.float().mean(dim=(-2, -1)) + loss = loss / (sparsity + 1e-7) + + err = (pred_points.detach() - gt_points).norm(dim=-1) / gt_points[..., 2] + + # Record any scalar metric + misc = { + 'truncated_error': weighted_mean(err.clamp_max(1.0), mask).item(), + 'delta': weighted_mean((err < 1).float(), mask).item() + } + + return loss, misc, scale.detach() + + +def monitoring(points: torch.Tensor): + return { + 'std': points.std().item(), + } + + +def compute_anchor_sampling_weight( + points: torch.Tensor, + mask: torch.Tensor, + radius_2d: torch.Tensor, + radius_3d: torch.Tensor, + num_test: int = 64 +) -> torch.Tensor: + # Importance sampling to balance the sampled probability of fine strutures. + # NOTE: MoGe-1 uses uniform random sampling instead of importance sampling. + # This is an incremental trick introduced later than the publication of MoGe-1 paper. + + height, width = points.shape[-3:-1] + + pixel_i, pixel_j = torch.meshgrid( + torch.arange(height, device=points.device), + torch.arange(width, device=points.device), + indexing='ij' + ) + + test_delta_i = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test] + test_delta_j = torch.randint(-radius_2d, radius_2d + 1, (height, width, num_test,), device=points.device) # [num_test] + test_i, test_j = pixel_i[..., None] + test_delta_i, pixel_j[..., None] + test_delta_j # [height, width, num_test] + test_mask = (test_i >= 0) & (test_i < height) & (test_j >= 0) & (test_j < width) # [height, width, num_test] + test_i, test_j = test_i.clamp(0, height - 1), test_j.clamp(0, width - 1) # [height, width, num_test] + test_mask = test_mask & mask[..., test_i, test_j] # [..., height, width, num_test] + test_points = points[..., test_i, test_j, :] # [..., height, width, num_test, 3] + test_dist = (test_points - points[..., None, :]).norm(dim=-1) # [..., height, width, num_test] + + weight = 1 / ((test_dist <= radius_3d[..., None]) & test_mask).float().sum(dim=-1).clamp_min(1) + weight = torch.where(mask, weight, 0) + weight = weight / weight.sum(dim=(-2, -1), keepdim=True).add(1e-7) # [..., height, width] + return weight + + +def affine_invariant_local_loss( + pred_points: torch.Tensor, + gt_points: torch.Tensor, + gt_mask: torch.Tensor, + focal: torch.Tensor, + global_scale: torch.Tensor, + level: Literal[4, 16, 64], + align_resolution: int = 32, + num_patches: int = 16, + beta: float = 0.0, + trunc: float = 1.0, + sparsity_aware: bool = False +): + device, dtype = pred_points.device, pred_points.dtype + *batch_shape, height, width, _ = pred_points.shape + batch_size = math.prod(batch_shape) + pred_points, gt_points, gt_mask, focal, global_scale = pred_points.reshape(-1, height, width, 3), gt_points.reshape(-1, height, width, 3), gt_mask.reshape(-1, height, width), focal.reshape(-1), global_scale.reshape(-1) if global_scale is not None else None + + # Sample patch anchor points indices [num_total_patches] + radius_2d = math.ceil(0.5 / level * (height ** 2 + width ** 2) ** 0.5) + radius_3d = 0.5 / level / focal * gt_points[..., 2] + anchor_sampling_weights = compute_anchor_sampling_weight(gt_points, gt_mask, radius_2d, radius_3d, num_test=64) + where_mask = torch.where(gt_mask) + random_selection = torch.multinomial(anchor_sampling_weights[where_mask], num_patches * batch_size, replacement=True) + patch_batch_idx, patch_anchor_i, patch_anchor_j = [indices[random_selection] for indices in where_mask] # [num_total_patches] + + # Get patch indices [num_total_patches, patch_h, patch_w] + patch_i, patch_j = torch.meshgrid( + torch.arange(-radius_2d, radius_2d + 1, device=device), + torch.arange(-radius_2d, radius_2d + 1, device=device), + indexing='ij' + ) + patch_i, patch_j = patch_i + patch_anchor_i[:, None, None], patch_j + patch_anchor_j[:, None, None] + patch_mask = (patch_i >= 0) & (patch_i < height) & (patch_j >= 0) & (patch_j < width) + patch_i, patch_j = patch_i.clamp(0, height - 1), patch_j.clamp(0, width - 1) + + # Get patch mask and gt patch points + gt_patch_anchor_points = gt_points[patch_batch_idx, patch_anchor_i, patch_anchor_j] + gt_patch_radius_3d = 0.5 / level / focal[patch_batch_idx] * gt_patch_anchor_points[:, 2] + gt_patch_points = gt_points[patch_batch_idx[:, None, None], patch_i, patch_j] + gt_patch_dist = (gt_patch_points - gt_patch_anchor_points[:, None, None, :]).norm(dim=-1) + patch_mask &= gt_mask[patch_batch_idx[:, None, None], patch_i, patch_j] + patch_mask &= gt_patch_dist <= gt_patch_radius_3d[:, None, None] + + # Pick only non-empty patches + MINIMUM_POINTS_PER_PATCH = 32 + nonempty = torch.where(patch_mask.sum(dim=(-2, -1)) >= MINIMUM_POINTS_PER_PATCH) + num_nonempty_patches = nonempty[0].shape[0] + if num_nonempty_patches == 0: + return torch.tensor(0.0, dtype=dtype, device=device), {} + + # Finalize all patch variables + patch_batch_idx, patch_i, patch_j = patch_batch_idx[nonempty], patch_i[nonempty], patch_j[nonempty] + patch_mask = patch_mask[nonempty] # [num_nonempty_patches, patch_h, patch_w] + gt_patch_points = gt_patch_points[nonempty] # [num_nonempty_patches, patch_h, patch_w, 3] + gt_patch_radius_3d = gt_patch_radius_3d[nonempty] # [num_nonempty_patches] + gt_patch_anchor_points = gt_patch_anchor_points[nonempty] # [num_nonempty_patches, 3] + pred_patch_points = pred_points[patch_batch_idx[:, None, None], patch_i, patch_j] + + # Align patch points + (pred_patch_points_lr, gt_patch_points_lr), patch_lr_mask = mask_aware_nearest_resize((pred_patch_points, gt_patch_points), mask=patch_mask, size=(align_resolution, align_resolution)) + local_scale, local_shift = align_points_scale_xyz_shift(pred_patch_points_lr.flatten(-3, -2), gt_patch_points_lr.flatten(-3, -2), patch_lr_mask.flatten(-2) / gt_patch_radius_3d[:, None].add(1e-7), trunc=trunc) + if global_scale is not None: + scale_differ = local_scale / global_scale[patch_batch_idx] + patch_valid = (scale_differ > 0.1) & (scale_differ < 10.0) & (global_scale > 0) + else: + patch_valid = local_scale > 0 + local_scale, local_shift = torch.where(patch_valid, local_scale, 0), torch.where(patch_valid[:, None], local_shift, 0) + patch_mask &= patch_valid[:, None, None] + + pred_patch_points = local_scale[:, None, None, None] * pred_patch_points + local_shift[:, None, None, :] # [num_patches_nonempty, patch_h, patch_w, 3] + + # Compute loss + gt_mean = harmonic_mean(gt_points[..., 2], gt_mask, dim=(-2, -1)) + patch_weight = patch_mask.float() / gt_patch_points[..., 2].clamp_min(0.1 * gt_mean[patch_batch_idx, None, None]) # [num_patches_nonempty, patch_h, patch_w] + loss = _smooth((pred_patch_points - gt_patch_points).abs() * patch_weight[..., None], beta=beta).mean(dim=(-3, -2, -1)) # [num_patches_nonempty] + + if sparsity_aware: + # Reweighting improves performance on sparse depth data. NOTE: this is not used in MoGe-1. + sparsity = patch_mask.float().mean(dim=(-2, -1)) / patch_lr_mask.float().mean(dim=(-2, -1)) + loss = loss / (sparsity + 1e-7) + loss = torch.scatter_reduce(torch.zeros(batch_size, dtype=dtype, device=device), dim=0, index=patch_batch_idx, src=loss, reduce='sum') / num_patches + loss = loss.reshape(batch_shape) + + err = (pred_patch_points.detach() - gt_patch_points).norm(dim=-1) / gt_patch_radius_3d[..., None, None] + + # Record any scalar metric + misc = { + 'truncated_error': weighted_mean(err.clamp_max(1), patch_mask).item(), + 'delta': weighted_mean((err < 1).float(), patch_mask).item() + } + + return loss, misc + +def normal_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + device, dtype = points.device, points.dtype + height, width = points.shape[-3:-1] + + leftup, rightup, leftdown, rightdown = points[..., :-1, :-1, :], points[..., :-1, 1:, :], points[..., 1:, :-1, :], points[..., 1:, 1:, :] + upxleft = torch.cross(rightup - rightdown, leftdown - rightdown, dim=-1) + leftxdown = torch.cross(leftup - rightup, rightdown - rightup, dim=-1) + downxright = torch.cross(leftdown - leftup, rightup - leftup, dim=-1) + rightxup = torch.cross(rightdown - leftdown, leftup - leftdown, dim=-1) + + gt_leftup, gt_rightup, gt_leftdown, gt_rightdown = gt_points[..., :-1, :-1, :], gt_points[..., :-1, 1:, :], gt_points[..., 1:, :-1, :], gt_points[..., 1:, 1:, :] + gt_upxleft = torch.cross(gt_rightup - gt_rightdown, gt_leftdown - gt_rightdown, dim=-1) + gt_leftxdown = torch.cross(gt_leftup - gt_rightup, gt_rightdown - gt_rightup, dim=-1) + gt_downxright = torch.cross(gt_leftdown - gt_leftup, gt_rightup - gt_leftup, dim=-1) + gt_rightxup = torch.cross(gt_rightdown - gt_leftdown, gt_leftup - gt_leftdown, dim=-1) + + mask_leftup, mask_rightup, mask_leftdown, mask_rightdown = mask[..., :-1, :-1], mask[..., :-1, 1:], mask[..., 1:, :-1], mask[..., 1:, 1:] + mask_upxleft = mask_rightup & mask_leftdown & mask_rightdown + mask_leftxdown = mask_leftup & mask_rightdown & mask_rightup + mask_downxright = mask_leftdown & mask_rightup & mask_leftup + mask_rightxup = mask_rightdown & mask_leftup & mask_leftdown + + MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(1), math.radians(90), math.radians(3) + + loss = mask_upxleft * _smooth(angle_diff_vec3(upxleft, gt_upxleft).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \ + + mask_leftxdown * _smooth(angle_diff_vec3(leftxdown, gt_leftxdown).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \ + + mask_downxright * _smooth(angle_diff_vec3(downxright, gt_downxright).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) \ + + mask_rightxup * _smooth(angle_diff_vec3(rightxup, gt_rightxup).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) + + loss = loss.mean() / (4 * max(points.shape[-3:-1])) + + return loss, {} + + +def edge_loss(points: torch.Tensor, gt_points: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + device, dtype = points.device, points.dtype + height, width = points.shape[-3:-1] + + dx = points[..., :-1, :, :] - points[..., 1:, :, :] + dy = points[..., :, :-1, :] - points[..., :, 1:, :] + + gt_dx = gt_points[..., :-1, :, :] - gt_points[..., 1:, :, :] + gt_dy = gt_points[..., :, :-1, :] - gt_points[..., :, 1:, :] + + mask_dx = mask[..., :-1, :] & mask[..., 1:, :] + mask_dy = mask[..., :, :-1] & mask[..., :, 1:] + + MIN_ANGLE, MAX_ANGLE, BETA_RAD = math.radians(0.1), math.radians(90), math.radians(3) + + loss_dx = mask_dx * _smooth(angle_diff_vec3(dx, gt_dx).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) + loss_dy = mask_dy * _smooth(angle_diff_vec3(dy, gt_dy).clamp(MIN_ANGLE, MAX_ANGLE), beta=BETA_RAD) + loss = (loss_dx.mean(dim=(-2, -1)) + loss_dy.mean(dim=(-2, -1))) / (2 * max(points.shape[-3:-1])) + + return loss, {} + + +def mask_l2_loss(pred_mask: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor: + loss = gt_mask_neg.float() * pred_mask.square() + gt_mask_pos.float() * (1 - pred_mask).square() + loss = loss.mean(dim=(-2, -1)) + return loss, {} + + +def mask_bce_loss(pred_mask_prob: torch.Tensor, gt_mask_pos: torch.Tensor, gt_mask_neg: torch.Tensor) -> torch.Tensor: + loss = (gt_mask_pos | gt_mask_neg) * F.binary_cross_entropy(pred_mask_prob, gt_mask_pos.float(), reduction='none') + loss = loss.mean(dim=(-2, -1)) + return loss, {} diff --git a/Pixel-Perfect-Depth/moge/train/utils.py b/Pixel-Perfect-Depth/moge/train/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..5f21e00876b927991381bf2f777a68b02c5b38cc --- /dev/null +++ b/Pixel-Perfect-Depth/moge/train/utils.py @@ -0,0 +1,57 @@ +from typing import * +import fnmatch + +import sympy +import torch +import torch.nn as nn + + +def any_match(s: str, patterns: List[str]) -> bool: + return any(fnmatch.fnmatch(s, pat) for pat in patterns) + + +def build_optimizer(model: nn.Module, optimizer_config: Dict[str, Any]) -> torch.optim.Optimizer: + named_param_groups = [ + { + k: p for k, p in model.named_parameters() if any_match(k, param_group_config['params']['include']) and not any_match(k, param_group_config['params'].get('exclude', [])) + } for param_group_config in optimizer_config['params'] + ] + excluded_params = [k for k, p in model.named_parameters() if p.requires_grad and not any(k in named_params for named_params in named_param_groups)] + assert len(excluded_params) == 0, f'The following parameters require grad but are excluded from the optimizer: {excluded_params}' + optimizer_cls = getattr(torch.optim, optimizer_config['type']) + optimizer = optimizer_cls([ + { + **param_group_config, + 'params': list(params.values()), + } for param_group_config, params in zip(optimizer_config['params'], named_param_groups) + ]) + return optimizer + + +def parse_lr_lambda(s: str) -> Callable[[int], float]: + epoch = sympy.symbols('epoch') + lr_lambda = sympy.sympify(s) + return sympy.lambdify(epoch, lr_lambda, 'math') + + +def build_lr_scheduler(optimizer: torch.optim.Optimizer, scheduler_config: Dict[str, Any]) -> torch.optim.lr_scheduler._LRScheduler: + if scheduler_config['type'] == "SequentialLR": + child_schedulers = [ + build_lr_scheduler(optimizer, child_scheduler_config) + for child_scheduler_config in scheduler_config['params']['schedulers'] + ] + return torch.optim.lr_scheduler.SequentialLR(optimizer, schedulers=child_schedulers, milestones=scheduler_config['params']['milestones']) + elif scheduler_config['type'] == "LambdaLR": + lr_lambda = scheduler_config['params']['lr_lambda'] + if isinstance(lr_lambda, str): + lr_lambda = parse_lr_lambda(lr_lambda) + elif isinstance(lr_lambda, list): + lr_lambda = [parse_lr_lambda(l) for l in lr_lambda] + return torch.optim.lr_scheduler.LambdaLR( + optimizer, + lr_lambda=lr_lambda, + ) + else: + scheduler_cls = getattr(torch.optim.lr_scheduler, scheduler_config['type']) + scheduler = scheduler_cls(optimizer, **scheduler_config.get('params', {})) + return scheduler \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/__init__.py b/Pixel-Perfect-Depth/moge/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/Pixel-Perfect-Depth/moge/utils/alignment.py b/Pixel-Perfect-Depth/moge/utils/alignment.py new file mode 100644 index 0000000000000000000000000000000000000000..3d6bb78766ec1a43a89a4fc931b64f70c5201e2d --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/alignment.py @@ -0,0 +1,416 @@ +from typing import * +import math +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +import utils3d + + +def scatter_min(size: int, dim: int, index: torch.LongTensor, src: torch.Tensor) -> torch.return_types.min: + "Scatter the minimum value along the given dimension of `input` into `src` at the indices specified in `index`." + shape = src.shape[:dim] + (size,) + src.shape[dim + 1:] + minimum = torch.full(shape, float('inf'), dtype=src.dtype, device=src.device).scatter_reduce(dim=dim, index=index, src=src, reduce='amin', include_self=False) + minimum_where = torch.where(src == torch.gather(minimum, dim=dim, index=index)) + indices = torch.full(shape, -1, dtype=torch.long, device=src.device) + indices[(*minimum_where[:dim], index[minimum_where], *minimum_where[dim + 1:])] = minimum_where[dim] + return torch.return_types.min((minimum, indices)) + + +def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs): + batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0] + n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0) + splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args) + splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()} + results = [] + for i in range(n_chunks): + chunk_args = tuple(arg[i] for arg in splited_args) + chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()} + results.append(fn(*chunk_args, **chunk_kwargs)) + + if isinstance(results[0], tuple): + return tuple(torch.cat(r, dim=0) for r in zip(*results)) + else: + return torch.cat(results, dim=0) + + +def _pad_inf(x_: torch.Tensor): + return torch.cat([torch.full_like(x_[..., :1], -torch.inf), x_, torch.full_like(x_[..., :1], torch.inf)], dim=-1) + + +def _pad_cumsum(cumsum: torch.Tensor): + return torch.cat([torch.zeros_like(cumsum[..., :1]), cumsum, cumsum[..., -1:]], dim=-1) + + +def _compute_residual(a: torch.Tensor, xyw: torch.Tensor, trunc: float): + return a.mul(xyw[..., 0]).sub_(xyw[..., 1]).abs_().mul_(xyw[..., 2]).clamp_max_(trunc).sum(dim=-1) + + +def align(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor, trunc: Optional[Union[float, torch.Tensor]] = None, eps: float = 1e-7) -> Tuple[torch.Tensor, torch.Tensor, torch.LongTensor]: + """ + If trunc is None, solve `min sum_i w_i * |a * x_i - y_i|`, otherwise solve `min sum_i min(trunc, w_i * |a * x_i - y_i|)`. + + w_i must be >= 0. + + ### Parameters: + - `x`: tensor of shape (..., n) + - `y`: tensor of shape (..., n) + - `w`: tensor of shape (..., n) + - `trunc`: optional, float or tensor of shape (..., n) or None + + ### Returns: + - `a`: tensor of shape (...), differentiable + - `loss`: tensor of shape (...), value of loss function at `a`, detached + - `index`: tensor of shape (...), where a = y[idx] / x[idx] + """ + if trunc is None: + x, y, w = torch.broadcast_tensors(x, y, w) + sign = torch.sign(x) + x, y = x * sign, y * sign + y_div_x = y / x.clamp_min(eps) + y_div_x, argsort = y_div_x.sort(dim=-1) + + wx = torch.gather(x * w, dim=-1, index=argsort) + derivatives = 2 * wx.cumsum(dim=-1) - wx.sum(dim=-1, keepdim=True) + search = torch.searchsorted(derivatives, torch.zeros_like(derivatives[..., :1]), side='left').clamp_max(derivatives.shape[-1] - 1) + + a = y_div_x.gather(dim=-1, index=search).squeeze(-1) + index = argsort.gather(dim=-1, index=search).squeeze(-1) + loss = (w * (a[..., None] * x - y).abs()).sum(dim=-1) + + else: + # Reshape to (batch_size, n) for simplicity + x, y, w = torch.broadcast_tensors(x, y, w) + batch_shape = x.shape[:-1] + batch_size = math.prod(batch_shape) + x, y, w = x.reshape(-1, x.shape[-1]), y.reshape(-1, y.shape[-1]), w.reshape(-1, w.shape[-1]) + + sign = torch.sign(x) + x, y = x * sign, y * sign + wx, wy = w * x, w * y + xyw = torch.stack([x, y, w], dim=-1) # Stacked for convenient gathering + + y_div_x = A = y / x.clamp_min(eps) + B = (wy - trunc) / wx.clamp_min(eps) + C = (wy + trunc) / wx.clamp_min(eps) + with torch.no_grad(): + # Caculate prefix sum by orders of A, B, C + A, A_argsort = A.sort(dim=-1) + Q_A = torch.cumsum(torch.gather(wx, dim=-1, index=A_argsort), dim=-1) + A, Q_A = _pad_inf(A), _pad_cumsum(Q_A) # Pad [-inf, A1, ..., An, inf] and [0, Q1, ..., Qn, Qn] to handle edge cases. + + B, B_argsort = B.sort(dim=-1) + Q_B = torch.cumsum(torch.gather(wx, dim=-1, index=B_argsort), dim=-1) + B, Q_B = _pad_inf(B), _pad_cumsum(Q_B) + + C, C_argsort = C.sort(dim=-1) + Q_C = torch.cumsum(torch.gather(wx, dim=-1, index=C_argsort), dim=-1) + C, Q_C = _pad_inf(C), _pad_cumsum(Q_C) + + # Caculate left and right derivative of A + j_A = torch.searchsorted(A, y_div_x, side='left').sub_(1) + j_B = torch.searchsorted(B, y_div_x, side='left').sub_(1) + j_C = torch.searchsorted(C, y_div_x, side='left').sub_(1) + left_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) + j_A = torch.searchsorted(A, y_div_x, side='right').sub_(1) + j_B = torch.searchsorted(B, y_div_x, side='right').sub_(1) + j_C = torch.searchsorted(C, y_div_x, side='right').sub_(1) + right_derivative = 2 * torch.gather(Q_A, dim=-1, index=j_A) - torch.gather(Q_B, dim=-1, index=j_B) - torch.gather(Q_C, dim=-1, index=j_C) + + # Find extrema + is_extrema = (left_derivative < 0) & (right_derivative >= 0) + is_extrema[..., 0] |= ~is_extrema.any(dim=-1) # In case all derivatives are zero, take the first one as extrema. + where_extrema_batch, where_extrema_index = torch.where(is_extrema) + + # Calculate objective value at extrema + extrema_a = y_div_x[where_extrema_batch, where_extrema_index] # (num_extrema,) + MAX_ELEMENTS = 4096 ** 2 # Split into small batches to avoid OOM in case there are too many extrema.(~1G) + SPLIT_SIZE = MAX_ELEMENTS // x.shape[-1] + extrema_value = torch.cat([ + _compute_residual(extrema_a_split[:, None], xyw[extrema_i_split, :, :], trunc) + for extrema_a_split, extrema_i_split in zip(extrema_a.split(SPLIT_SIZE), where_extrema_batch.split(SPLIT_SIZE)) + ]) # (num_extrema,) + + # Find minima among corresponding extrema + minima, indices = scatter_min(size=batch_size, dim=0, index=where_extrema_batch, src=extrema_value) # (batch_size,) + index = where_extrema_index[indices] + + a = torch.gather(y, dim=-1, index=index[..., None]) / torch.gather(x, dim=-1, index=index[..., None]).clamp_min(eps) + a = a.reshape(batch_shape) + loss = minima.reshape(batch_shape) + index = index.reshape(batch_shape) + + return a, loss, index + + +def align_depth_scale(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + Align `depth_src` to `depth_tgt` with given constant weights. + + ### Parameters: + - `depth_src: torch.Tensor` of shape (..., N) + - `depth_tgt: torch.Tensor` of shape (..., N) + + """ + scale, _, _ = align(depth_src, depth_tgt, weight, trunc) + + return scale + + +def align_depth_affine(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + Align `depth_src` to `depth_tgt` with given constant weights. + + ### Parameters: + - `depth_src: torch.Tensor` of shape (..., N) + - `depth_tgt: torch.Tensor` of shape (..., N) + - `weight: torch.Tensor` of shape (..., N) + - `trunc: float` or tensor of shape (..., N) or None + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (...). + """ + dtype, device = depth_src.dtype, depth_src.device + + # Flatten batch dimensions for simplicity + batch_shape, n = depth_src.shape[:-1], depth_src.shape[-1] + batch_size = math.prod(batch_shape) + depth_src, depth_tgt, weight = depth_src.reshape(batch_size, n), depth_tgt.reshape(batch_size, n), weight.reshape(batch_size, n) + + # Here, we take anchors only for non-zero weights. + # Although the results will be still correct even anchor points have zero weight, + # it is wasting computation and may cause instability in some cases, e.g. too many extrema. + anchors_where_batch, anchors_where_n = torch.where(weight > 0) + + # Stop gradient when solving optimal anchors + with torch.no_grad(): + depth_src_anchor = depth_src[anchors_where_batch, anchors_where_n] # (anchors) + depth_tgt_anchor = depth_tgt[anchors_where_batch, anchors_where_n] # (anchors) + + depth_src_anchored = depth_src[anchors_where_batch, :] - depth_src_anchor[..., None] # (anchors, n) + depth_tgt_anchored = depth_tgt[anchors_where_batch, :] - depth_tgt_anchor[..., None] # (anchors, n) + weight_anchored = weight[anchors_where_batch, :] # (anchors, n) + + scale, loss, index = align(depth_src_anchored, depth_tgt_anchored, weight_anchored, trunc) # (anchors) + + loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchors_where_batch, src=loss) # (batch_size,) + + # Reproduce by indexing for shorter compute graph + index_1 = anchors_where_n[index_anchor] # (batch_size,) + index_2 = index[index_anchor] # (batch_size,) + + tgt_1, src_1 = torch.gather(depth_tgt, dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_1[..., None]).squeeze(-1) + tgt_2, src_2 = torch.gather(depth_tgt, dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(depth_src, dim=1, index=index_2[..., None]).squeeze(-1) + + scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1e-7) + shift = tgt_1 - scale * src_1 + + scale, shift = scale.reshape(batch_shape), shift.reshape(batch_shape) + + return scale, shift + +def align_depth_affine_irls(depth_src: torch.Tensor, depth_tgt: torch.Tensor, weight: Optional[torch.Tensor], max_iter: int = 100, eps: float = 1e-12): + """ + Align `depth_src` to `depth_tgt` with given constant weights using IRLS. + """ + dtype, device = depth_src.dtype, depth_src.device + + w = weight + x = torch.stack([depth_src, torch.ones_like(depth_src)], dim=-1) + y = depth_tgt + + for i in range(max_iter): + beta = (x.transpose(-1, -2) @ (w * y)) @ (x.transpose(-1, -2) @ (w[..., None] * x)).inverse().transpose(-2, -1) + w = 1 / (y - (x @ beta[..., None])[..., 0]).abs().clamp_min(eps) + + return beta[..., 0], beta[..., 1] + + +def align_points_scale(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weight: torch.Tensor` of shape (..., N) + + ### Returns: + - `a: torch.Tensor` of shape (...). Only positive solutions are garunteed. You should filter out negative scales before using it. + - `b: torch.Tensor` of shape (...) + """ + dtype, device = points_src.dtype, points_src.device + + scale, _, _ = align(points_src.flatten(-2), points_tgt.flatten(-2), weight[..., None].expand_as(points_src).flatten(-2), trunc) + + return scale + + +def align_points_scale_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None): + """ + Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. + It is similar to `align_affine` but scale and shift are applied to different dimensions. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3). x and y shifts are zeros. + """ + dtype, device = points_src.dtype, points_src.device + + # Flatten batch dimensions for simplicity + batch_shape, n = points_src.shape[:-2], points_src.shape[-2] + batch_size = math.prod(batch_shape) + points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) + + # Take anchors + anchor_where_batch, anchor_where_n = torch.where(weight > 0) + with torch.no_grad(): + zeros = torch.zeros(anchor_where_batch.shape[0], device=device, dtype=dtype) + points_src_anchor = torch.stack([zeros, zeros, points_src[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3) + points_tgt_anchor = torch.stack([zeros, zeros, points_tgt[anchor_where_batch, anchor_where_n, 2]], dim=-1) # (anchors, 3) + + points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3) + points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3) + weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3) + + # Solve optimal scale and shift for each anchor + MAX_ELEMENTS = 2 ** 20 + scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // n, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,) + + loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,) + + # Reproduce by indexing for shorter compute graph + index_2 = index[index_anchor] # (batch_size,) [0, 3n) + index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n) + + zeros = torch.zeros((batch_size, n), device=device, dtype=dtype) + points_tgt_00z, points_src_00z = torch.stack([zeros, zeros, points_tgt[..., 2]], dim=-1), torch.stack([zeros, zeros, points_src[..., 2]], dim=-1) + tgt_1, src_1 = torch.gather(points_tgt_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_src_00z.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) + tgt_2, src_2 = torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) + + scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) + shift = torch.gather(points_tgt_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src_00z, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) + scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) + + return scale, shift + + +def align_points_scale_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): + """ + Align `points_src` to `points_tgt` with respect to a shared xyz scale and z shift. + It is similar to `align_affine` but scale and shift are applied to different dimensions. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3) + """ + dtype, device = points_src.dtype, points_src.device + + # Flatten batch dimensions for simplicity + batch_shape, n = points_src.shape[:-2], points_src.shape[-2] + batch_size = math.prod(batch_shape) + points_src, points_tgt, weight = points_src.reshape(batch_size, n, 3), points_tgt.reshape(batch_size, n, 3), weight.reshape(batch_size, n) + + # Take anchors + anchor_where_batch, anchor_where_n = torch.where(weight > 0) + + with torch.no_grad(): + points_src_anchor = points_src[anchor_where_batch, anchor_where_n] # (anchors, 3) + points_tgt_anchor = points_tgt[anchor_where_batch, anchor_where_n] # (anchors, 3) + + points_src_anchored = points_src[anchor_where_batch, :, :] - points_src_anchor[..., None, :] # (anchors, n, 3) + points_tgt_anchored = points_tgt[anchor_where_batch, :, :] - points_tgt_anchor[..., None, :] # (anchors, n, 3) + weight_anchored = weight[anchor_where_batch, :, None].expand(-1, -1, 3) # (anchors, n, 3) + + # Solve optimal scale and shift for each anchor + MAX_ELEMENTS = 2 ** 20 + scale, loss, index = split_batch_fwd(align, MAX_ELEMENTS // 2, points_src_anchored.flatten(-2), points_tgt_anchored.flatten(-2), weight_anchored.flatten(-2), trunc) # (anchors,) + + # Get optimal scale and shift for each batch element + loss, index_anchor = scatter_min(size=batch_size, dim=0, index=anchor_where_batch, src=loss) # (batch_size,) + + index_2 = index[index_anchor] # (batch_size,) [0, 3n) + index_1 = anchor_where_n[index_anchor] * 3 + index_2 % 3 # (batch_size,) [0, 3n) + + src_1, tgt_1 = torch.gather(points_src.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_1[..., None]).squeeze(-1) + src_2, tgt_2 = torch.gather(points_src.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1), torch.gather(points_tgt.flatten(-2), dim=1, index=index_2[..., None]).squeeze(-1) + + scale = (tgt_2 - tgt_1) / torch.where(src_2 != src_1, src_2 - src_1, 1.0) + shift = torch.gather(points_tgt, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) - scale[..., None] * torch.gather(points_src, dim=1, index=(index_1 // 3)[..., None, None].expand(-1, -1, 3)).squeeze(-2) + + scale, shift = scale.reshape(batch_shape), shift.reshape(*batch_shape, 3) + + return scale, shift + + +def align_points_z_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): + """ + Align `points_src` to `points_tgt` with respect to a Z-axis shift. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3) + """ + dtype, device = points_src.dtype, points_src.device + + shift, _, _ = align(torch.ones_like(points_src[..., 2]), points_tgt[..., 2] - points_src[..., 2], weight, trunc) + shift = torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1) + + return shift + + +def align_points_xyz_shift(points_src: torch.Tensor, points_tgt: torch.Tensor, weight: Optional[torch.Tensor], trunc: Optional[Union[float, torch.Tensor]] = None, max_iters: int = 30, eps: float = 1e-6): + """ + Align `points_src` to `points_tgt` with respect to a Z-axis shift. + + ### Parameters: + - `points_src: torch.Tensor` of shape (..., N, 3) + - `points_tgt: torch.Tensor` of shape (..., N, 3) + - `weights: torch.Tensor` of shape (..., N) + + ### Returns: + - `scale: torch.Tensor` of shape (...). + - `shift: torch.Tensor` of shape (..., 3) + """ + dtype, device = points_src.dtype, points_src.device + + shift, _, _ = align(torch.ones_like(points_src).swapaxes(-2, -1), (points_tgt - points_src).swapaxes(-2, -1), weight[..., None, :], trunc) + + return shift + + +def align_affine_lstsq(x: torch.Tensor, y: torch.Tensor, w: torch.Tensor = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Solve `min sum_i w_i * (a * x_i + b - y_i ) ^ 2`, where `a` and `b` are scalars, with respect to `a` and `b` using least squares. + + ### Parameters: + - `x: torch.Tensor` of shape (..., N) + - `y: torch.Tensor` of shape (..., N) + - `w: torch.Tensor` of shape (..., N) + + ### Returns: + - `a: torch.Tensor` of shape (...,) + - `b: torch.Tensor` of shape (...,) + """ + w_sqrt = torch.ones_like(x) if w is None else w.sqrt() + A = torch.stack([w_sqrt * x, torch.ones_like(x)], dim=-1) + B = (w_sqrt * y)[..., None] + a, b = torch.linalg.lstsq(A, B)[0].squeeze(-1).unbind(-1) + return a, b \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/download.py b/Pixel-Perfect-Depth/moge/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..886edbccc81cc0c3daed4d858f641097bdfceee2 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/download.py @@ -0,0 +1,55 @@ +from pathlib import Path +from typing import * +import requests + +from tqdm import tqdm + + +__all__ = ["download_file", "download_bytes"] + + +def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Initialize local variables + file_path = Path(filepath) + downloaded_bytes = 0 + + # Check if we should resume the download + if resume and file_path.exists(): + downloaded_bytes = file_path.stat().st_size + headers['Range'] = f"bytes={downloaded_bytes}-" + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Calculate the total size to download + total_size = downloaded_bytes + int(response.headers.get('content-length', 0)) + + # Display a progress bar while downloading + with ( + tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar, + open(file_path, 'ab') as file, + ): + # Set the initial position of the progress bar + pbar.update(downloaded_bytes) + + # Write the content to the file in chunks + for chunk in response.iter_content(chunk_size=4096): + file.write(chunk) + pbar.update(len(chunk)) + + +def download_bytes(url: str, headers: dict = None) -> bytes: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Read the content of the response + return response.content + \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/geometry_numpy.py b/Pixel-Perfect-Depth/moge/utils/geometry_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..6975471e9fb7443d5a615a47de94d49841c789e1 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/geometry_numpy.py @@ -0,0 +1,406 @@ +from typing import * +from functools import partial +import math + +import cv2 +import numpy as np +from scipy.signal import fftconvolve +import numpy as np +import utils3d + +from .tools import timeit + + +def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return np.mean(x, axis=axis) + else: + w = w.astype(x.dtype) + return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None) + + +def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis) + else: + w = w.astype(x.dtype) + return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps) + + +def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype) + v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + uv = np.stack([u, v], axis=-1) + return uv + + +def focal_to_fov_numpy(focal: np.ndarray): + return 2 * np.arctan(0.5 / focal) + + +def fov_to_focal_numpy(fov: np.ndarray): + return 0.5 / np.tan(fov / 2) + + +def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0]) + fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1]) + return fov_x, fov_y + + +def point_map_to_depth_legacy_numpy(points: np.ndarray): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2) + _, uv = np.broadcast_arrays(points[..., :2], uv) + + # Solve least squares problem + b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2) + A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2) + + M = A.swapaxes(-2, -1) @ A + solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution + + depth = points[..., 2] + shift[..., None, None] + fov_x = np.arctan(width / diagonal / focal) * 2 + fov_y = np.arctan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[: , None] + f = (xy_proj * uv).sum() / np.square(xy_proj).sum() + err = (f * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + xy_proj = xy / (z + optim_shift)[: , None] + optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum() + + return optim_shift, optim_focal + + +def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[: , None] + err = (focal * xy_proj - uv).ravel() + return err + + solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + + return optim_shift + + +def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)): + import cv2 + assert points.shape[-1] == 3, "Points should (H, W, 3)" + + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + uv = normalized_view_plane_uv_numpy(width=width, height=height) + + if mask is None: + points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3) + uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2) + else: + (points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size) + + if points_lr.size < 2: + return 1., 0. + + if focal is None: + focal, shift = solve_optimal_focal_shift(uv_lr, points_lr) + else: + shift = solve_optimal_shift(uv_lr, points_lr, focal) + + return focal, shift + + +def mask_aware_nearest_resize_numpy( + inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None], + mask: np.ndarray, + size: Tuple[int, int], + return_index: bool = False +) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...). + - `mask`: input 2D mask of shape (..., H, W) + - `size`: target size (width, height) + + ### Returns + - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...). + - `resized_mask`: mask of the resized map of shape (..., target_height, target_width) + - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension. + """ + height, width = mask.shape[-2:] + target_width, target_height = size + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1 + + # Window the original mask and uv + uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) + indices = np.arange(height * width, dtype=np.int32).reshape(height, width) + padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) + windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + + # Gather the target pixels's local window + target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) + target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) + + target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + + # Compute nearest neighbor in the local window for each pixel + dist = np.square(target_window_centers - target_centers[..., None]) + dist = dist[..., 0, :] + dist[..., 1, :] + dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size) + nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1) + nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + target_mask = np.any(target_window_mask, axis=-1) + batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + index = (*batch_indices, nearest_i, nearest_j) + + if inputs is None: + outputs = None + elif isinstance(inputs, np.ndarray): + outputs = inputs[index] + elif isinstance(inputs, Sequence): + outputs = tuple(x[index] for x in inputs) + else: + raise ValueError(f'Invalid input type: {type(inputs)}') + + if return_index: + return outputs, target_mask, index + else: + return outputs, target_mask + + +def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `image`: Input 2D image of shape (..., H, W, C) + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + + if image.shape[-2:] == (height, width): + omit_channel_dim = True + else: + omit_channel_dim = False + if omit_channel_dim: + image = image[..., None] + + image = np.where(mask[..., None], image, 0) + + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1 + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1 + + # Window the original mask and uv (non-copy) + uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) + indices = np.arange(height * width, dtype=np.int32).reshape(height, width) + padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) + windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + + # Gather the target pixels's local window + target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) + target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) + + target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + + # Compute pixel area in the local windows + target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None]) + target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None]) + target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None) + target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0) + + # Weighted sum by area + target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1) + target_mask = np.sum(target_window_area, axis=-1) >= 0.25 + target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1) + + if omit_channel_dim: + target_image = target_image[..., 0] + + return target_image, target_mask + + +def norm3d(x: np.ndarray) -> np.ndarray: + "Faster `np.linalg.norm(x, axis=-1)` for 3D vectors" + return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2])) + + +def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1): + disp = np.where(mask, 1 / depth, 0) + disp_pad = np.pad(disp, (thickness, thickness), constant_values=0) + mask_pad = np.pad(mask, (thickness, thickness), constant_values=False) + kernel_size = 2 * thickness + 1 + disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2] + mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2] + + disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1)) + fg_edge_mask = mask & (disp > (1 + tol) * disp_mean) + bg_edge_mask = mask & (disp_mean > (1 + tol) * disp) + + edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \ + & (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) + + return edge_mask + + +def disk_kernel(radius: int) -> np.ndarray: + """ + Generate disk kernel with given radius. + + Args: + radius (int): Radius of the disk (in pixels). + + Returns: + np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel. + """ + # Create coordinate grid centered at (0,0) + L = np.arange(-radius, radius + 1) + X, Y = np.meshgrid(L, L) + # Generate disk: region inside circle with radius R is 1 + kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32) + # Normalize the kernel + kernel /= np.sum(kernel) + return kernel + + +def disk_blur(image: np.ndarray, radius: int) -> np.ndarray: + """ + Apply disk blur to an image using FFT convolution. + + Args: + image (np.ndarray): Input image, can be grayscale or color. + radius (int): Blur radius (in pixels). + + Returns: + np.ndarray: Blurred image. + """ + if radius == 0: + return image + kernel = disk_kernel(radius) + if image.ndim == 2: + blurred = fftconvolve(image, kernel, mode='same') + elif image.ndim == 3: + channels = [] + for i in range(image.shape[2]): + blurred_channel = fftconvolve(image[..., i], kernel, mode='same') + channels.append(blurred_channel) + blurred = np.stack(channels, axis=-1) + else: + raise ValueError("Image must be 2D or 3D.") + return blurred + + +def depth_of_field( + img: np.ndarray, + disp: np.ndarray, + focus_disp : float, + max_blur_radius : int = 10, +) -> np.ndarray: + """ + Apply depth of field effect to an image. + + Args: + img (numpy.ndarray): (H, W, 3) input image. + depth (numpy.ndarray): (H, W) depth map of the scene. + focus_depth (float): Focus depth of the lens. + strength (float): Strength of the depth of field effect. + max_blur_radius (int): Maximum blur radius (in pixels). + + Returns: + numpy.ndarray: (H, W, 3) output image with depth of field effect applied. + """ + # Precalculate dialated depth map for each blur radius + max_disp = np.max(disp) + disp = disp / max_disp + focus_disp = focus_disp / max_disp + dilated_disp = [] + for radius in range(max_blur_radius + 1): + dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1)) + + # Determine the blur radius for each pixel based on the depth map + blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32) + for radius in range(max_blur_radius + 1): + dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32) + mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp) + blur_radii[mask] = dialted_blur_radii[mask] + blur_radii = np.clip(blur_radii, 0, max_blur_radius) + blur_radii = cv2.blur(blur_radii, (5, 5)) + + # Precalculate the blured image for each blur radius + unique_radii = np.unique(blur_radii) + precomputed = {} + for radius in range(max_blur_radius + 1): + if radius not in unique_radii: + continue + precomputed[radius] = disk_blur(img, radius) + + # Composit the blured image for each pixel + output = np.zeros_like(img) + for r in unique_radii: + mask = blur_radii == r + output[mask] = precomputed[r][mask] + + return output diff --git a/Pixel-Perfect-Depth/moge/utils/geometry_torch.py b/Pixel-Perfect-Depth/moge/utils/geometry_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..ab5dbe965a42d0e0b3cbe53eb213bdcb829f8243 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/geometry_torch.py @@ -0,0 +1,354 @@ +from typing import * +import math +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +import utils3d + +from .tools import timeit +from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift + + +def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.mean(dim=dim, keepdim=keepdim) + else: + w = w.to(x.dtype) + return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps) + + +def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal() + + +def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).log().mean(dim=dim).exp() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp() + + +def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: + kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2)) + kernel = kernel / kernel.sum() + kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size) + input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate') + input = F.conv2d(input, kernel, groups=input.shape[1]) + return input + + +def focal_to_fov(focal: torch.Tensor): + return 2 * torch.atan(0.5 / focal) + + +def fov_to_focal(fov: torch.Tensor): + return 0.5 / torch.tan(fov / 2) + + +def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12): + return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1)) + +def intrinsics_to_fov(intrinsics: torch.Tensor): + """ + Returns field of view in radians from normalized intrinsics matrix. + ### Parameters: + - intrinsics: torch.Tensor of shape (..., 3, 3) + + ### Returns: + - fov_x: torch.Tensor of shape (...) + - fov_y: torch.Tensor of shape (...) + """ + focal_x = intrinsics[..., 0, 0] + focal_y = intrinsics[..., 1, 1] + return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y) + + +def point_map_to_depth_legacy(points: torch.Tensor): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + # Solve least squares problem + b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2) + A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2) + + M = A.transpose(-2, -1) @ A + solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution.unbind(-1) + + depth = points[..., 2] + shift[..., None, None] + fov_x = torch.atan(width / diagonal / focal) * 2 + fov_y = torch.atan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def view_plane_uv_to_focal(uv: torch.Tensor): + normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype) + focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12) + return focal + + +def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)): + """ + Recover the depth map and FoV from a point map with unknown z shift and focal. + + Note that it assumes: + - the optical center is at the center of the map + - the map is undistorted + - the map is isometric in the x and y directions + + ### Parameters: + - `points: torch.Tensor` of shape (..., H, W, 3) + - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps. + + ### Returns: + - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map + - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space + """ + shape = points.shape + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + points = points.reshape(-1, *shape[-3:]) + mask = None if mask is None else mask.reshape(-1, *shape[-3:-1]) + focal = focal.reshape(-1) if focal is not None else None + uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1) + uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0) + mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0 + + uv_lr_np = uv_lr.cpu().numpy() + points_lr_np = points_lr.detach().cpu().numpy() + focal_np = focal.cpu().numpy() if focal is not None else None + mask_lr_np = None if mask is None else mask_lr.cpu().numpy() + optim_shift, optim_focal = [], [] + for i in range(points.shape[0]): + points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]] + uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]] + if uv_lr_i_np.shape[0] < 2: + optim_focal.append(1) + optim_shift.append(0) + continue + if focal is None: + optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np) + optim_focal.append(float(optim_focal_i)) + else: + optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i]) + optim_shift.append(float(optim_shift_i)) + optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + + if focal is None: + optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3]) + else: + optim_focal = focal.reshape(shape[:-3]) + + return optim_focal, optim_shift + + +def mask_aware_nearest_resize( + inputs: Union[torch.Tensor, Sequence[torch.Tensor], None], + mask: torch.BoolTensor, + size: Tuple[int, int], + return_index: bool = False +) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...). + - `mask`: input 2D mask of shape (..., H, W) + - `size`: target size (target_width, target_height) + + ### Returns + - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...). + - `resized_mask`: mask of the resized map of shape (..., target_height, target_width) + - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, . + """ + height, width = mask.shape[-2:] + target_width, target_height = size + device = mask.device + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1 + + # Window the original mask and uv + uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device) + indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width) + padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1)) + windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device) + target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device) + target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device) + + target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + target_window_indices = target_window_indices.expand_as(target_window_mask) + + # Compute nearest neighbor in the local window for each pixel + dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size) + nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1) + nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width) + target_mask = torch.any(target_window_mask, dim=-1) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + index = (*batch_indices, nearest_i, nearest_j) + + if inputs is None: + outputs = None + elif isinstance(inputs, torch.Tensor): + outputs = inputs[index] + elif isinstance(inputs, Sequence): + outputs = tuple(x[index] for x in inputs) + else: + raise ValueError(f'Invalid input type: {type(inputs)}') + + if return_index: + return outputs, target_mask, index + else: + return outputs, target_mask + + +def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3): + *batch_shape, height, width = depth.shape + depth = depth.reshape(-1, 1, height, width) + mask = mask.reshape(-1, 1, height, width) + if pooler =='max': + pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + output_mask = pooled_depth > depth * (1 + rtol) + elif pooler =='min': + pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + output_mask = pooled_depth < depth * (1 - rtol) + else: + raise ValueError(f'Unsupported pooler: {pooler}') + output_mask = output_mask.reshape(*batch_shape, height, width) + return output_mask + + +def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1): + device, dtype = depth.device, depth.dtype + + disp = torch.where(mask, 1 / depth, 0) + disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0) + mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False) + disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2] + mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2] + + x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype) + A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3] + A = mask_window[..., None] * A + I = torch.eye(3, device=device, dtype=dtype) + + affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2] + diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0) + + edge_mask = mask & (diff > tol).any(dim=-1) + + disp_mean = weighted_mean(disp_window, mask_window, dim=-1) + fg_edge_mask = edge_mask & (disp > disp_mean) + # fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size) + bg_edge_mask = edge_mask & ~fg_edge_mask + return fg_edge_mask, bg_edge_mask + + +def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1): + device, dtype = depth.device, depth.dtype + + disp = torch.where(mask, 1 / depth, 0) + disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0) + mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False) + disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2] + mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2] + + disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1)) + fg_edge_mask = mask & (disp / disp_mean > 1 + tol) + bg_edge_mask = mask & (disp_mean / disp > 1 + tol) + + fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool() + bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool() + + return fg_edge_mask, bg_edge_mask + + +def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor: + kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool) + for _ in range(iterations): + input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1)) + mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1)) + if filter =='min': + input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values) + elif filter =='max': + input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values) + elif filter == 'mean': + input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1))) + elif filter =='median': + input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values) + mask = mask_window.any(dim=(-2, -1)) + return input, mask + + +def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor: + device, dtype = depth.device, depth.dtype + height, width = depth.shape[-2:] + radius = kernel_size // 2 + + duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device) + + log_depth = depth.clamp_min_(eps).log() + log_depth_diff = utils3d.torch.sliding_window_2d(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None] + + weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square()) + tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps) + + uv = utils3d.torch.image_uv(height=height, width=width, device=device, dtype=dtype) + K_inv = torch.inverse(intrinsics) + + grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \ + / (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2) + laplacian = (weight * ((utils3d.torch.sliding_window_2d(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1)) + + laplacian = laplacian.clamp(-0.1, 0.1) + log_depth_refine = log_depth.clone() + + for _ in range(iterations): + log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.torch.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp) + + depth_refine = log_depth_refine.exp() + + return depth_refine \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/io.py b/Pixel-Perfect-Depth/moge/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..108548caaa34dfcbf394ed4b021874c5ac12edf8 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/io.py @@ -0,0 +1,236 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from typing import IO +import zipfile +import json +import io +from typing import * +from pathlib import Path +import re +from PIL import Image, PngImagePlugin + +import numpy as np +import cv2 + +from .tools import timeit + + +def save_glb( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_uvs: np.ndarray, + texture: np.ndarray, + vertex_normals: Optional[np.ndarray] = None, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + vertex_normals=vertex_normals, + faces=faces, + visual = trimesh.visual.texture.TextureVisuals( + uv=vertex_uvs, + material=trimesh.visual.material.PBRMaterial( + baseColorTexture=Image.fromarray(texture), + metallicFactor=0.5, + roughnessFactor=1.0 + ) + ), + process=False + ).export(save_path) + + +def save_ply( + save_path: Union[str, os.PathLike], + vertices: np.ndarray, + faces: np.ndarray, + vertex_colors: np.ndarray, + vertex_normals: Optional[np.ndarray] = None, +): + import trimesh + import trimesh.visual + from PIL import Image + + trimesh.Trimesh( + vertices=vertices, + faces=faces, + vertex_colors=vertex_colors, + vertex_normals=vertex_normals, + process=False + ).export(save_path) + + +def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray: + """ + Read a image, return uint8 RGB array of shape (H, W, 3). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB) + return image + + +def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95): + """ + Write a image, input uint8 RGB array of shape (H, W, 3). + """ + data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes() + if isinstance(path, (str, os.PathLike)): + Path(path).write_bytes(data) + else: + path.write(data) + + +def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]: + """ + Read a depth image, return float32 depth array of shape (H, W). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + pil_image = Image.open(io.BytesIO(data)) + near = float(pil_image.info.get('near')) + far = float(pil_image.info.get('far')) + unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None + depth = np.array(pil_image) + mask_nan, mask_inf = depth == 0, depth == 65535 + depth = (depth.astype(np.float32) - 1) / 65533 + depth = near ** (1 - depth) * far ** depth + depth[mask_nan] = np.nan + depth[mask_inf] = np.inf + return depth, unit + + +def write_depth( + path: Union[str, os.PathLike, IO], + depth: np.ndarray, + unit: float = None, + max_range: float = 1e5, + compression_level: int = 7, +): + """ + Encode and write a depth image as 16-bit PNG format. + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to write to. + - `depth: np.ndarray` + The depth array, float32 array of shape (H, W). + May contain `NaN` for invalid values and `Inf` for infinite values. + - `unit: float = None` + The unit of the depth values. + + Depth values are encoded as follows: + - 0: unknown + - 1 ~ 65534: depth values in logarithmic + - 65535: infinity + + metadata is stored in the PNG file as text fields: + - `near`: the minimum depth value + - `far`: the maximum depth value + - `unit`: the unit of the depth values (optional) + """ + mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth) + + depth = depth.astype(np.float32) + mask_finite = depth + near = max(depth[mask_values].min(), 1e-5) + far = max(near * 1.1, min(depth[mask_values].max(), near * max_range)) + depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534 + depth[mask_nan] = 0 + depth[mask_inf] = 65535 + + pil_image = Image.fromarray(depth) + pnginfo = PngImagePlugin.PngInfo() + pnginfo.add_text('near', str(near)) + pnginfo.add_text('far', str(far)) + if unit is not None: + pnginfo.add_text('unit', str(unit)) + pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) + + +def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]: + """ + Read a segmentation mask + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to read from. + ### Returns: + - `Tuple[np.ndarray, Dict[str, int]]` + A tuple containing: + - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W). + - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}. + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + pil_image = Image.open(io.BytesIO(data)) + labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None + mask = np.array(pil_image) + return mask, labels + + +def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7): + """ + Write a segmentation mask and label mapping, as PNG format. + ### Parameters: + - `path: Union[str, os.PathLike, IO]` + The file path or file object to write to. + - `mask: np.ndarray` + The segmentation mask, uint8 or uint16 array of shape (H, W). + - `labels: Dict[str, int] = None` + The label mapping, a dictionary of {label_name: label_id}. + - `compression_level: int = 7` + The compression level for PNG compression. + """ + assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}" + pil_image = Image.fromarray(mask) + pnginfo = PngImagePlugin.PngInfo() + if labels is not None: + labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':')) + pnginfo.add_text('labels', labels_json) + pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level) + + + +def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray: + """ + Read a normal image, return float32 normal array of shape (H, W, 3). + """ + if isinstance(path, (str, os.PathLike)): + data = Path(path).read_bytes() + else: + data = path.read() + normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB) + mask_nan = np.all(normal == 0, axis=-1) + normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0] + normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12) + normal[mask_nan] = np.nan + return normal + + +def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray: + """ + Write a normal image, input float32 normal array of shape (H, W, 3). + """ + mask_nan = np.isnan(normal).any(axis=-1) + normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16) + normal[mask_nan] = 0 + data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes() + if isinstance(path, (str, os.PathLike)): + Path(path).write_bytes(data) + else: + path.write(data) + + +def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]: + return json.loads(Path(path).read_text()) + +def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]): + Path(path).write_text(json.dumps(meta)) \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/panorama.py b/Pixel-Perfect-Depth/moge/utils/panorama.py new file mode 100644 index 0000000000000000000000000000000000000000..3f9d121c3c189770a7fd9f88be66f74f1ba5cfd3 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/panorama.py @@ -0,0 +1,191 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from pathlib import Path +from typing import * +import itertools +import json +import warnings + +import cv2 +import numpy as np +from numpy import ndarray +from tqdm import tqdm, trange +from scipy.sparse import csr_array, hstack, vstack +from scipy.ndimage import convolve +from scipy.sparse.linalg import lsmr + +import utils3d + + +def get_panorama_cameras(): + vertices, _ = utils3d.numpy.icosahedron() + intrinsics = utils3d.numpy.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90)) + extrinsics = utils3d.numpy.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32) + return extrinsics, [intrinsics] * len(vertices) + + +def spherical_uv_to_directions(uv: np.ndarray): + theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi + directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1) + return directions + + +def directions_to_spherical_uv(directions: np.ndarray): + directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True) + u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0 + v = np.arccos(directions[..., 2]) / np.pi + return np.stack([u, v], axis=-1) + + +def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int): + height, width = image.shape[:2] + uv = utils3d.numpy.image_uv(width=resolution, height=resolution) + splitted_images = [] + for i in range(len(extrinsics)): + spherical_uv = directions_to_spherical_uv(utils3d.numpy.unproject_cv(uv, extrinsics=extrinsics[i], intrinsics=intrinsics[i])) + pixels = utils3d.numpy.uv_to_pixel(spherical_uv, width=width, height=height).astype(np.float32) + + splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR) + splitted_images.append(splitted_image) + return splitted_images + + +def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]: + grid_index = np.arange(height * width).reshape(height, width) + grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge') + grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge') + + data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1) + indices = np.stack([ + grid_index[1:-1, 1:-1], + grid_index[:-2, 1:-1], # up + grid_index[2:, 1:-1], # down + grid_index[1:-1, :-2], # left + grid_index[1:-1, 2:] # right + ], axis=-1).reshape(-1) + indptr = np.arange(0, height * width * 5 + 1, 5) + A = csr_array((data, indices, indptr), shape=(height * width, height * width)) + + return A + + +def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]: + grid_index = np.arange(width * height).reshape(height, width) + if wrap_x: + grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap') + if wrap_y: + grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap') + + data = np.concatenate([ + np.concatenate([ + np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j] + -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1] + ], axis=1).reshape(-1), + np.concatenate([ + np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j] + -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j] + ], axis=1).reshape(-1), + ]) + indices = np.concatenate([ + np.concatenate([ + grid_index[:, :-1].reshape(-1, 1), + grid_index[:, 1:].reshape(-1, 1), + ], axis=1).reshape(-1), + np.concatenate([ + grid_index[:-1, :].reshape(-1, 1), + grid_index[1:, :].reshape(-1, 1), + ], axis=1).reshape(-1), + ]) + indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2) + A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width)) + + return A + + +def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]): + if max(width, height) > 256: + panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics) + panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR) + else: + panorama_depth_init = None + + uv = utils3d.numpy.image_uv(width=width, height=height) + spherical_directions = spherical_uv_to_directions(uv) + + # Warp each view to the panorama + panorama_log_distance_grad_maps, panorama_grad_masks = [], [] + panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], [] + panorama_pred_masks = [] + for i in range(len(distance_maps)): + projected_uv, projected_depth = utils3d.numpy.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i]) + projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1) + + projected_pixels = utils3d.numpy.uv_to_pixel(np.clip(projected_uv, 0, 1), width=distance_maps[i].shape[1], height=distance_maps[i].shape[0]).astype(np.float32) + + log_splitted_distance = np.log(distance_maps[i]) + panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0) + panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0) + + # calculate gradient map + padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap') + grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :] + + padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap') + mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :] + + panorama_log_distance_grad_maps.append((grad_x, grad_y)) + panorama_grad_masks.append((mask_x, mask_y)) + + # calculate laplacian map + padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge') + padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') + laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1] + + padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge') + padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap') + mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5 + + panorama_log_distance_laplacian_maps.append(laplacian) + panorama_laplacian_masks.append(mask) + + panorama_pred_masks.append(panorama_pred_mask) + + panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0) + panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0) + panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0) + panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0) + + panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3) + panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3) + + panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0) + panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0) + panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3) + + grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1) + grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1) + grad_mask = np.concatenate([grad_x_mask, grad_y_mask]) + laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1) + + # Solve overdetermined system + A = vstack([ + grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask], + poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask], + ]) + b = np.concatenate([ + panorama_log_distance_grad_x.reshape(-1)[grad_x_mask], + panorama_log_distance_grad_y.reshape(-1)[grad_y_mask], + panorama_laplacian_map.reshape(-1)[laplacian_mask] + ]) + x, *_ = lsmr( + A, b, + atol=1e-5, btol=1e-5, + x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None, + show=False, + ) + + panorama_depth = np.exp(x).reshape(height, width).astype(np.float32) + panorama_mask = np.any(panorama_pred_masks, axis=0) + + return panorama_depth, panorama_mask + diff --git a/Pixel-Perfect-Depth/moge/utils/pipeline.py b/Pixel-Perfect-Depth/moge/utils/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..daa522e987317e949899d4159e61d7a7066e1fba --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/pipeline.py @@ -0,0 +1,503 @@ +from typing import * +from abc import abstractmethod +from queue import Empty, Full +from threading import Thread +from queue import Queue +from multiprocessing import Process +from threading import Thread, Event +import multiprocessing +import threading +import inspect +import time +import uuid +from copy import deepcopy +import itertools +import functools + +__all__ = [ + 'Node', + 'Link', + 'ConcurrentNode', + 'Worker', + 'WorkerFunction', + 'Provider', + 'ProviderFunction', + 'Sequential', + 'Batch', + 'Unbatch', + 'Parallel', + 'Graph', + 'Buffer', +] + +TERMINATE_CHECK_INTERVAL = 0.5 + + +class _ItemWrapper: + def __init__(self, data: Any, id: Union[int, List[int]] = None): + self.data = data + self.id = id + + +class Terminate(Exception): + pass + + +def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper: + while True: + try: + item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL)) + if terminate_flag.is_set(): + raise Terminate() + return item + except Empty: + if terminate_flag.is_set(): + raise Terminate() + + if timeout is not None: + timeout -= TERMINATE_CHECK_INTERVAL + if timeout <= 0: + raise Empty() + + +def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event): + while True: + try: + queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL) + if terminate_flag.is_set(): + raise Terminate() + return + except Full: + if terminate_flag.is_set(): + raise Terminate() + +class Node: + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + self.input: Queue = Queue(maxsize=in_buffer_size) + self.output: Queue = Queue(maxsize=out_buffer_size) + self.in_buffer_size = in_buffer_size + self.out_buffer_size = out_buffer_size + + @abstractmethod + def start(self): + pass + + @abstractmethod + def terminate(self): + pass + + def stop(self): + self.terminate() + self.join() + + @abstractmethod + def join(self): + pass + + def put(self, data: Any, key: str = None, block: bool = True) -> None: + item = _ItemWrapper(data) + self.input.put(item, block=block) + + def get(self, key: str = None, block: bool = True) -> Any: + item: _ItemWrapper = self.output.get(block=block) + return item.data + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.terminate() + self.join() + + +class ConcurrentNode(Node): + job: Union[Thread, Process] + + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(in_buffer_size, out_buffer_size) + self.running_as = running_as + + @abstractmethod + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + pass + + def start(self): + if self.running_as == 'thread': + terminate_flag = threading.Event() + job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + elif self.running_as == 'process': + terminate_flag = multiprocessing.Event() + job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + job.start() + self.job = job + self.terminate_flag = terminate_flag + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.job.join() + + +class Worker(ConcurrentNode): + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread is started, to initialize any resources that is only held in the thread. + """ + pass + + @abstractmethod + def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]: + """ + This method defines the job that the node should do for each input item. + A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue. + The method is executed concurrently with other nodes. + """ + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + while True: + item = _get_queue_item(input, terminate_flag) + result = self.work(item.data) + _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag) + + except Terminate: + return + + +class Provider(ConcurrentNode): + """ + A node that provides data to successive nodes. It takes no input and provides data to the output queue. + """ + def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None: + super().__init__(running_as, 0, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process. + """ + pass + + @abstractmethod + def provide(self) -> Generator[Any, None, None]: + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + for data in self.provide(): + _put_queue_item(output, _ItemWrapper(data), terminate_flag) + except Terminate: + return + + +class WorkerFunction(Worker): + def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + self.fn = fn + + def work(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + +class ProviderFunction(Provider): + def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None: + super().__init__(running_as, out_buffer_size) + self.fn = fn + + def provide(self): + for item in self.fn(): + yield item + + +class Link: + def __init__(self, src: Queue, dst: Queue): + self.src = src + self.dst = dst + + def _thread_fn(self): + try: + while True: + item = _get_queue_item(self.src, self.terminate_flag) + _put_queue_item(self.dst, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.thread = Thread(target=self._thread_fn) + self.thread.start() + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.thread.join() + + +class Graph(Node): + """ + Graph pipeline of nodes and links + """ + nodes: List[Node] + links: List[Link] + + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + self.links = [] + + def add(self, node: Node): + self.nodes.append(node) + + def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]): + """ + Links the output of the source node to the input of the destination node. + If the source or destination node is None, the pipeline's input or output is used. + """ + src_queue = self.input if src is None else src.output + dst_queue = self.output if dst is None else dst.input + self.links.append(Link(src_queue, dst_queue)) + + def chain(self, nodes: Iterable[Node]): + """ + Link the output of each node to the input of the next node. + """ + nodes = list(nodes) + for i in range(len(nodes) - 1): + self.link(nodes[i], nodes[i + 1]) + + def start(self): + for node in self.nodes: + node.start() + for link in self.links: + link.start() + + def terminate(self): + for node in self.nodes: + node.terminate() + for link in self.links: + link.terminate() + + def join(self): + for node in self.nodes: + node.join() + for link in self.links: + link.join() + + def __iter__(self): + providers = [node for node in self.nodes if isinstance(node, Provider)] + if len(providers) == 0: + raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.") + with self: + # while all(provider.job.is_alive() for provider in providers): + while True: + yield self.get() + + def __call__(self, data: Any) -> Any: + """ + Submit data to the pipeline's input queue, and return the output data asynchronously. + NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work. + """ + # TODO + + +class Sequential(Graph): + """ + Pipeline of nodes in sequential order, where each node takes the output of the previous node as input. + The order of input and output items is preserved (FIFO) + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute sequentially. + ### Parameters: + - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + self.chain([None, *self.nodes, None]) + + +class Parallel(Node): + """ + A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available. + NOTE: It is FIFO if and only if all the nested nodes are FIFO. + """ + nodes: List[Node] + + def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.nodes.append(node) + self.output_order = Queue() + self.lock = threading.Lock() + + def _in_thread_fn(self, node: Node): + try: + while True: + with self.lock: + # A better idea: first make sure its node is vacant, then get it a new item. + # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node. + # This could lead to suboptimal scheduling. + item = _get_queue_item(self.input, self.terminate_flag) + self.output_order.put(node.output) + _put_queue_item(node.input, item, self.terminate_flag) + except Terminate: + return + + def _out_thread_fn(self): + try: + while True: + queue = _get_queue_item(self.output_order, self.terminate_flag) + item = _get_queue_item(queue, self.terminate_flag) + _put_queue_item(self.output, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.in_threads = [] + for node in self.nodes: + thread = Thread(target=self._in_thread_fn, args=(node,)) + thread.start() + self.in_threads.append(thread) + thread = Thread(target=self._out_thread_fn) + thread.start() + self.out_thread = thread + for node in self.nodes: + node.start() + + def terminate(self): + self.terminate_flag.set() + for node in self.nodes: + node.terminate() + + def join(self): + for thread in self.in_threads: + thread.join() + self.out_thread.join() + + +class UnorderedParallel(Graph): + """ + Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available. + NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input. + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node. + ### Parameters: + - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + for i in range(len(nodes)): + self.chain([None, self.nodes[i], None]) + + +class Batch(ConcurrentNode): + """ + Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes. + The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node, + i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size. + """ + def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1): + assert batch_size > 0, "Batch size must be greater than 0." + super().__init__('thread', in_buffer_size, out_buffer_size) + self.batch_size = batch_size + self.patience = patience + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch_id, batch_data = [], [] + # Try to fill the batch + for i in range(self.batch_size): + if i == 0 or self.patience is None: + timeout = None + else: + timeout = self.patience - (time.time() - earliest_time) + if timeout < 0: + break + try: + item = _get_queue_item(input, terminate_flag, timeout) + except Empty: + break + + if i == 0: + earliest_time = time.time() + batch_data.append(item.data) + batch_id.append(item.id) + + batch = _ItemWrapper(batch_data, batch_id) + _put_queue_item(output, batch, terminate_flag) + except Terminate: + return + + +class Unbatch(ConcurrentNode): + """ + Ungroups every batch (a list of items) into individual items and passes them to successive nodes. + """ + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__('thread', in_buffer_size, out_buffer_size) + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch = _get_queue_item(input, terminate_flag) + for id, data in zip(batch.id or itertools.repeat(None), batch.data): + item = _ItemWrapper(data, id) + _put_queue_item(output, item, terminate_flag) + except Terminate: + return + + +class Buffer(Node): + "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time." + def __init__(self, size: int): + super().__init__(size, size) + self.size = size + self.input = self.output = Queue(maxsize=size) \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/tools.py b/Pixel-Perfect-Depth/moge/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..3687f6938fe34433d149a1a8405be7eed5f23c37 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/tools.py @@ -0,0 +1,289 @@ +from typing import * +import time +from pathlib import Path +from numbers import Number +from functools import wraps +import warnings +import math +import json +import os +import importlib +import importlib.util + + +def catch_exception(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + import traceback + print(f"Exception in {fn.__name__}", end='r') + # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())}) + traceback.print_exc(chain=False) + time.sleep(0.1) + return None + return wrapper + + +class CallbackOnException: + def __init__(self, callback: Callable, exception: type): + self.exception = exception + self.callback = callback + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if isinstance(exc_val, self.exception): + self.callback() + return True + return False + +def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: + for k, v in d.items(): + if isinstance(v, dict): + for sub_key in traverse_nested_dict_keys(v): + yield (k, ) + sub_key + else: + yield (k, ) + + +def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): + for k in keys: + d = d.get(k, default) + if d is None: + break + return d + +def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value + + +def key_average(list_of_dicts: list) -> Dict[str, Any]: + """ + Returns a dictionary with the average value of each key in the input list of dictionaries. + """ + _nested_dict_keys = set() + for d in list_of_dicts: + _nested_dict_keys.update(traverse_nested_dict_keys(d)) + _nested_dict_keys = sorted(_nested_dict_keys) + result = {} + for k in _nested_dict_keys: + values = [] + for d in list_of_dicts: + v = get_nested_dict(d, k) + if v is not None and not math.isnan(v): + values.append(v) + avg = sum(values) / len(values) if values else float('nan') + set_nested_dict(result, k, avg) + return result + + +def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: + """ + Flattens a nested dictionary into a single-level dictionary, with keys as tuples. + """ + items = [] + if parent_key is None: + parent_key = () + for k, v in d.items(): + new_key = parent_key + (k, ) + if isinstance(v, MutableMapping): + items.extend(flatten_nested_dict(v, new_key).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. + """ + result = {} + for k, v in d.items(): + sub_dict = result + for k_ in k[:-1]: + if k_ not in sub_dict: + sub_dict[k_] = {} + sub_dict = sub_dict[k_] + sub_dict[k[-1]] = v + return result + + +def read_jsonl(file): + import json + with open(file, 'r') as f: + data = f.readlines() + return [json.loads(line) for line in data] + + +def write_jsonl(data: List[dict], file): + import json + with open(file, 'w') as f: + for item in data: + f.write(json.dumps(item) + '\n') + + +def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): + import pandas as pd + data = [flatten_nested_dict(d) for d in data] + df = pd.DataFrame(data) + df = df.sort_index(axis=1) + df.columns = pd.MultiIndex.from_tuples(df.columns) + return df + + +def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): + if isinstance(d, str): + for old, new in mapping.items(): + d = d.replace(old, new) + elif isinstance(d, list): + for i, item in enumerate(d): + d[i] = recursive_replace(item, mapping) + elif isinstance(d, dict): + for k, v in d.items(): + d[k] = recursive_replace(v, mapping) + return d + + +class timeit: + _history: Dict[str, List['timeit']] = {} + + def __init__(self, name: str = None, verbose: bool = True, average: bool = False): + self.name = name + self.verbose = verbose + self.start = None + self.end = None + self.average = average + if average and name not in timeit._history: + timeit._history[name] = [] + + def __call__(self, func: Callable): + import inspect + if inspect.iscoroutinefunction(func): + async def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = await func(*args, **kwargs) + return ret + return wrapper + else: + def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = func(*args, **kwargs) + return ret + return wrapper + + def __enter__(self): + self.start = time.time() + return self + + @property + def time(self) -> float: + assert self.start is not None, "Time not yet started." + assert self.end is not None, "Time not yet ended." + return self.end - self.start + + @property + def average_time(self) -> float: + assert self.average, "Average time not available." + return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) + + @property + def history(self) -> List['timeit']: + return timeit._history.get(self.name, []) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.time() + if self.average: + timeit._history[self.name].append(self) + if self.verbose: + if self.average: + avg = self.average_time + print(f"{self.name or 'It'} took {avg:.6f} seconds in average.") + else: + print(f"{self.name or 'It'} took {self.time:.6f} seconds.") + + +def strip_common_prefix_suffix(strings: List[str]) -> List[str]: + first = strings[0] + + for start in range(len(first)): + if any(s[start] != strings[0][start] for s in strings): + break + + for end in range(1, min(len(s) for s in strings)): + if any(s[-end] != first[-end] for s in strings): + break + + return [s[start:len(s) - end + 1] for s in strings] + + +def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): + from concurrent.futures import ThreadPoolExecutor + from contextlib import nullcontext + from tqdm import tqdm + + if pbar is not None: + pbar.total = len(inputs) if hasattr(inputs, '__len__') else None + else: + pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) + + def decorator(fn: Callable): + with ( + ThreadPoolExecutor(max_workers=num_workers) as executor, + pbar + ): + pbar.refresh() + @catch_exception + @suppress_traceback + def _fn(input): + ret = fn(input) + pbar.update() + return ret + executor.map(_fn, inputs) + executor.shutdown(wait=True) + + return decorator + + +def suppress_traceback(fn): + @wraps(fn) + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + e.__traceback__ = e.__traceback__.tb_next.tb_next + raise + return wrapper + + +class no_warnings: + def __init__(self, action: str = 'ignore', **kwargs): + self.action = action + self.filter_kwargs = kwargs + + def __call__(self, fn): + @wraps(fn) + def wrapper(*args, **kwargs): + with warnings.catch_warnings(): + warnings.simplefilter(self.action, **self.filter_kwargs) + return fn(*args, **kwargs) + return wrapper + + def __enter__(self): + self.warnings_manager = warnings.catch_warnings() + self.warnings_manager.__enter__() + warnings.simplefilter(self.action, **self.filter_kwargs) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.warnings_manager.__exit__(exc_type, exc_val, exc_tb) + + +def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str): + spec = importlib.util.spec_from_file_location(module_name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/vis.py b/Pixel-Perfect-Depth/moge/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..cb9c2378b58ec26ac5067b7ffcbd749a8ad968ce --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/vis.py @@ -0,0 +1,65 @@ +from typing import * + +import numpy as np +import matplotlib + + +def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is None: + depth = np.where(depth > 0, depth, np.nan) + else: + depth = np.where((depth > 0) & mask, depth, np.nan) + disp = 1 / depth + if normalize: + min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99) + disp = (disp - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + depth = np.where(mask, depth, np.nan) + + min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999) + depth = (depth - min_depth) / (max_depth - min_depth) + colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + disparity = np.where(mask, disparity, np.nan) + + if normalize: + min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999) + disparity = (disparity - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0) + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray: + colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)[..., :3] + colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8)) + return colored + + +def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + if mask is not None: + normal = np.where(mask[..., None], normal, 0) + normal = normal * [0.5, -0.5, -0.5] + 0.5 + normal = (normal.clip(0, 1) * 255).astype(np.uint8) + return normal + + +def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None): + vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map)) + cmap = matplotlib.colormaps[cmap] + colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3] + if mask is not None: + colorized_error_map = np.where(mask[..., None], colorized_error_map, 0) + colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8)) + return colorized_error_map diff --git a/Pixel-Perfect-Depth/moge/utils/webfile.py b/Pixel-Perfect-Depth/moge/utils/webfile.py new file mode 100644 index 0000000000000000000000000000000000000000..1e98abf8413e1c9f408849b74f4d2025d25511b6 --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/webfile.py @@ -0,0 +1,73 @@ +import requests +from typing import * + +__all__ = ["WebFile"] + + +class WebFile: + def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None): + self.url = url + self.session = session or requests.Session() + self.session.headers.update(headers or {}) + self._offset = 0 + self.size = size if size is not None else self._fetch_size() + + def _fetch_size(self): + with self.session.get(self.url, stream=True) as response: + response.raise_for_status() + content_length = response.headers.get("Content-Length") + if content_length is None: + raise ValueError("Missing Content-Length in header") + return int(content_length) + + def _fetch_data(self, offset: int, n: int) -> bytes: + headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"} + response = self.session.get(self.url, headers=headers) + response.raise_for_status() + return response.content + + def seekable(self) -> bool: + return True + + def tell(self) -> int: + return self._offset + + def available(self) -> int: + return self.size - self._offset + + def seek(self, offset: int, whence: int = 0) -> None: + if whence == 0: + new_offset = offset + elif whence == 1: + new_offset = self._offset + offset + elif whence == 2: + new_offset = self.size + offset + else: + raise ValueError("Invalid value for whence") + + self._offset = max(0, min(new_offset, self.size)) + + def read(self, n: Optional[int] = None) -> bytes: + if n is None or n < 0: + n = self.available() + else: + n = min(n, self.available()) + + if n == 0: + return b'' + + data = self._fetch_data(self._offset, n) + self._offset += len(data) + + return data + + def close(self) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + \ No newline at end of file diff --git a/Pixel-Perfect-Depth/moge/utils/webzipfile.py b/Pixel-Perfect-Depth/moge/utils/webzipfile.py new file mode 100644 index 0000000000000000000000000000000000000000..25ed1d3cd34720335eb001d77a278539ffef569b --- /dev/null +++ b/Pixel-Perfect-Depth/moge/utils/webzipfile.py @@ -0,0 +1,128 @@ +from typing import * +import io +import os +from zipfile import ( + ZipInfo, BadZipFile, ZipFile, ZipExtFile, + sizeFileHeader, structFileHeader, stringFileHeader, + _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS, + _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED +) +import struct +from requests import Session + +from .webfile import WebFile + + +class _SharedWebFile(WebFile): + def __init__(self, webfile: WebFile, pos: int): + super().__init__(webfile.url, webfile.session, size=webfile.size) + self.seek(pos) + + +class WebZipFile(ZipFile): + "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads." + def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None): + """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x', + or append 'a'.""" + webf = WebFile(url, session=session, headers=headers) + super().__init__(webf, mode='r') + + def open(self, name, mode="r", pwd=None, *, force_zip64=False): + """Return file-like object for 'name'. + + name is a string for the file name within the ZIP file, or a ZipInfo + object. + + mode should be 'r' to read a file already in the ZIP file, or 'w' to + write to a file newly added to the archive. + + pwd is the password to decrypt files (only used for reading). + + When writing, if the file size is not known in advance but may exceed + 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large + files. If the size is known in advance, it is best to pass a ZipInfo + instance for name, with zinfo.file_size set. + """ + if mode not in {"r", "w"}: + raise ValueError('open() requires mode "r" or "w"') + if pwd and (mode == "w"): + raise ValueError("pwd is only supported for reading files") + if not self.fp: + raise ValueError( + "Attempt to use ZIP archive that was already closed") + + assert mode == "r", "Only read mode is supported for now" + + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name + elif mode == 'w': + zinfo = ZipInfo(name) + zinfo.compress_type = self.compression + zinfo._compresslevel = self.compresslevel + else: + # Get info object for name + zinfo = self.getinfo(name) + + if mode == 'w': + return self._open_to_write(zinfo, force_zip64=force_zip64) + + if self._writing: + raise ValueError("Can't read from the ZIP file while there " + "is an open writing handle on it. " + "Close the writing handle before trying to read.") + + # Open for reading: + self._fileRefCnt += 1 + zef_file = _SharedWebFile(self.fp, zinfo.header_offset) + + try: + # Skip the file header: + fheader = zef_file.read(sizeFileHeader) + if len(fheader) != sizeFileHeader: + raise BadZipFile("Truncated file header") + fheader = struct.unpack(structFileHeader, fheader) + if fheader[_FH_SIGNATURE] != stringFileHeader: + raise BadZipFile("Bad magic number for file header") + + fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1) + + if zinfo.flag_bits & _MASK_COMPRESSED_PATCH: + # Zip 2.7: compressed patched data + raise NotImplementedError("compressed patched data (flag bit 5)") + + if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION: + # strong encryption + raise NotImplementedError("strong encryption (flag bit 6)") + + if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME: + # UTF-8 filename + fname_str = fname.decode("utf-8") + else: + fname_str = fname.decode(self.metadata_encoding or "cp437") + + if fname_str != zinfo.orig_filename: + raise BadZipFile( + 'File name in directory %r and header %r differ.' + % (zinfo.orig_filename, fname)) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED + if is_encrypted: + if not pwd: + pwd = self.pwd + if pwd and not isinstance(pwd, bytes): + raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__) + if not pwd: + raise RuntimeError("File %r is encrypted, password " + "required for extraction" % name) + else: + pwd = None + + return ZipExtFile(zef_file, mode, zinfo, pwd, True) + except: + zef_file.close() + raise \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/models/attention.py b/Pixel-Perfect-Depth/ppd/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c84aee905934d05ea07656a30f4f6b7d4832744a --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/attention.py @@ -0,0 +1,59 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +class Attention(nn.Module): + + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + qk_norm: bool = False, + rope=None, + fused_attn: bool = True, # use F.scaled_dot_product_attention or not + attn_drop: float = 0., + proj_drop: float = 0., + norm_layer: nn.Module = nn.LayerNorm, + ) -> None: + super().__init__() + assert dim % num_heads == 0, 'dim should be divisible by num_heads' + self.num_heads = num_heads + self.head_dim = dim // num_heads + self.scale = self.head_dim ** -0.5 + self.fused_attn = fused_attn + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.rope = rope + + def forward(self, x: torch.Tensor, pos=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4) + q, k, v = qkv.unbind(0) + q, k = self.q_norm(q), self.k_norm(k) + + if self.rope is not None: + q = self.rope(q, pos) + k = self.rope(k, pos) + + if self.fused_attn: + x = F.scaled_dot_product_attention( + q, k, v, + dropout_p=self.attn_drop.p if self.training else 0., + ) + else: + q = q * self.scale + attn = q @ k.transpose(-2, -1) + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = attn @ v + + x = x.transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..e12ddffdbf922505f5a99cfaa655c10685b47ee0 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2.py @@ -0,0 +1,416 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + # we add a small number to avoid floating point error in the interpolation + # see discussion at https://github.com/facebookresearch/dino/issues/8 + # DINOv2 with register modify the interpolate_offset from 0.1 to 0.0 + w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset + # w0, h0 = w0 + 0.1, h0 + 0.1 + + sqrt_N = math.sqrt(N) + sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2), + scale_factor=(sx, sy), + # (int(w0), int(h0)), # to solve the upsampling shape issue + mode="bicubic", + antialias=self.interpolate_antialias + ) + + assert int(w0) == patch_pos_embed.shape[-2] + assert int(h0) == patch_pos_embed.shape[-1] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def DINOv2(model_name): + model_zoo = { + "vits": vit_small, + "vitb": vit_base, + "vitl": vit_large, + "vitg": vit_giant2 + } + + return model_zoo[model_name]( + img_size=518, + patch_size=14, + init_values=1.0, + ffn_layer="mlp" if model_name != "vitg" else "swiglufused", + block_chunks=0, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1 + ) diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/__init__.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8120f4bc83066cb3f825ce32daa3b437f88486f1 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/attention.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..815a2bf53dbec496f6a184ed7d03bcecb7124262 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/attention.py @@ -0,0 +1,83 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import memory_efficient_attention, unbind, fmha + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/block.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..25488f57cc0ad3c692f86b62555f6668e2a66db1 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/block.py @@ -0,0 +1,252 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Callable, List, Any, Tuple, Dict + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha + from xformers.ops import scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/drop_path.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..af05625984dd14682cc96a63bf0c97bab1f123b1 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/drop_path.py @@ -0,0 +1,35 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/layer_scale.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..ca5daa52bd81d3581adeb2198ea5b7dba2a3aea1 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/mlp.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5e4b315f972f9a9f54aef1e4ef4e81b52976f018 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/patch_embed.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..574abe41175568d700a389b8b96d1ba554914779 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/patch_embed.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/swiglu_ffn.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3324b266fb0a50ccf8c3a0ede2ae10ac4dfa03e --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dinov2_layers/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dpt.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dpt.py new file mode 100644 index 0000000000000000000000000000000000000000..43a0cbd89d9dcc2083757f8d39f04d0ef688c961 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/dpt.py @@ -0,0 +1,227 @@ +import cv2 +import torch +import torch.nn as nn +import torch.nn.functional as F +from torchvision.transforms import Compose + +from .dinov2 import DINOv2 +from .util.blocks import FeatureFusionBlock, _make_scratch +from .util.transform import Resize, NormalizeImage, PrepareForNet +import math + + +def _make_fusion_block(features, use_bn, size=None): + return FeatureFusionBlock( + features, + nn.ReLU(False), + deconv=False, + bn=use_bn, + expand=False, + align_corners=True, + size=size, + ) + + +class ConvBlock(nn.Module): + def __init__(self, in_feature, out_feature): + super().__init__() + + self.conv_block = nn.Sequential( + nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1), + nn.BatchNorm2d(out_feature), + nn.ReLU(True) + ) + + def forward(self, x): + return self.conv_block(x) + + +class DPTHead(nn.Module): + def __init__( + self, + in_channels, + features=256, + use_bn=False, + out_channels=[256, 512, 1024, 1024], + use_clstoken=False + ): + super(DPTHead, self).__init__() + + self.use_clstoken = use_clstoken + + self.projects = nn.ModuleList([ + nn.Conv2d( + in_channels=in_channels, + out_channels=out_channel, + kernel_size=1, + stride=1, + padding=0, + ) for out_channel in out_channels + ]) + + self.resize_layers = nn.ModuleList([ + nn.ConvTranspose2d( + in_channels=out_channels[0], + out_channels=out_channels[0], + kernel_size=4, + stride=4, + padding=0), + nn.ConvTranspose2d( + in_channels=out_channels[1], + out_channels=out_channels[1], + kernel_size=2, + stride=2, + padding=0), + nn.Identity(), + nn.Conv2d( + in_channels=out_channels[3], + out_channels=out_channels[3], + kernel_size=3, + stride=2, + padding=1) + ]) + + if use_clstoken: + self.readout_projects = nn.ModuleList() + for _ in range(len(self.projects)): + self.readout_projects.append( + nn.Sequential( + nn.Linear(2 * in_channels, in_channels), + nn.GELU())) + + self.scratch = _make_scratch( + out_channels, + features, + groups=1, + expand=False, + ) + + self.scratch.stem_transpose = None + + self.scratch.refinenet1 = _make_fusion_block(features, use_bn) + self.scratch.refinenet2 = _make_fusion_block(features, use_bn) + self.scratch.refinenet3 = _make_fusion_block(features, use_bn) + self.scratch.refinenet4 = _make_fusion_block(features, use_bn) + + head_features_1 = features + head_features_2 = 32 + + self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1) + self.scratch.output_conv2 = nn.Sequential( + nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1), + nn.ReLU(True), + nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0), + nn.ReLU(True), + nn.Identity(), + ) + + def forward(self, out_features, patch_h, patch_w): + out = [] + for i, x in enumerate(out_features): + if self.use_clstoken: + x, cls_token = x[0], x[1] + readout = cls_token.unsqueeze(1).expand_as(x) + x = self.readout_projects[i](torch.cat((x, readout), -1)) + else: + x = x[0] + + x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w)) + x = self.projects[i](x) + x = self.resize_layers[i](x) + out.append(x) + + layer_1, layer_2, layer_3, layer_4 = out + + layer_1_rn = self.scratch.layer1_rn(layer_1) + layer_2_rn = self.scratch.layer2_rn(layer_2) + layer_3_rn = self.scratch.layer3_rn(layer_3) + layer_4_rn = self.scratch.layer4_rn(layer_4) + + path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:]) + path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:]) + # path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:]) + # path_1 = self.scratch.refinenet1(path_2, layer_1_rn) + + # out = self.scratch.output_conv1(path_1) + # out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True) + # out = self.scratch.output_conv2(out) + + return path_3.flatten(2).transpose(1, 2) + + +class DepthAnythingV2(nn.Module): + def __init__( + self, + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024], + use_bn=False, + use_clstoken=False + ): + super(DepthAnythingV2, self).__init__() + + # self.intermediate_layer_idx = { + # 'vits': [2, 5, 8, 11], + # 'vitb': [2, 5, 8, 11], + # 'vitl': [4, 11, 17, 23], + # 'vitg': [9, 19, 29, 39] + # } + + # self.encoder = encoder + self.pretrained = DINOv2(model_name=encoder) + # self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken) + + def forward(self, x): + + ori_h, ori_w = x.shape[-2:] + + mean=[0.485, 0.456, 0.406] + std=[0.229, 0.224, 0.225] + mean = torch.tensor(mean).view(1, 3, 1, 1).to(x.device) + std = torch.tensor(std).view(1, 3, 1, 1).to(x.device) + x = (x - mean) / std + + new_h = (ori_h // 16) * 14 + new_w = (ori_w // 16) * 14 + + x = F.interpolate(x, size=(new_h, new_w), mode='bicubic', align_corners=False) + + # patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14 + # features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True) + semantics = self.pretrained.forward_features(x)["x_norm_patchtokens"] + + return semantics + + @torch.no_grad() + def infer_image(self, raw_image, input_size=518): + image, (h, w) = self.image2tensor(raw_image, input_size) + depth = self.forward(image) + + depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0] + + return depth.cpu().numpy() + + def image2tensor(self, raw_image, input_size=518): + transform = Compose([ + Resize( + width=input_size, + height=input_size, + resize_target=False, + keep_aspect_ratio=True, + ensure_multiple_of=14, + resize_method='lower_bound', + image_interpolation_method=cv2.INTER_CUBIC, + ), + NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + PrepareForNet(), + ]) + h, w = raw_image.shape[:2] + image = cv2.cvtColor(raw_image, cv2.COLOR_BGR2RGB) / 255.0 + + image = transform({'image': image})['image'] + image = torch.from_numpy(image).unsqueeze(0) + + DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' + image = image.to(DEVICE) + + return image, (h, w) diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/util/blocks.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/util/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..382ea183a40264056142afffc201c992a2b01d37 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/util/blocks.py @@ -0,0 +1,148 @@ +import torch.nn as nn + + +def _make_scratch(in_shape, out_shape, groups=1, expand=False): + scratch = nn.Module() + + out_shape1 = out_shape + out_shape2 = out_shape + out_shape3 = out_shape + if len(in_shape) >= 4: + out_shape4 = out_shape + + if expand: + out_shape1 = out_shape + out_shape2 = out_shape * 2 + out_shape3 = out_shape * 4 + if len(in_shape) >= 4: + out_shape4 = out_shape * 8 + + scratch.layer1_rn = nn.Conv2d(in_shape[0], out_shape1, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer2_rn = nn.Conv2d(in_shape[1], out_shape2, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + scratch.layer3_rn = nn.Conv2d(in_shape[2], out_shape3, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + if len(in_shape) >= 4: + scratch.layer4_rn = nn.Conv2d(in_shape[3], out_shape4, kernel_size=3, stride=1, padding=1, bias=False, groups=groups) + + return scratch + + +class ResidualConvUnit(nn.Module): + """Residual convolution module. + """ + + def __init__(self, features, activation, bn): + """Init. + + Args: + features (int): number of features + """ + super().__init__() + + self.bn = bn + + self.groups=1 + + self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True, groups=self.groups) + + if self.bn == True: + self.bn1 = nn.BatchNorm2d(features) + self.bn2 = nn.BatchNorm2d(features) + + self.activation = activation + + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + """Forward pass. + + Args: + x (tensor): input + + Returns: + tensor: output + """ + + out = self.activation(x) + out = self.conv1(out) + if self.bn == True: + out = self.bn1(out) + + out = self.activation(out) + out = self.conv2(out) + if self.bn == True: + out = self.bn2(out) + + if self.groups > 1: + out = self.conv_merge(out) + + return self.skip_add.add(out, x) + + +class FeatureFusionBlock(nn.Module): + """Feature fusion block. + """ + + def __init__( + self, + features, + activation, + deconv=False, + bn=False, + expand=False, + align_corners=True, + size=None + ): + """Init. + + Args: + features (int): number of features + """ + super(FeatureFusionBlock, self).__init__() + + self.deconv = deconv + self.align_corners = align_corners + + self.groups=1 + + self.expand = expand + out_features = features + if self.expand == True: + out_features = features // 2 + + self.out_conv = nn.Conv2d(features, out_features, kernel_size=1, stride=1, padding=0, bias=True, groups=1) + + self.resConfUnit1 = ResidualConvUnit(features, activation, bn) + self.resConfUnit2 = ResidualConvUnit(features, activation, bn) + + self.skip_add = nn.quantized.FloatFunctional() + + self.size=size + + def forward(self, *xs, size=None): + """Forward pass. + + Returns: + tensor: output + """ + output = xs[0] + + if len(xs) == 2: + res = self.resConfUnit1(xs[1]) + output = self.skip_add.add(output, res) + + output = self.resConfUnit2(output) + + if (size is None) and (self.size is None): + modifier = {"scale_factor": 2} + elif size is None: + modifier = {"size": self.size} + else: + modifier = {"size": size} + + output = nn.functional.interpolate(output, **modifier, mode="bilinear", align_corners=self.align_corners) + + output = self.out_conv(output) + + return output diff --git a/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/util/transform.py b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/util/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..b14aacd44ea086b01725a9ca68bb49eadcf37d73 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/depth_anything_v2/util/transform.py @@ -0,0 +1,158 @@ +import numpy as np +import cv2 + + +class Resize(object): + """Resize sample to given size (width, height). + """ + + def __init__( + self, + width, + height, + resize_target=True, + keep_aspect_ratio=False, + ensure_multiple_of=1, + resize_method="lower_bound", + image_interpolation_method=cv2.INTER_AREA, + ): + """Init. + + Args: + width (int): desired output width + height (int): desired output height + resize_target (bool, optional): + True: Resize the full sample (image, mask, target). + False: Resize image only. + Defaults to True. + keep_aspect_ratio (bool, optional): + True: Keep the aspect ratio of the input sample. + Output sample might not have the given width and height, and + resize behaviour depends on the parameter 'resize_method'. + Defaults to False. + ensure_multiple_of (int, optional): + Output width and height is constrained to be multiple of this parameter. + Defaults to 1. + resize_method (str, optional): + "lower_bound": Output will be at least as large as the given size. + "upper_bound": Output will be at max as large as the given size. (Output size might be smaller than given size.) + "minimal": Scale as least as possible. (Output size might be smaller than given size.) + Defaults to "lower_bound". + """ + self.__width = width + self.__height = height + + self.__resize_target = resize_target + self.__keep_aspect_ratio = keep_aspect_ratio + self.__multiple_of = ensure_multiple_of + self.__resize_method = resize_method + self.__image_interpolation_method = image_interpolation_method + + def constrain_to_multiple_of(self, x, min_val=0, max_val=None): + y = (np.round(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if max_val is not None and y > max_val: + y = (np.floor(x / self.__multiple_of) * self.__multiple_of).astype(int) + + if y < min_val: + y = (np.ceil(x / self.__multiple_of) * self.__multiple_of).astype(int) + + return y + + def get_size(self, width, height): + # determine new height and width + scale_height = self.__height / height + scale_width = self.__width / width + + if self.__keep_aspect_ratio: + if self.__resize_method == "lower_bound": + # scale such that output size is lower bound + if scale_width > scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "upper_bound": + # scale such that output size is upper bound + if scale_width < scale_height: + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + elif self.__resize_method == "minimal": + # scale as least as possbile + if abs(1 - scale_width) < abs(1 - scale_height): + # fit width + scale_height = scale_width + else: + # fit height + scale_width = scale_height + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + if self.__resize_method == "lower_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, min_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, min_val=self.__width) + elif self.__resize_method == "upper_bound": + new_height = self.constrain_to_multiple_of(scale_height * height, max_val=self.__height) + new_width = self.constrain_to_multiple_of(scale_width * width, max_val=self.__width) + elif self.__resize_method == "minimal": + new_height = self.constrain_to_multiple_of(scale_height * height) + new_width = self.constrain_to_multiple_of(scale_width * width) + else: + raise ValueError(f"resize_method {self.__resize_method} not implemented") + + return (new_width, new_height) + + def __call__(self, sample): + width, height = self.get_size(sample["image"].shape[1], sample["image"].shape[0]) + + # resize sample + sample["image"] = cv2.resize(sample["image"], (width, height), interpolation=self.__image_interpolation_method) + + if self.__resize_target: + if "depth" in sample: + sample["depth"] = cv2.resize(sample["depth"], (width, height), interpolation=cv2.INTER_NEAREST) + + if "mask" in sample: + sample["mask"] = cv2.resize(sample["mask"].astype(np.float32), (width, height), interpolation=cv2.INTER_NEAREST) + + return sample + + +class NormalizeImage(object): + """Normlize image by given mean and std. + """ + + def __init__(self, mean, std): + self.__mean = mean + self.__std = std + + def __call__(self, sample): + sample["image"] = (sample["image"] - self.__mean) / self.__std + + return sample + + +class PrepareForNet(object): + """Prepare sample for usage as network input. + """ + + def __init__(self): + pass + + def __call__(self, sample): + image = np.transpose(sample["image"], (2, 0, 1)) + sample["image"] = np.ascontiguousarray(image).astype(np.float32) + + if "depth" in sample: + depth = sample["depth"].astype(np.float32) + sample["depth"] = np.ascontiguousarray(depth) + + if "mask" in sample: + sample["mask"] = sample["mask"].astype(np.float32) + sample["mask"] = np.ascontiguousarray(sample["mask"]) + + return sample \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/models/dit.py b/Pixel-Perfect-Depth/ppd/models/dit.py new file mode 100644 index 0000000000000000000000000000000000000000..6f6e5c90882439823bb581d2f1c6537ce250bb2a --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/dit.py @@ -0,0 +1,234 @@ +import math +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from .patch_embed import PatchEmbed +from .mlp import Mlp +from .attention import Attention +from .rope import RotaryPositionEmbedding2D, PositionGetter + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, rope=None, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, rope=rope, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0 + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c, pos=None): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( + c + ).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa), pos=pos) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Cascade diffusion model with a transformer backbone. + """ + + def __init__( + self, + in_channels=4, + out_channels=1, + hidden_size=1024, + depth=24, + num_heads=16, + mlp_ratio=4.0, + ): + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.num_heads = num_heads + + rope_freq = 100 + self.rope = RotaryPositionEmbedding2D(frequency=rope_freq) if rope_freq > 0 else None + self.position_getter = PositionGetter() if self.rope is not None else None + + self.x_embedder = PatchEmbed(in_chans=in_channels, embed_dim=hidden_size) + self.t_embedder = TimestepEmbedder(hidden_size) + + self.blocks = nn.ModuleList( + [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio, rope=self.rope) for _ in range(depth)] + ) + + self.proj_fusion = nn.Sequential( + nn.Linear(hidden_size*2, hidden_size*4), + nn.SiLU(), + nn.Linear(hidden_size*4, hidden_size*4), + nn.SiLU(), + nn.Linear(hidden_size*4, hidden_size*4), + ) + + self.final_layer = FinalLayer(hidden_size, 8, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x, height, width): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = 8 + h = height // p + w = width // p + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def forward(self, x=None, semantics=None, timestep=None, dropout=0.1): + """ + Forward pass of SP-DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + """ + + N, C, H, W = x.shape + if len(timestep.shape) == 0: + timestep = timestep[None] + + pos0 = None + pos1 = None + if self.rope is not None: + pos0 = self.position_getter(N, H // 16, W // 16, device=x.device) + pos1 = self.position_getter(N, H // 8, W // 8, device=x.device) + + x = self.x_embedder(x) + N, T, D = x.shape + t = self.t_embedder(timestep) # (N, D) + + # for block in self.blocks: + for i, block in enumerate(self.blocks): + if i < 12: + x = block(x, t, pos0) # (N, T, D) + else: + x = block(x, t, pos1) # (N, T, D) + + if i == 11: + + semantics = F.normalize(semantics, dim=-1) + x = self.proj_fusion(torch.cat([x, semantics], dim=-1)) + p = 16 + x = x.reshape(shape=(N, H//p, W//p, 2, 2, D)) + x = torch.einsum("nhwpqc->nchpwq", x) + x = x.reshape(shape=(N, D, (H//p)*2, (W//p)*2)) + x = x.flatten(2).transpose(1, 2) + + x = self.final_layer(x, t) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x, height=H, width=W) # (N, out_channels, H, W) + return x + diff --git a/Pixel-Perfect-Depth/ppd/models/dit_wo_rope.py b/Pixel-Perfect-Depth/ppd/models/dit_wo_rope.py new file mode 100644 index 0000000000000000000000000000000000000000..2cd8da536c50175b1e32d216c636cf507324e77d --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/dit_wo_rope.py @@ -0,0 +1,376 @@ +import math +from typing import NamedTuple +import numpy as np +import torch +import torch.nn as nn +from timm.models.vision_transformer import Attention, PatchEmbed +import torch.nn.functional as F +from timm.layers import resample_abs_pos_embed + +from .mlp import Mlp + + +class DitOutput(NamedTuple): + sample: torch.Tensor + +def build_mlp(hidden_size, projector_dim, z_dim): + return nn.Sequential( + nn.Linear(hidden_size, projector_dim), + nn.SiLU(), + nn.Linear(projector_dim, projector_dim), + nn.SiLU(), + nn.Linear(projector_dim, z_dim), + ) + + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + + +################################################################################# +# Embedding Layers for Timesteps and Class Labels # +################################################################################# + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_emb = self.mlp(t_freq) + return t_emb + + +# class LabelEmbedder(nn.Module): +# """ +# Embeds class labels into vector representations. Also handles label dropout for cfg. +# """ + +# def __init__(self, num_classes, hidden_size, use_cfg_embedding): +# super().__init__() +# self.embedding_table = nn.Embedding(num_classes + use_cfg_embedding, hidden_size) +# self.num_classes = num_classes + +# def token_drop(self, labels, dropout_prob, force_drop_ids=None): +# """ +# Drops labels to enable classifier-free guidance. +# """ +# if force_drop_ids is None: +# drop_ids = torch.rand(labels.shape[0], device=labels.device) < dropout_prob +# else: +# drop_ids = force_drop_ids == 1 +# labels = torch.where(drop_ids, self.num_classes, labels) +# return labels + +# def forward(self, labels, dropout_prob=0.1, force_drop_ids=None): +# if dropout_prob > 0: +# labels = self.token_drop(labels, dropout_prob, force_drop_ids) +# embeddings = self.embedding_table(labels) +# return embeddings + + +################################################################################# +# Core DiT Model # +################################################################################# + + +class DiTBlock(nn.Module): + """ + A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. + """ + + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, qk_norm=True, **block_kwargs) + self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + mlp_hidden_dim = int(hidden_size * mlp_ratio) + approx_gelu = nn.GELU(approximate="tanh") + self.mlp = Mlp( + in_features=hidden_size, hidden_features=mlp_hidden_dim, act_layer=approx_gelu, drop=0 + ) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 6 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation( + c + ).chunk(6, dim=1) + x = x + gate_msa.unsqueeze(1) * self.attn(modulate(self.norm1(x), shift_msa, scale_msa)) + x = x + gate_mlp.unsqueeze(1) * self.mlp(modulate(self.norm2(x), shift_mlp, scale_mlp)) + return x + + +class FinalLayer(nn.Module): + """ + The final layer of DiT. + """ + + def __init__(self, hidden_size, patch_size, out_channels): + super().__init__() + self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) + self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) + self.adaLN_modulation = nn.Sequential( + nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True) + ) + + def forward(self, x, c): + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x = modulate(self.norm_final(x), shift, scale) + x = self.linear(x) + return x + + +class DiT(nn.Module): + """ + Diffusion model with a Transformer backbone. + """ + + def __init__( + self, + input_size=32, + patch_size=2, + in_channels=4, + out_channels=4, + hidden_size=1152, + depth=28, + num_heads=16, + mlp_ratio=4.0, + use_cfg_embedding=True, + num_classes=1000, + learn_sigma=True, + ): + super().__init__() + self.learn_sigma = learn_sigma + self.in_channels = in_channels + self.out_channels = out_channels * 2 if learn_sigma else out_channels + self.patch_size = patch_size + self.num_heads = num_heads + self.input_size = input_size + + self.x_embedder = PatchEmbed(input_size, patch_size*2, in_channels, hidden_size, bias=True) + self.t_embedder = TimestepEmbedder(hidden_size) + # self.y_embedder = LabelEmbedder(num_classes, hidden_size, use_cfg_embedding) + num_patches = self.x_embedder.num_patches + # Will use fixed sin-cos embedding: + num_patches = (512//16) ** 2 + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches, hidden_size), requires_grad=False + ) + + self.blocks = nn.ModuleList( + [DiTBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)] + ) + # self.projector = build_mlp(hidden_size, projector_dim=2048, z_dim=1024) + # self.mlp_fusion = nn.Sequential( + # nn.Linear(hidden_size*2, hidden_size), + # nn.SiLU(), + # nn.Linear(hidden_size, hidden_size), + # ) + self.proj_fusion = nn.Sequential( + nn.Linear(hidden_size*2, hidden_size*4), + nn.SiLU(), + nn.Linear(hidden_size*4, hidden_size*4), + nn.SiLU(), + nn.Linear(hidden_size*4, hidden_size*4), + ) + + # self.proj_fusion_ = nn.Sequential( + # nn.Linear(hidden_size*2, hidden_size*4), + # nn.SiLU(), + # ) + self.final_layer = FinalLayer(hidden_size, patch_size, self.out_channels) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize (and freeze) pos_embed by sin-cos embedding: + pos_embed = get_2d_sincos_pos_embed( + self.pos_embed.shape[-1], + (512//16, 512//16) + ) + self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) + + # Initialize patch_embed like nn.Linear (instead of nn.Conv2d): + w = self.x_embedder.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.x_embedder.proj.bias, 0) + + # Initialize label embedding table: + # nn.init.normal_(self.y_embedder.embedding_table.weight, std=0.02) + + # Initialize timestep embedding MLP: + nn.init.normal_(self.t_embedder.mlp[0].weight, std=0.02) + nn.init.normal_(self.t_embedder.mlp[2].weight, std=0.02) + + # Zero-out adaLN modulation layers in DiT blocks: + for block in self.blocks: + nn.init.constant_(block.adaLN_modulation[-1].weight, 0) + nn.init.constant_(block.adaLN_modulation[-1].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.final_layer.adaLN_modulation[-1].weight, 0) + nn.init.constant_(self.final_layer.adaLN_modulation[-1].bias, 0) + nn.init.constant_(self.final_layer.linear.weight, 0) + nn.init.constant_(self.final_layer.linear.bias, 0) + + def unpatchify(self, x, height, width): + """ + x: (N, T, patch_size**2 * C) + imgs: (N, H, W, C) + """ + c = self.out_channels + p = self.x_embedder.patch_size[0] // 2 + h = height // p + w = width // p + assert h * w == x.shape[1] + + x = x.reshape(shape=(x.shape[0], h, w, p, p, c)) + x = torch.einsum("nhwpqc->nchpwq", x) + imgs = x.reshape(shape=(x.shape[0], c, h * p, w * p)) + return imgs + + def forward(self, x=None, z_latent=None, timestep=None, label=None, dropout=0.1): + """ + Forward pass of DiT. + x: (N, C, H, W) tensor of spatial inputs (images or latent representations of images) + t: (N,) tensor of diffusion timesteps + y: (N,) tensor of class labels + """ + # if cfg_scale > 1.0: + # half = sample[: len(x) // 2] + # sample = torch.cat([half, half], dim=0) + N, C, H, W = x.shape + if len(timestep.shape) == 0: + timestep = timestep[None] + + x = self.x_embedder(x) + self.pos_embed # (N, T, D), where T=H*W/patch_size ** 2 + N, T, D = x.shape + timestep = self.t_embedder(timestep) # (N, D) + c = timestep # + label # (N, D) + + # for block in self.blocks: + for i, block in enumerate(self.blocks): + x = block(x, c) # (N, T, D) + if (i+1) == 12: + + z_latent = F.normalize(z_latent, dim=-1) + x = self.proj_fusion(torch.cat([x, z_latent], dim=-1)) + p = self.x_embedder.patch_size[0] + x = x.reshape(shape=(N, H//p, W//p, 2, 2, D)) + x = torch.einsum("nhwpqc->nchpwq", x) + x = x.reshape(shape=(N, D, (H//p)*2, (W//p)*2)) + x = x.flatten(2).transpose(1, 2) + + x = self.final_layer(x, c) # (N, T, patch_size ** 2 * out_channels) + x = self.unpatchify(x, height=H, width=W) # (N, out_channels, H, W) + + return x + +def get_pos_embed(pos_embed, H, W): + # ζ£€ζŸ₯当前 pos_embed ηš„ shape + if pos_embed.shape[1] != (H // 16) * (W // 16): + return resample_abs_pos_embed(pos_embed, new_size=[H // 16, W // 16], num_prefix_tokens=0) + return pos_embed + + +################################################################################# +# Sine/Cosine Positional Embedding Functions # +################################################################################# +# https://github.com/facebookresearch/mae/blob/main/util/pos_embed.py + + +def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False, extra_tokens=0): + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] + """ + + if isinstance(grid_size, int): + h, w = grid_size, grid_size + else: + h, w = grid_size + grid_h = np.arange(h, dtype=np.float32) + grid_w = np.arange(w, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + + grid = grid.reshape([2, 1, h, w]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid) + if cls_token and extra_tokens > 0: + pos_embed = np.concatenate([np.zeros([extra_tokens, embed_dim]), pos_embed], axis=0) + return pos_embed + + +def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) + + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) + out: (M, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + return emb diff --git a/Pixel-Perfect-Depth/ppd/models/mlp.py b/Pixel-Perfect-Depth/ppd/models/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..f73209fb401c1ec8c51bef3b05c102fc127cf87f --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/mlp.py @@ -0,0 +1,261 @@ +""" MLP module w/ dropout and configurable activation layer + +Hacked together by / Copyright 2020 Ross Wightman +""" + +from functools import partial +from timm.layers.grn import GlobalResponseNorm +from timm.layers.helpers import to_2tuple +from torch import nn as nn + + +class Mlp(nn.Module): + """MLP as used in Vision Transformer, MLP-Mixer and related networks""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GluMlp(nn.Module): + """MLP w/ GLU style gating + See: https://arxiv.org/abs/1612.08083, https://arxiv.org/abs/2002.05202 + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.Sigmoid, + norm_layer=None, + bias=True, + drop=0.0, + use_conv=False, + gate_last=True, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + assert hidden_features % 2 == 0 + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + self.chunk_dim = 1 if use_conv else -1 + self.gate_last = gate_last # use second half of width for gate + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features // 2) if norm_layer is not None else nn.Identity() + self.fc2 = linear_layer(hidden_features // 2, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + fc1_mid = self.fc1.bias.shape[0] // 2 + nn.init.ones_(self.fc1.bias[fc1_mid:]) + nn.init.normal_(self.fc1.weight[fc1_mid:], std=1e-6) + + def forward(self, x): + x = self.fc1(x) + x1, x2 = x.chunk(2, dim=self.chunk_dim) + x = x1 * self.act(x2) if self.gate_last else self.act(x1) * x2 + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +SwiGLUPacked = partial(GluMlp, act_layer=nn.SiLU, gate_last=False) + + +class SwiGLU(nn.Module): + """SwiGLU + NOTE: GluMLP above can implement SwiGLU, but this impl has split fc1 and + better matches some other common impl which makes mapping checkpoints simpler. + """ + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.SiLU, + norm_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1_g = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.fc1_x = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def init_weights(self): + # override init of fc1 w/ gate portion set to weight near zero, bias=1 + nn.init.ones_(self.fc1_g.bias) + nn.init.normal_(self.fc1_g.weight, std=1e-6) + + def forward(self, x): + x_gate = self.fc1_g(x) + x = self.fc1_x(x) + x = self.act(x_gate) * x + x = self.drop1(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class GatedMlp(nn.Module): + """MLP as used in gMLP""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + norm_layer=None, + gate_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + if gate_layer is not None: + assert hidden_features % 2 == 0 + self.gate = gate_layer(hidden_features) + hidden_features = hidden_features // 2 # FIXME base reduction on gate property? + else: + self.gate = nn.Identity() + self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.gate(x) + x = self.norm(x) + x = self.fc2(x) + x = self.drop2(x) + return x + + +class ConvMlp(nn.Module): + """MLP using 1x1 convs that keeps spatial dims""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.ReLU, + norm_layer=None, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + + self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1, bias=bias[0]) + self.norm = norm_layer(hidden_features) if norm_layer else nn.Identity() + self.act = act_layer() + self.drop = nn.Dropout(drop) + self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.norm(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + return x + + +class GlobalResponseNormMlp(nn.Module): + """MLP w/ Global Response Norm (see grn.py), nn.Linear or 1x1 Conv2d""" + + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0.0, + use_conv=False, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + drop_probs = to_2tuple(drop) + linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop_probs[0]) + self.grn = GlobalResponseNorm(hidden_features, channels_last=not use_conv) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + self.drop2 = nn.Dropout(drop_probs[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.grn(x) + x = self.fc2(x) + x = self.drop2(x) + return x diff --git a/Pixel-Perfect-Depth/ppd/models/patch_embed.py b/Pixel-Perfect-Depth/ppd/models/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..f26830f45bcf2a0626874cb99dd3a729e2af22d5 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/patch_embed.py @@ -0,0 +1,86 @@ +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/Pixel-Perfect-Depth/ppd/models/ppd.py b/Pixel-Perfect-Depth/ppd/models/ppd.py new file mode 100644 index 0000000000000000000000000000000000000000..79da9e848da71d24c15f653f6c12ab85eb29355c --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/ppd.py @@ -0,0 +1,102 @@ +from PIL import Image +import numpy as np +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import cv2 +import random +from huggingface_hub import hf_hub_download +from ppd.utils.timesteps import Timesteps +from ppd.utils.schedule import LinearSchedule +from ppd.utils.sampler import EulerSampler +from ppd.utils.transform import image2tensor, resize_1024, resize_1024_crop, resize_keep_aspect + +from ppd.models.depth_anything_v2.dpt import DepthAnythingV2 +from ppd.models.dit import DiT + +class PixelPerfectDepth(nn.Module): + def __init__( + self, + semantics_pth='checkpoints/depth_anything_v2_vitl.pth', + sampling_steps=10, + + ): + super(PixelPerfectDepth, self).__init__() + + DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = DEVICE + + self.semantics_encoder = DepthAnythingV2( + encoder='vitl', + features=256, + out_channels=[256, 512, 1024, 1024] + ) + semantics_pth = hf_hub_download( + repo_id="depth-anything/Depth-Anything-V2-Large", + filename="depth_anything_v2_vitl.pth", + repo_type="model") + self.semantics_encoder.load_state_dict(torch.load(semantics_pth, map_location='cpu'), strict=False) + self.semantics_encoder = self.semantics_encoder.to(self.device).eval() + self.dit = DiT() + + self.sampling_steps = sampling_steps + + self.schedule = LinearSchedule(T=1000) + self.sampling_timesteps = Timesteps( + T=self.schedule.T, + steps=self.sampling_steps, + device=self.device, + ) + self.sampler = EulerSampler( + schedule=self.schedule, + timesteps=self.sampling_timesteps, + prediction_type='velocity' + ) + + @torch.no_grad() + def infer_image(self, image, sampling_steps=None, use_fp16: bool = True): + h, w = image.shape[:2] + resize_image = resize_keep_aspect(image) + image = image2tensor(resize_image) + image = image.to(self.device) + + if sampling_steps is not None and sampling_steps != self.sampling_steps: + self.sampling_steps = sampling_steps + self.sampling_timesteps = Timesteps( + T=self.schedule.T, + steps=self.sampling_steps, + device=self.device, + ) + self.sampler = EulerSampler( + schedule=self.schedule, + timesteps=self.sampling_timesteps, + prediction_type='velocity' + ) + + with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=True): + depth = self.forward_test(image) + # depth = F.interpolate(depth, size=(h, w), mode='bilinear', align_corners=False)[0, 0] + + return depth.squeeze().cpu().numpy(), resize_image + + @torch.no_grad() + def forward_test(self, image): + + semantics = self.semantics_prompt(image) + cond = image - 0.5 + latent = torch.randn(size=[cond.shape[0], 1, cond.shape[2], cond.shape[3]]).to(self.device) + + for timestep in self.sampling_timesteps: + input = torch.cat([latent, cond], dim=1) + pred = self.dit(x=input, semantics=semantics, timestep=timestep) + latent = self.sampler.step(pred=pred, x_t=latent, t=timestep) + + return latent + 0.5 + + + @torch.no_grad() + def semantics_prompt(self, image): + with torch.no_grad(): + semantics = self.semantics_encoder(image) + return semantics diff --git a/Pixel-Perfect-Depth/ppd/models/rope.py b/Pixel-Perfect-Depth/ppd/models/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..75398521f55711ceb741f8d5a669606a2bea1c70 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/models/rope.py @@ -0,0 +1,186 @@ +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +# Implementation of 2D Rotary Position Embeddings (RoPE). + +# This module provides a clean implementation of 2D Rotary Position Embeddings, +# which extends the original RoPE concept to handle 2D spatial positions. + +# Inspired by: +# https://github.com/meta-llama/codellama/blob/main/llama/model.py +# https://github.com/naver-ai/rope-vit + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from typing import Dict, Tuple + + +class PositionGetter: + """Generates and caches 2D spatial positions for patches in a grid. + + This class efficiently manages the generation of spatial coordinates for patches + in a 2D grid, caching results to avoid redundant computations. + + Attributes: + position_cache: Dictionary storing precomputed position tensors for different + grid dimensions. + """ + + def __init__(self): + """Initializes the position generator with an empty cache.""" + self.position_cache: Dict[Tuple[int, int], torch.Tensor] = {} + + def __call__(self, batch_size: int, height: int, width: int, device: torch.device) -> torch.Tensor: + """Generates spatial positions for a batch of patches. + + Args: + batch_size: Number of samples in the batch. + height: Height of the grid in patches. + width: Width of the grid in patches. + device: Target device for the position tensor. + + Returns: + Tensor of shape (batch_size, height*width, 2) containing y,x coordinates + for each position in the grid, repeated for each batch item. + """ + if (height, width) not in self.position_cache: + y_coords = torch.arange(height, device=device) + x_coords = torch.arange(width, device=device) + positions = torch.cartesian_prod(y_coords, x_coords) + self.position_cache[height, width] = positions + + cached_positions = self.position_cache[height, width] + return cached_positions.view(1, height * width, 2).expand(batch_size, -1, -1).clone() + + +class RotaryPositionEmbedding2D(nn.Module): + """2D Rotary Position Embedding implementation. + + This module applies rotary position embeddings to input tokens based on their + 2D spatial positions. It handles the position-dependent rotation of features + separately for vertical and horizontal dimensions. + + Args: + frequency: Base frequency for the position embeddings. Default: 100.0 + scaling_factor: Scaling factor for frequency computation. Default: 1.0 + + Attributes: + base_frequency: Base frequency for computing position embeddings. + scaling_factor: Factor to scale the computed frequencies. + frequency_cache: Cache for storing precomputed frequency components. + """ + + def __init__(self, frequency: float = 100.0, scaling_factor: float = 1.0): + """Initializes the 2D RoPE module.""" + super().__init__() + self.base_frequency = frequency + self.scaling_factor = scaling_factor + self.frequency_cache: Dict[Tuple, Tuple[torch.Tensor, torch.Tensor]] = {} + + def _compute_frequency_components( + self, dim: int, seq_len: int, device: torch.device, dtype: torch.dtype + ) -> Tuple[torch.Tensor, torch.Tensor]: + """Computes frequency components for rotary embeddings. + + Args: + dim: Feature dimension (must be even). + seq_len: Maximum sequence length. + device: Target device for computations. + dtype: Data type for the computed tensors. + + Returns: + Tuple of (cosine, sine) tensors for frequency components. + """ + cache_key = (dim, seq_len, device, dtype) + if cache_key not in self.frequency_cache: + # Compute frequency bands + exponents = torch.arange(0, dim, 2, device=device).float() / dim + inv_freq = 1.0 / (self.base_frequency**exponents) + + # Generate position-dependent frequencies + positions = torch.arange(seq_len, device=device, dtype=inv_freq.dtype) + angles = torch.einsum("i,j->ij", positions, inv_freq) + + # Compute and cache frequency components + angles = angles.to(dtype) + angles = torch.cat((angles, angles), dim=-1) + cos_components = angles.cos().to(dtype) + sin_components = angles.sin().to(dtype) + self.frequency_cache[cache_key] = (cos_components, sin_components) + + return self.frequency_cache[cache_key] + + @staticmethod + def _rotate_features(x: torch.Tensor) -> torch.Tensor: + """Performs feature rotation by splitting and recombining feature dimensions. + + Args: + x: Input tensor to rotate. + + Returns: + Rotated feature tensor. + """ + feature_dim = x.shape[-1] + x1, x2 = x[..., : feature_dim // 2], x[..., feature_dim // 2 :] + return torch.cat((-x2, x1), dim=-1) + + def _apply_1d_rope( + self, tokens: torch.Tensor, positions: torch.Tensor, cos_comp: torch.Tensor, sin_comp: torch.Tensor + ) -> torch.Tensor: + """Applies 1D rotary position embeddings along one dimension. + + Args: + tokens: Input token features. + positions: Position indices. + cos_comp: Cosine components for rotation. + sin_comp: Sine components for rotation. + + Returns: + Tokens with applied rotary position embeddings. + """ + # Embed positions with frequency components + cos = F.embedding(positions, cos_comp)[:, None, :, :] + sin = F.embedding(positions, sin_comp)[:, None, :, :] + + # Apply rotation + return (tokens * cos) + (self._rotate_features(tokens) * sin) + + def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor: + """Applies 2D rotary position embeddings to input tokens. + + Args: + tokens: Input tensor of shape (batch_size, n_heads, n_tokens, dim). + The feature dimension (dim) must be divisible by 4. + positions: Position tensor of shape (batch_size, n_tokens, 2) containing + the y and x coordinates for each token. + + Returns: + Tensor of same shape as input with applied 2D rotary position embeddings. + + Raises: + AssertionError: If input dimensions are invalid or positions are malformed. + """ + # Validate inputs + assert tokens.size(-1) % 2 == 0, "Feature dimension must be even" + assert positions.ndim == 3 and positions.shape[-1] == 2, "Positions must have shape (batch_size, n_tokens, 2)" + + # Compute feature dimension for each spatial direction + feature_dim = tokens.size(-1) // 2 + + # Get frequency components + max_position = int(positions.max()) + 1 + cos_comp, sin_comp = self._compute_frequency_components(feature_dim, max_position, tokens.device, tokens.dtype) + + # Split features for vertical and horizontal processing + vertical_features, horizontal_features = tokens.chunk(2, dim=-1) + + # Apply RoPE separately for each dimension + vertical_features = self._apply_1d_rope(vertical_features, positions[..., 0], cos_comp, sin_comp) + horizontal_features = self._apply_1d_rope(horizontal_features, positions[..., 1], cos_comp, sin_comp) + + # Combine processed features + return torch.cat((vertical_features, horizontal_features), dim=-1) diff --git a/Pixel-Perfect-Depth/ppd/utils/align_depth_func.py b/Pixel-Perfect-Depth/ppd/utils/align_depth_func.py new file mode 100644 index 0000000000000000000000000000000000000000..2a2ce5af914c62dea4adefa6e58d2914aead3127 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/utils/align_depth_func.py @@ -0,0 +1,40 @@ +import torch +import numpy as np +import cv2 +from sklearn.linear_model import RANSACRegressor +from sklearn.preprocessing import PolynomialFeatures +from sklearn.pipeline import make_pipeline + +degree = 1 +poly_features = PolynomialFeatures(degree=degree, include_bias=False) +ransac = RANSACRegressor(max_trials=1000) +model = make_pipeline(poly_features, ransac) + +def recover_metric_depth_ransac(pred, gt, mask): + pred = pred.astype(np.float32) + gt = gt.astype(np.float32) + + mask_gt = gt[mask].astype(np.float32) + mask_pred = pred[mask].astype(np.float32) + + ## depth -> log depth + mask_gt = np.log(mask_gt + 1.) + + try: + model.fit(mask_pred[:, None], mask_gt[:, None]) + a, b = model.named_steps['ransacregressor'].estimator_.coef_, model.named_steps['ransacregressor'].estimator_.intercept_ + a = a.item() + b = b.item() + except: + a, b = 1, 0 + + if a > 0: + pred_metric = a * pred + b + else: + pred_mean = np.mean(mask_pred) + gt_mean = np.mean(mask_gt) + pred_metric = pred * (gt_mean / pred_mean) + + ## log depth -> depth + pred_metric = np.exp(pred_metric) - 1. + return pred_metric \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/utils/depth2pcd.py b/Pixel-Perfect-Depth/ppd/utils/depth2pcd.py new file mode 100644 index 0000000000000000000000000000000000000000..f601893b80e6bac26e11dc1cf3ff700f58880a12 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/utils/depth2pcd.py @@ -0,0 +1,55 @@ +import numpy as np +import open3d as o3d + +def depth2pcd(depth, intrinsic, color=None, input_mask=None, ret_pcd=False): + """ + Convert a depth map into a 3D point cloud. + + Args: + depth (np.ndarray): (H, W) depth map in meters. + intrinsic (np.ndarray): (3, 3) camera intrinsic matrix. + color (np.ndarray, optional): (H, W, 3) RGB image aligned with the depth map. + input_mask (np.ndarray, optional): (H, W) boolean mask indicating valid pixels. + ret_pcd (bool, optional): If True, returns an Open3D PointCloud object; + otherwise returns NumPy arrays. + + Returns: + - If ret_pcd=True: returns `o3d.geometry.PointCloud()` + - Otherwise: returns (N, 3) point coordinates and (N, 3) color arrays. + """ + H, W = depth.shape + x, y = np.meshgrid(np.arange(W), np.arange(H)) + xx, yy = x.reshape(-1), y.reshape(-1) + zz = depth.reshape(-1) + + # Create a valid pixel mask + mask = np.ones_like(zz, dtype=bool) + if input_mask is not None: + mask &= input_mask.reshape(-1) + + # Form homogeneous pixel coordinates + pixels = np.stack([xx, yy, np.ones_like(xx)], axis=1) + + # Back-project pixels into 3D camera coordinates + points = pixels * zz[:, None] + points = np.dot(points, np.linalg.inv(intrinsic).T) + + # Keep only valid points + points = points[mask] + + # Process color information + if color is not None: + color = color.astype(np.float32) / 255.0 + colors = color.reshape(-1, 3)[mask] + else: + colors = None + + # Return Open3D point cloud or NumPy arrays + if ret_pcd: + pcd = o3d.geometry.PointCloud() + pcd.points = o3d.utility.Vector3dVector(points) + if colors is not None: + pcd.colors = o3d.utility.Vector3dVector(colors) + return pcd + else: + return points, colors \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/utils/sampler.py b/Pixel-Perfect-Depth/ppd/utils/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..1c8fe3e76092137fb2557c1908ce9660943bab2d --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/utils/sampler.py @@ -0,0 +1,73 @@ +import torch +from enum import Enum +from ppd.utils.timesteps import Timesteps +from ppd.utils.schedule import LinearSchedule + + +class EulerSampler: + """ + The Euler method is the simplest ODE solver. + """ + + def __init__( + self, + schedule: LinearSchedule, + timesteps: Timesteps, + prediction_type: 'velocity', + ): + self.schedule = schedule + self.timesteps = timesteps + self.prediction_type = prediction_type + + + def step( + self, + pred: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + Step to the next timestep. + """ + return self.step_to(pred, x_t, t, self.get_next_timestep(t), **kwargs) + + def step_to( + self, + pred: torch.Tensor, + x_t: torch.Tensor, + t: torch.Tensor, + s: torch.Tensor, + **kwargs, + ) -> torch.Tensor: + """ + Steps from x_t at timestep t to x_s at timestep s. Returns x_s. + """ + t = t[(...,) + (None,) * (x_t.ndim - t.ndim)] if t.ndim < x_t.ndim else t + s = s[(...,) + (None,) * (x_t.ndim - s.ndim)] if s.ndim < x_t.ndim else s + T = self.schedule.T + # Step from x_t to x_s. + pred_x_0, pred_x_T = self.schedule.convert_from_pred(pred, self.prediction_type, x_t, t) + pred_x_s = self.schedule.forward(pred_x_0, pred_x_T, s.clamp(0, T)) + # Clamp x_s to x_0 and x_T if s is out of bound. + pred_x_s = pred_x_s.where(s >= 0, pred_x_0) + pred_x_s = pred_x_s.where(s <= T, pred_x_T) + return pred_x_s + + def get_next_timestep( + self, + t: torch.Tensor, + ) -> torch.Tensor: + """ + Get the next sample timestep. + Support multiple different timesteps t in a batch. + If no more steps, return out of bound value -1 or T+1. + """ + T = self.timesteps.T + steps = len(self.timesteps) + curr_idx = self.timesteps.index(t) + next_idx = curr_idx + 1 + + s = self.timesteps[next_idx.clamp_max(steps - 1)] + s = s.where(next_idx < steps, -1) + return s \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/utils/schedule.py b/Pixel-Perfect-Depth/ppd/utils/schedule.py new file mode 100644 index 0000000000000000000000000000000000000000..9c2833cc91af17b69fc95d3fc5b741bb8a0157ce --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/utils/schedule.py @@ -0,0 +1,54 @@ +""" +Linear interpolation schedule (lerp). +""" + +from typing import Tuple, Union +import torch +from enum import Enum + + +class LinearSchedule: + """ + Linear interpolation schedule (lerp) is proposed by flow matching and rectified flow. + It leads to straighter probability flow theoretically. It is also used by Stable Diffusion 3. + + x_t = (1 - t) * x_0 + t * x_T + + """ + + def __init__(self, T: Union[int, float] = 1.0): + self.T = T + + def forward(self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor) -> torch.Tensor: + """ + Diffusion forward function. + """ + t = t[(...,) + (None,) * (x_0.ndim - t.ndim)] if t.ndim < x_0.ndim else t + return (1 - t / self.T) * x_0 + (t / self.T) * x_T + + def convert_from_pred( + self, pred: torch.Tensor, pred_type: 'velocity', x_t: torch.Tensor, t: torch.Tensor + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Convert from velocity prediction. Return predicted x_0 and x_T. + """ + t = t[(...,) + (None,) * (x_t.ndim - t.ndim)] if t.ndim < x_t.ndim else t + A_t = 1 - t / self.T + B_t = t / self.T + + # pred_type = 'velocity' + pred_x_0 = x_t - B_t * pred + pred_x_T = x_t + A_t * pred + + return pred_x_0, pred_x_T + + def convert_to_pred( + self, x_0: torch.Tensor, x_T: torch.Tensor, t: torch.Tensor, pred_type: 'velocity' + ) -> torch.FloatTensor: + """ + Convert to velocity prediction target given x_0 and x_T. + Predict velocity dx/dt based on the lerp schedule (x_T - x_0). + Proposed by rectified flow (https://arxiv.org/abs/2209.03003) + """ + # pred_type = 'velocity' + return x_T - x_0 diff --git a/Pixel-Perfect-Depth/ppd/utils/set_seed.py b/Pixel-Perfect-Depth/ppd/utils/set_seed.py new file mode 100644 index 0000000000000000000000000000000000000000..c03635e993187e3da2603e9e0adf90aa920ce598 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/utils/set_seed.py @@ -0,0 +1,13 @@ +import random +import numpy as np +import torch + +def set_seed(seed=666): + import random, numpy as np, torch + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed_all(seed) + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/utils/timesteps.py b/Pixel-Perfect-Depth/ppd/utils/timesteps.py new file mode 100644 index 0000000000000000000000000000000000000000..79aaa2fd9b12ca5380f3d508390cc51e2c5027c9 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/utils/timesteps.py @@ -0,0 +1,39 @@ +from typing import Union +import torch + + +class Timesteps: + """ + Sampling timesteps. + It defines the discretization of sampling steps. + """ + + def __init__( + self, + T: int, + steps: int, + device: torch.device = "cpu", + ): + self.T = T + timesteps = torch.arange(T, -1, -(T + 1) / steps, device=device).round().int() + self.timesteps = timesteps + + def __len__(self) -> int: + """ + Number of sampling steps. + """ + return len(self.timesteps) + + def __getitem__(self, idx: Union[int, torch.IntTensor]) -> torch.Tensor: + return self.timesteps[idx] + + def index(self, t: torch.Tensor) -> torch.Tensor: + """ + Find index by t. + Return index of the same shape as t. + Index is -1 if t not found in timesteps. + """ + i, j = t.reshape(-1, 1).eq(self.timesteps).nonzero(as_tuple=True) + idx = torch.full_like(t, fill_value=-1, dtype=torch.int) + idx.view(-1)[i] = j.int() + return idx \ No newline at end of file diff --git a/Pixel-Perfect-Depth/ppd/utils/transform.py b/Pixel-Perfect-Depth/ppd/utils/transform.py new file mode 100644 index 0000000000000000000000000000000000000000..76dba3f00d14f4d3831b77692cf0d235e44ff6d1 --- /dev/null +++ b/Pixel-Perfect-Depth/ppd/utils/transform.py @@ -0,0 +1,68 @@ +import cv2 +import numpy as np +import torch +import torch.nn.functional as F + + + +def image2tensor(image): + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image = np.asarray(image / 255.).astype(np.float32) + image = np.transpose(image, (2, 0, 1)) + image = np.ascontiguousarray(image).astype(np.float32) + image = torch.from_numpy(image).unsqueeze(0) + + return image + +def resize_1024(image): + image = cv2.resize(image, (1024, 768), interpolation=cv2.INTER_LINEAR) + return image + +def resize_1024_crop(image): + ori_h, ori_w = image.shape[:2] + tar_w, tar_h = 1024, 768 + if ori_h > ori_w: + resize_h = int(tar_w / ori_w * ori_h) + image = cv2.resize(image, (tar_w, resize_h), interpolation=cv2.INTER_LINEAR) + if resize_h > tar_h: + top = (resize_h - tar_h) // 2 + image = image[top:top+tar_h, :] + else: + image = cv2.resize(image, (tar_w, tar_h), interpolation=cv2.INTER_LINEAR) + + else: + resize_w = int(tar_h / ori_h * ori_w) + image = cv2.resize(image, (resize_w, tar_h), interpolation=cv2.INTER_LINEAR) + + if resize_w > tar_w: + left = (resize_w - tar_w) // 2 + image = image[:, left:left+tar_w] + else: + image = cv2.resize(image, (tar_w, tar_h), interpolation=cv2.INTER_LINEAR) + + return image + +def resize_keep_aspect(image): + ori_h, ori_w = image.shape[:2] + tar_w, tar_h = 1024, 768 + ori_area = ori_h * ori_w + tar_area = tar_h * tar_w + scale = scale = (tar_area / ori_area) ** 0.5 + resize_h = ori_h * scale + resize_w = ori_w * scale + resize_h = max(16, int(round(resize_h / 16)) * 16) + resize_w = max(16, int(round(resize_w / 16)) * 16) + if scale < 1: + image = cv2.resize(image, (resize_w, resize_h), interpolation=cv2.INTER_AREA) + else: + image = cv2.resize(image, (resize_w, resize_h), interpolation=cv2.INTER_CUBIC) + return image + + + + + + + + + diff --git a/README.md b/README.md index 22763d1e20e169367e68fe242b0fe316855557ca..57077ae047ec0bd7678b0766cd0d477595fbee7a 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,5 @@ --- -title: Depth Anything Compare Demo +title: Depth Estimation Compare Demo emoji: πŸ‘€ colorFrom: indigo colorTo: indigo @@ -9,185 +9,99 @@ app_file: app.py pinned: false --- -# Depth Anything v1 vs v2 Comparison Demo +# Depth Estimation Comparison Demo -A comprehensive comparison tool for **Depth Anything v1** and **Depth Anything v2** models, built with Gradio and optimized for HuggingFace Spaces with ZeroGPU support. +A ZeroGPU-friendly Gradio interface for comparing **Depth Anything v1**, **Depth Anything v2**, and **Pixel-Perfect Depth (PPD)** on the same image. Switch between side-by-side layouts, a slider overlay, or single-model inspection to understand how different pipelines perceive scene geometry. -## πŸš€ Features +## πŸš€ Highlights +- **Three interactive views**: draggable slider, labeled side-by-side comparison, and original vs depth for any single model. +- **Multi-family depth models**: run ViT variants from Depth Anything v1/v2 alongside Pixel-Perfect Depth with MoGe metric alignment. +- **ZeroGPU aware**: on-demand loading, model cache clearing, and torch CUDA cleanup keep GPU usage inside HuggingFace Spaces limits. +- **Curated examples**: reusable demo images sourced from each model family plus local assets to quickly validate behaviour. -### Three Comparison Modes +## πŸ” Supported Pipelines +- **Depth Anything v1** (`LiheYoung/depth_anything_*`): ViT-S/B/L with fast transformer backbones and colorized outputs via `Spectral_r` colormap. +- **Depth Anything v2** (`Depth-Anything-V2/checkpoints/*.pth`): ViT-Small/Base/Large with HF Hub fallback, configurable feature channels, and improved edge handling. +- **Pixel-Perfect Depth**: Diffusion-based relative depth refined by the **MoGe** metric surface model and RANSAC alignment to recover metric depth; customizable denoising steps. -1. **🎚️ Slider Comparison**: Interactive side-by-side comparison with a draggable slider -2. **πŸ” Method Comparison**: Traditional side-by-side view with model labels -3. **πŸ”¬ Single Model**: Run individual models for detailed analysis - -### Supported Models - -#### Depth Anything v1 -- **ViT-S (Small)**: Fastest inference, good quality -- **ViT-B (Base)**: Balanced speed and quality -- **ViT-L (Large)**: Best quality, slower inference - -#### Depth Anything v2 -- **ViT-Small**: Enhanced small model with improved accuracy -- **ViT-Base**: Balanced performance with v2 improvements -- **ViT-Large**: State-of-the-art depth estimation quality - -## πŸ–ΌοΈ Example Images - -The demo includes 20+ carefully selected example images showcasing various scenarios: -- Indoor and outdoor scenes -- Different lighting conditions -- Various object types and compositions -- Challenging depth estimation scenarios - -## πŸ› οΈ Technical Details - -### Architecture -- **Framework**: Gradio 4.0+ with modern UI components -- **Backend**: PyTorch with CUDA acceleration -- **Deployment**: ZeroGPU-optimized for HuggingFace Spaces -- **Memory Management**: Automatic model loading/unloading for efficient GPU usage - -### ZeroGPU Optimizations -- `@spaces.GPU` decorators for GPU-intensive functions -- Automatic memory cleanup between inferences -- On-demand model loading to prevent OOM errors -- Efficient resource allocation and deallocation - -### Depth Visualization -- **Colormap**: Spectral_r colormap for intuitive depth representation -- **Normalization**: Min-max scaling for consistent visualization -- **Resolution**: Maintains original image aspect ratios +## πŸ–₯️ App Experience +- **Slider Comparison**: drag between two predictions with automatically labeled overlays. +- **Method Comparison**: view models side-by-side with synchronized layout and captions rendered in OpenCV. +- **Single Model**: inspect the RGB input versus one model output using the Gradio `ImageSlider` component. +- **Example Gallery**: natural-number sorting across `assets/examples`, `Depth-Anything/assets/examples`, `Depth-Anything-V2/assets/examples`, and `Pixel-Perfect-Depth/assets/examples`. ## πŸ“¦ Installation & Setup ### Local Development - -1. **Clone the repository**: -```bash -git clone -cd Depth-Anything-Compare-demo -``` - -2. **Install dependencies**: -```bash -pip install -r requirements.txt -``` - -3. **Download model checkpoints** (for local usage): -```bash -# Depth Anything v1 models are downloaded automatically from HuggingFace Hub -# For v2 models, download checkpoints to Depth-Anything-V2/checkpoints/ -``` - -4. **Run locally**: -```bash -python app_local.py # For local development -python app.py # For ZeroGPU deployment -``` - -### HuggingFace Spaces Deployment - -This app is optimized for HuggingFace Spaces with ZeroGPU. Simply: - -1. Upload the repository to your HuggingFace Space -2. Set hardware to "ZeroGPU" -3. The app will automatically handle GPU allocation and model loading +1. **Clone & enter**: + ```bash + git clone + cd Depth-Estimation-Compare-demo + ``` +2. **Install dependencies** (includes `gradio`, `torch`, `gradio_imageslider`, `open3d`, `scikit-learn`, and MoGe utilities): + ```bash + pip install -r requirements.txt + ``` +3. **Model assets**: + - Depth Anything v1 checkpoints stream automatically from the HuggingFace Hub. + - Download Depth Anything v2 weights into `Depth-Anything-V2/checkpoints/` if they are not already present (`depth_anything_v2_vits.pth`, `depth_anything_v2_vitb.pth`, `depth_anything_v2_vitl.pth`). + - Pixel-Perfect Depth pulls the diffusion checkpoint (`ppd.pth`) from `gangweix/Pixel-Perfect-Depth` on first use and loads MoGe weights (`Ruicheng/moge-2-vitl-normal`). +4. **Run the app**: + ```bash + python app_local.py # Local UI with live reload tweaks + python app.py # ZeroGPU-ready launch script + ``` + +### HuggingFace Spaces (ZeroGPU) +1. Push the repository contents to a Gradio Space. +2. Select the **ZeroGPU** hardware preset. +3. The app will download required checkpoints on demand and aggressively free memory after each inference via `clear_model_cache()`. ## πŸ“ Project Structure - ``` -Depth-Anything-Compare-demo/ -β”œβ”€β”€ app.py # ZeroGPU-optimized main application -β”œβ”€β”€ app_local.py # Local development version -β”œβ”€β”€ requirements.txt # Python dependencies -β”œβ”€β”€ README.md # This file +Depth-Estimation-Compare-demo/ +β”œβ”€β”€ app.py # ZeroGPU deployment entrypoint +β”œβ”€β”€ app_local.py # Local-friendly launch script +β”œβ”€β”€ requirements.txt # Python dependencies (Gradio, Torch, PPD stack) β”œβ”€β”€ assets/ -β”‚ └── examples/ # Example images for testing -β”œβ”€β”€ Depth-Anything/ # Depth Anything v1 implementation -β”‚ β”œβ”€β”€ depth_anything/ -β”‚ β”‚ β”œβ”€β”€ dpt.py # v1 model architecture -β”‚ β”‚ └── util/ # v1 utilities and transforms -β”‚ └── torchhub/ # Required dependencies -└── Depth-Anything-V2/ # Depth Anything v2 implementation - β”œβ”€β”€ depth_anything_v2/ - β”‚ β”œβ”€β”€ dpt.py # v2 model architecture - β”‚ └── dinov2_layers/ # DINOv2 components - └── assets/ - └── examples/ # v2-specific examples +β”‚ └── examples/ # Shared demo imagery +β”œβ”€β”€ Depth-Anything/ # Depth Anything v1 implementation + utilities +β”œβ”€β”€ Depth-Anything-V2/ # Depth Anything v2 implementation & checkpoints +β”œβ”€β”€ Pixel-Perfect-Depth/ # Pixel-Perfect Depth diffusion + MoGe helpers +└── README.md # You are here ``` -## πŸ”§ Configuration - -### Model Configuration -Models are configured in the respective config dictionaries: -- `V1_MODEL_CONFIGS`: HuggingFace Hub model identifiers -- `V2_MODEL_CONFIGS`: Local checkpoint paths and architecture parameters - -### Environment Variables -- `DEVICE`: Automatically detects CUDA availability -- GPU memory is managed automatically by ZeroGPU - -## πŸ“Š Performance - -### Inference Times (Approximate) -- **ViT-S models**: ~1-2 seconds -- **ViT-B models**: ~2-4 seconds -- **ViT-L models**: ~4-8 seconds - -*Times vary based on image resolution and GPU availability* +## βš™οΈ Configuration Notes +- Model dropdown labels come from `V1_MODEL_CONFIGS`, `V2_MODEL_CONFIGS`, and the PPD entry in `app.py`. +- `clear_model_cache()` resets every model and flushes CUDA to respect ZeroGPU constraints. +- Pixel-Perfect Depth inference aligns relative depth to metric scale through `recover_metric_depth_ransac()` for consistent visualization. +- Depth visualizations use a normalized `Spectral_r` colormap; PPD uses a dedicated matplotlib colormap for metric maps. -### Memory Usage -- Optimized for ZeroGPU's memory constraints -- Automatic model unloading prevents OOM errors -- Efficient batch processing for multiple comparisons +## πŸ“Š Performance Expectations +- **Depth Anything v1**: ViT-S ~1–2 s, ViT-B ~2–4 s, ViT-L ~4–8 s (image dependent). +- **Depth Anything v2**: similar to v1 with improved sharpness; HF downloads add one-time startup overhead. +- **Pixel-Perfect Depth**: diffusion + metric refinement typically takes longer (10–20 denoise steps) but returns metrically-aligned depth suitable for downstream 3D tasks. -## 🎯 Usage Examples - -### Compare v1 vs v2 Models -1. Upload an image or select from examples -2. Choose models from both v1 and v2 families -3. Click "Compare" or "Slider Compare" -4. Analyze the depth estimation differences - -### Analyze Single Model Performance -1. Select "Single Model" tab -2. Choose any available model -3. Upload image and click "Run" -4. Examine detailed depth map output +## 🎯 Usage Tips +- Mix-and-match any two models in comparison tabs to highlight qualitative differences. +- Use the Single Model tab to corroborate PPD metric depth versus RGB input. +- Leverage the provided examples to benchmark indoor/outdoor, lighting extremes, and complex geometry scenarios before running custom images. ## 🀝 Contributing - -Contributions are welcome! Areas for improvement: -- Additional model variants -- New visualization options -- Performance optimizations -- UI/UX enhancements +Enhancements are welcomeβ€”new model backends, visualization modes, or memory optimizations are especially valuable for ZeroGPU deployments. Please follow the coding style in `app.py` and keep documentation in sync with new capabilities. ## πŸ“š References - -- **Depth Anything v1**: [LiheYoung/Depth-Anything](https://github.com/LiheYoung/Depth-Anything) -- **Depth Anything v2**: [DepthAnything/Depth-Anything-V2](https://github.com/DepthAnything/Depth-Anything-V2) -- **Original Papers**: - - [Depth Anything: Unleashing the Power of Large-Scale Unlabeled Data](https://arxiv.org/abs/2401.10891) - - [Depth Anything V2: More Efficient, Better Supervised](https://arxiv.org/abs/2406.09414) +- [Depth Anything v1](https://github.com/LiheYoung/Depth-Anything) +- [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2) +- [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth) +- [MoGe](https://huggingface.co/Ruicheng/moge-2-vitl-normal) ## πŸ“„ License - -This project combines implementations from: - Depth Anything v1: MIT License - Depth Anything v2: Apache 2.0 License -- Demo code: MIT License - -Please check individual component licenses for specific terms. - -## πŸ™ Acknowledgments - -- Original Depth Anything authors and contributors -- HuggingFace team for Spaces and ZeroGPU infrastructure -- Gradio team for the excellent UI framework +- Pixel-Perfect Depth: see upstream repository for licensing +- Demo scaffolding in this repo: MIT License (follow individual component terms) --- -**Note**: This is a demonstration/comparison tool. For production use of the Depth Anything models, please refer to the original repositories and follow their recommended practices. +Built as a hands-on playground for exploring modern monocular depth estimators. Adjust tabs, compare outputs, and plug results into your 3D workflows. diff --git a/app.py b/app.py index 8fa9d131f4f44ecd8ed0fc594d8397e75a60357d..aa74506cc1eca4b4f676d41ec70412f0f92e8464 100644 --- a/app.py +++ b/app.py @@ -1,7 +1,7 @@ """ -Depth Anything Comparison Demo (v1 vs v2) - ZeroGPU Version +Depth Estimation Comparison Demo (ZeroGPU) -Compare different Depth Anything models (v1 and v2) side-by-side or with a slider using Gradio. +Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider using Gradio. Optimized for HuggingFace Spaces with ZeroGPU support. """ @@ -9,19 +9,17 @@ import os import sys import logging import gc -import tempfile -from pathlib import Path -from typing import Optional, Tuple, Dict, List +from typing import Optional, Tuple, List import numpy as np import cv2 import gradio as gr -from PIL import Image from huggingface_hub import hf_hub_download import spaces # Import v1 and v2 model code sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything")) sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything-V2")) +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Pixel-Perfect-Depth")) # v1 imports from depth_anything.dpt import DepthAnything as DepthAnythingV1 @@ -35,11 +33,20 @@ from depth_anything_v2.dpt import DepthAnythingV2 import matplotlib +# Pixel-Perfect Depth imports +from ppd.utils.set_seed import set_seed +from ppd.utils.align_depth_func import recover_metric_depth_ransac +from moge.model.v2 import MoGeModel +from ppd.models.ppd import PixelPerfectDepth + # Logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') # Device selection - ZeroGPU will handle GPU allocation DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu' +TORCH_DEVICE = torch.device(DEVICE) + +set_seed(666) # Model configs V1_MODEL_CONFIGS = { @@ -78,6 +85,11 @@ V2_MODEL_CONFIGS = { # Model cache - cleared after each inference for ZeroGPU _v1_models = {} _v2_models = {} +_ppd_model: Optional[PixelPerfectDepth] = None +_moge_model: Optional[MoGeModel] = None + +PPD_DEFAULT_STEPS = 20 +_ppd_cmap = matplotlib.colormaps.get_cmap('Spectral') # v1 transform v1_transform = Compose([ @@ -150,13 +162,15 @@ def load_v2_model(key: str): def clear_model_cache(): """Clear model cache to free GPU memory for ZeroGPU""" - global _v1_models, _v2_models + global _v1_models, _v2_models, _ppd_model, _moge_model for model in _v1_models.values(): del model for model in _v2_models.values(): del model _v1_models.clear() _v2_models.clear() + _ppd_model = None + _moge_model = None gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -183,12 +197,76 @@ def colorize_depth(depth: np.ndarray) -> np.ndarray: colored = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8) return colored + +def _normalize_depth_to_rgb(depth: np.ndarray) -> np.ndarray: + depth_vis = (depth - depth.min()) / (depth.max() - depth.min() + 1e-5) * 255.0 + depth_vis = depth_vis.astype(np.uint8) + colored = (_ppd_cmap(depth_vis)[:, :, :3] * 255).astype(np.uint8) + return colored + + +def load_ppd_model() -> PixelPerfectDepth: + global _ppd_model + if _ppd_model is not None: + return _ppd_model + + model = PixelPerfectDepth(sampling_steps=PPD_DEFAULT_STEPS) + ckpt_path = hf_hub_download( + repo_id="gangweix/Pixel-Perfect-Depth", + filename="ppd.pth", + repo_type="model" + ) + state_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(state_dict, strict=False) + model = model.to(TORCH_DEVICE).eval() + _ppd_model = model + return _ppd_model + + +def load_moge_model() -> MoGeModel: + global _moge_model + if _moge_model is not None: + return _moge_model + + model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").eval() + model = model.to(TORCH_DEVICE) + _moge_model = model + return _moge_model + + +def pixel_perfect_depth_inference(image_bgr: np.ndarray, denoise_steps: int = PPD_DEFAULT_STEPS) -> Tuple[np.ndarray, np.ndarray]: + if image_bgr is None: + raise ValueError("Pixel-Perfect Depth received an empty image.") + + ppd_model = load_ppd_model() + moge_model = load_moge_model() + + H, W = image_bgr.shape[:2] + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + + with torch.no_grad(): + depth_rel, resize_image = ppd_model.infer_image(image_bgr, sampling_steps=denoise_steps) + + rgb_tensor = torch.tensor(cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) / 255, dtype=torch.float32, device=TORCH_DEVICE).permute(2, 0, 1) + + with torch.no_grad(): + metric_depth, mask, intrinsics = moge_model.infer(rgb_tensor) + + metric_depth[~mask] = metric_depth[mask].max() + metric_depth_aligned = recover_metric_depth_ransac(depth_rel, metric_depth, mask) + + depth_full = cv2.resize(metric_depth_aligned, (W, H), interpolation=cv2.INTER_LINEAR) + colored_depth = _normalize_depth_to_rgb(depth_full) + + return image_rgb, colored_depth + def get_model_choices() -> List[Tuple[str, str]]: choices = [] for k, v in V1_MODEL_CONFIGS.items(): choices.append((v['display_name'], f'v1_{k}')) for k, v in V2_MODEL_CONFIGS.items(): choices.append((v['display_name'], f'v2_{k}')) + choices.append(("Pixel-Perfect Depth", "ppd")) return choices @spaces.GPU @@ -200,14 +278,21 @@ def run_model(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, str]: model = load_v1_model(key) depth = predict_v1(model, image) label = V1_MODEL_CONFIGS[key]['display_name'] - else: + colored = colorize_depth(depth) + return colored, label + elif model_key.startswith('v2_'): key = model_key[3:] model = load_v2_model(key) depth = predict_v2(model, image) label = V2_MODEL_CONFIGS[key]['display_name'] - - colored = colorize_depth(depth) - return colored, label + colored = colorize_depth(depth) + return colored, label + elif model_key == 'ppd': + clear_model_cache() + _, colored = pixel_perfect_depth_inference(image) + return colored, "Pixel-Perfect Depth" + else: + raise ValueError(f"Unknown model key: {model_key}") finally: # Clean up GPU memory after inference if torch.cuda.is_available(): @@ -229,6 +314,10 @@ def compare_models(image, model1: str, model2: str, progress=gr.Progress()) -> T image = np.array(image) if len(image.shape) == 3 and image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + else: + image = np.array(image) + if len(image.shape) == 3 and image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) progress(0.1, desc=f"Running {model1}") out1, label1 = run_model(model1, image) @@ -271,6 +360,10 @@ def slider_compare(image, model1: str, model2: str, progress=gr.Progress()): image = np.array(image) if len(image.shape) == 3 and image.shape[2] == 3: image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + else: + image = np.array(image) + if len(image.shape) == 3 and image.shape[2] == 3: + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) progress(0.1, desc=f"Running {model1}") out1, label1 = run_model(model1, image) @@ -346,7 +439,12 @@ def get_example_images() -> List[str]: # Try both v1 and v2 examples examples = [] - for ex_dir in ["assets/examples", "Depth-Anything/assets/examples", "Depth-Anything-V2/assets/examples"]: + for ex_dir in [ + "assets/examples", + "Depth-Anything/assets/examples", + "Depth-Anything-V2/assets/examples", + "Pixel-Perfect-Depth/assets/examples", + ]: ex_path = os.path.join(os.path.dirname(__file__), ex_dir) if os.path.exists(ex_path): # Get all image files and sort them naturally @@ -370,13 +468,17 @@ def get_paginated_examples(examples: List[str], page: int = 0, per_page: int = 6 def create_app(): model_choices = get_model_choices() - default1 = model_choices[0][1] - default2 = model_choices[1][1] + default1 = next((value for _, value in model_choices if value.startswith('v1_')), model_choices[0][1]) + default2 = next((value for _, value in model_choices if value == 'ppd'), None) + if default2 is None: + default2 = next((value for _, value in model_choices if value.startswith('v2_') and value != default1), model_choices[min(1, len(model_choices) - 1)][1]) - with gr.Blocks(title="Depth Anything v1 vs v2 Comparison", theme=gr.themes.Soft()) as app: + example_images = get_example_images() + + with gr.Blocks(title="Depth Estimation Comparison", theme=gr.themes.Soft()) as app: gr.Markdown(""" - # Depth Anything v1 vs v2 Comparison - Compare different Depth Anything models (v1 and v2) side-by-side or with a slider. + # Depth Estimation Comparison + Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider. ⚑ **Running on ZeroGPU** - GPU resources are allocated automatically for inference. """) @@ -384,7 +486,7 @@ def create_app(): with gr.Tabs(): with gr.Tab("🎚️ Slider Comparison"): with gr.Row(): - img_input2 = gr.Image(label="Input Image") + img_input2 = gr.Image(label="Input Image", type="numpy") with gr.Column(): m1s = gr.Dropdown(choices=model_choices, label="Model A", value=default1) m2s = gr.Dropdown(choices=model_choices, label="Model B", value=default2) @@ -394,15 +496,14 @@ def create_app(): btn2.click(slider_compare, inputs=[img_input2, m1s, m2s], outputs=[slider, slider_status], show_progress=True) # Examples for slider comparison - ex_imgs = get_example_images() - if ex_imgs: + if example_images: def slider_example_fn(image): return slider_compare(image, default1, default2) - examples2 = gr.Examples(examples=ex_imgs, inputs=[img_input2], outputs=[slider, slider_status], fn=slider_example_fn) + gr.Examples(examples=example_images, inputs=[img_input2], outputs=[slider, slider_status], fn=slider_example_fn) with gr.Tab("πŸ” Method Comparison"): with gr.Row(): - img_input = gr.Image(label="Input Image") + img_input = gr.Image(label="Input Image", type="numpy") with gr.Column(): m1 = gr.Dropdown(choices=model_choices, label="Model 1", value=default1) m2 = gr.Dropdown(choices=model_choices, label="Model 2", value=default2) @@ -412,14 +513,14 @@ def create_app(): btn.click(compare_models, inputs=[img_input, m1, m2], outputs=[out_img, out_status], show_progress=True) # Examples for method comparison - if ex_imgs: + if example_images: def compare_example_fn(image): return compare_models(image, default1, default2) - examples = gr.Examples(examples=ex_imgs, inputs=[img_input], outputs=[out_img, out_status], fn=compare_example_fn) + gr.Examples(examples=example_images, inputs=[img_input], outputs=[out_img, out_status], fn=compare_example_fn) - with gr.Tab("οΏ½ Single Model"): + with gr.Tab("πŸ“· Single Model"): with gr.Row(): - img_input3 = gr.Image(label="Input Image") + img_input3 = gr.Image(label="Input Image", type="numpy") with gr.Column(): m_single = gr.Dropdown(choices=model_choices, label="Model", value=default1) btn3 = gr.Button("Run", variant="primary") @@ -428,16 +529,17 @@ def create_app(): btn3.click(single_inference, inputs=[img_input3, m_single], outputs=[single_slider, out_single_status], show_progress=True) # Examples for single model - if ex_imgs: + if example_images: def single_example_fn(image): return single_inference(image, default1) - examples3 = gr.Examples(examples=ex_imgs, inputs=[img_input3], outputs=[single_slider, out_single_status], fn=single_example_fn) + gr.Examples(examples=example_images, inputs=[img_input3], outputs=[single_slider, out_single_status], fn=single_example_fn) gr.Markdown(""" --- **References:** - **v1**: [Depth Anything v1](https://github.com/LiheYoung/Depth-Anything) - **v2**: [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2) + - **PPD**: [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth) **Note**: This app uses ZeroGPU for efficient GPU resource management. Models are loaded on-demand and GPU memory is automatically cleaned up after each inference. """) @@ -445,7 +547,7 @@ def create_app(): return app def main(): - logging.info("πŸš€ Starting Depth Anything Comparison App on ZeroGPU...") + logging.info("πŸš€ Starting Depth Estimation Comparison App on ZeroGPU...") app = create_app() app.queue().launch(show_error=True) diff --git a/app_local.py b/app_local.py index ec0a94b8167348ee0cf5418f186fd88c2e119c52..574564cf83376700e819f6a61b4af18a63af6561 100644 --- a/app_local.py +++ b/app_local.py @@ -1,22 +1,23 @@ """ -Depth Anything Comparison Demo (v1 vs v2) +Depth Estimation Comparison Demo (Depth Anything v1/v2 + Pixel-Perfect Depth) -Compare different Depth Anything models (v1 and v2) side-by-side or with a slider using Gradio. +Compare Depth Anything models (v1 and v2) and Pixel-Perfect Depth side-by-side or with a slider using Gradio. Inspired by the Stereo Matching Methods Comparison Demo. """ import os import sys import logging -import gc import tempfile +import shutil from pathlib import Path from typing import Optional, Tuple, Dict, List import numpy as np import cv2 import gradio as gr -from PIL import Image from huggingface_hub import hf_hub_download +import open3d as o3d +import trimesh # Import v1 and v2 model code sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Depth-Anything")) @@ -34,6 +35,14 @@ from depth_anything_v2.dpt import DepthAnythingV2 import matplotlib +# Pixel-Perfect Depth imports +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "Pixel-Perfect-Depth")) +from ppd.utils.set_seed import set_seed +from ppd.utils.align_depth_func import recover_metric_depth_ransac +from ppd.utils.depth2pcd import depth2pcd +from moge.model.v2 import MoGeModel +from ppd.models.ppd import PixelPerfectDepth + # Logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') @@ -159,12 +168,147 @@ def colorize_depth(depth: np.ndarray) -> np.ndarray: colored = (cmap(depth_uint8)[:, :, :3] * 255).astype(np.uint8) return colored + +# Pixel-Perfect Depth setup ------------------------------------------------- +set_seed(666) + +TORCH_DEVICE = torch.device(DEVICE) +PPD_DEFAULT_STEPS = 20 +PPD_TEMP_ROOT = Path(tempfile.gettempdir()) / "ppd" + +_ppd_model: Optional[PixelPerfectDepth] = None +_moge_model: Optional[MoGeModel] = None +_ppd_cmap = matplotlib.colormaps.get_cmap('Spectral') + + +def load_ppd_model() -> PixelPerfectDepth: + global _ppd_model + if _ppd_model is None: + model = PixelPerfectDepth(sampling_steps=PPD_DEFAULT_STEPS) + ckpt_path = hf_hub_download( + repo_id="gangweix/Pixel-Perfect-Depth", + filename="ppd.pth", + repo_type="model" + ) + state_dict = torch.load(ckpt_path, map_location="cpu") + model.load_state_dict(state_dict, strict=False) + model = model.to(TORCH_DEVICE).eval() + _ppd_model = model + return _ppd_model + + +def load_moge_model() -> MoGeModel: + global _moge_model + if _moge_model is None: + model = MoGeModel.from_pretrained("Ruicheng/moge-2-vitl-normal").eval() + model = model.to(TORCH_DEVICE) + _moge_model = model + return _moge_model + + +def _ensure_ppd_temp_dir(session_hash: str) -> Path: + PPD_TEMP_ROOT.mkdir(exist_ok=True) + output_path = PPD_TEMP_ROOT / session_hash + shutil.rmtree(output_path, ignore_errors=True) + output_path.mkdir(exist_ok=True, parents=True) + return output_path + + +def _normalize_depth_to_rgb(depth: np.ndarray) -> np.ndarray: + depth_vis = (depth - depth.min()) / (depth.max() - depth.min() + 1e-5) * 255.0 + depth_vis = depth_vis.astype(np.uint8) + colored = (_ppd_cmap(depth_vis)[:, :, :3] * 255).astype(np.uint8) + return colored + + +def pixel_perfect_depth_inference( + image_bgr: np.ndarray, + denoise_steps: int, + apply_filter: bool, + request: Optional[gr.Request] = None, + generate_assets: bool = True +): + if image_bgr is None: + return None, None, [] + + ppd_model = load_ppd_model() + moge_model = load_moge_model() + + H, W = image_bgr.shape[:2] + image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) + + # PixelPerfectDepth expects BGR input similar to original demo + with torch.no_grad(): + depth_rel, resize_image = ppd_model.infer_image(image_bgr, sampling_steps=denoise_steps) + resize_H, resize_W = resize_image.shape[:2] + + # MoGe expects RGB tensor + rgb_tensor = torch.tensor(cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB) / 255, dtype=torch.float32, device=TORCH_DEVICE).permute(2, 0, 1) + with torch.no_grad(): + metric_depth, mask, intrinsics = moge_model.infer(rgb_tensor) + metric_depth[~mask] = metric_depth[mask].max() + + # Align relative depth to metric using RANSAC + metric_depth_aligned = recover_metric_depth_ransac(depth_rel, metric_depth, mask) + intrinsics[0, 0] *= resize_W + intrinsics[1, 1] *= resize_H + intrinsics[0, 2] *= resize_W + intrinsics[1, 2] *= resize_H + + depth_full = cv2.resize(metric_depth_aligned, (W, H), interpolation=cv2.INTER_LINEAR) + colored_depth = _normalize_depth_to_rgb(depth_full) + + if not generate_assets: + return (image_rgb, colored_depth), None, [] + + pcd = depth2pcd( + metric_depth_aligned, + intrinsics, + color=cv2.cvtColor(resize_image, cv2.COLOR_BGR2RGB), + input_mask=mask, + ret_pcd=True + ) + if apply_filter: + _, ind = pcd.remove_statistical_outlier(nb_neighbors=20, std_ratio=2.0) + pcd = pcd.select_by_index(ind) + + session_hash = getattr(request, "session_hash", "default") + output_dir = _ensure_ppd_temp_dir(session_hash) + + # Save artifacts + ply_path = output_dir / "pointcloud.ply" + pcd.points = o3d.utility.Vector3dVector(np.asarray(pcd.points) * np.array([1, -1, -1], dtype=np.float32)) + o3d.io.write_point_cloud(ply_path.as_posix(), pcd) + vertices = np.asarray(pcd.points) + vertex_colors = (np.asarray(pcd.colors) * 255).astype(np.uint8) + mesh = trimesh.PointCloud(vertices=vertices, colors=vertex_colors) + glb_path = output_dir / "pointcloud.glb" + mesh.export(glb_path.as_posix()) + + raw_depth_path = output_dir / "raw_depth.npy" + np.save(raw_depth_path.as_posix(), depth_full) + + split_region = np.ones((image_bgr.shape[0], 50, 3), dtype=np.uint8) * 255 + combined_result = cv2.hconcat([image_bgr, split_region, colored_depth[:, :, ::-1]]) + vis_path = output_dir / "image_depth_vis.png" + cv2.imwrite(vis_path.as_posix(), combined_result) + + available_files = [ + path.as_posix() + for path in [vis_path, raw_depth_path, ply_path] + if path.exists() + ] + + return (image_rgb, colored_depth), glb_path.as_posix(), available_files + + def get_model_choices() -> List[Tuple[str, str]]: choices = [] for k, v in V1_MODEL_CONFIGS.items(): choices.append((v['display_name'], f'v1_{k}')) for k, v in V2_MODEL_CONFIGS.items(): choices.append((v['display_name'], f'v2_{k}')) + choices.append(("Pixel-Perfect Depth", "ppd")) return choices def run_model(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, str]: @@ -173,11 +317,24 @@ def run_model(model_key: str, image: np.ndarray) -> Tuple[np.ndarray, str]: model = load_v1_model(key) depth = predict_v1(model, image) label = V1_MODEL_CONFIGS[key]['display_name'] - else: + elif model_key.startswith('v2_'): key = model_key[3:] model = load_v2_model(key) depth = predict_v2(model, image) label = V2_MODEL_CONFIGS[key]['display_name'] + elif model_key == 'ppd': + slider_data, _, _ = pixel_perfect_depth_inference( + image, + denoise_steps=PPD_DEFAULT_STEPS, + apply_filter=False, + request=None, + generate_assets=False + ) + depth = slider_data[1] + label = "Pixel-Perfect Depth" + return depth, label + else: + raise ValueError(f"Unknown model key: {model_key}") colored = colorize_depth(depth) return colored, label @@ -288,7 +445,12 @@ def get_example_images() -> List[str]: # Try both v1 and v2 examples examples = [] - for ex_dir in ["assets/examples", "Depth-Anything/assets/examples", "Depth-Anything-V2/assets/examples"]: + for ex_dir in [ + "assets/examples", + "Depth-Anything/assets/examples", + "Depth-Anything-V2/assets/examples", + "Pixel-Perfect-Depth/assets/examples", + ]: ex_path = os.path.join(os.path.dirname(__file__), ex_dir) if os.path.exists(ex_path): # Get all image files and sort them naturally @@ -314,15 +476,16 @@ def create_app(): model_choices = get_model_choices() default1 = model_choices[0][1] default2 = model_choices[1][1] + example_images = get_example_images() with gr.Blocks(title="Depth Anything v1 vs v2 Comparison", theme=gr.themes.Soft()) as app: gr.Markdown(""" - # Depth Anything v1 vs v2 Comparison - Compare different Depth Anything models (v1 and v2) side-by-side or with a slider. + # Depth Estimation Comparison + Compare Depth Anything v1, Depth Anything v2, and Pixel-Perfect Depth side-by-side or with a slider. """) with gr.Tabs(): # Select the first tab (Slider Comparison) by default with gr.Tab("🎚️ Slider Comparison"): with gr.Row(): - img_input2 = gr.Image(label="Input Image") + img_input2 = gr.Image(label="Input Image", type="numpy") with gr.Column(): m1s = gr.Dropdown(choices=model_choices, label="Model A", value=default1) m2s = gr.Dropdown(choices=model_choices, label="Model B", value=default2) @@ -331,15 +494,13 @@ def create_app(): slider_status = gr.Markdown() btn2.click(slider_compare, inputs=[img_input2, m1s, m2s], outputs=[slider, slider_status], show_progress=True) - # Simple Examples - Tab 2 - ex_imgs = get_example_images() - if ex_imgs: + if example_images: def slider_example_fn(image): return slider_compare(image, default1, default2) - examples2 = gr.Examples(examples=ex_imgs, inputs=[img_input2], outputs=[slider, slider_status], fn=slider_example_fn) + gr.Examples(examples=example_images, inputs=[img_input2], outputs=[slider, slider_status], fn=slider_example_fn) with gr.Tab("πŸ” Method Comparison"): with gr.Row(): - img_input = gr.Image(label="Input Image") + img_input = gr.Image(label="Input Image", type="numpy") with gr.Column(): m1 = gr.Dropdown(choices=model_choices, label="Model 1", value=default1) m2 = gr.Dropdown(choices=model_choices, label="Model 2", value=default2) @@ -348,30 +509,28 @@ def create_app(): out_status = gr.Markdown() btn.click(compare_models, inputs=[img_input, m1, m2], outputs=[out_img, out_status], show_progress=True) - # Simple Examples - Clean approach - ex_imgs = get_example_images() - if ex_imgs: + if example_images: def compare_example_fn(image): return compare_models(image, default1, default2) - examples = gr.Examples(examples=ex_imgs, inputs=[img_input], outputs=[out_img, out_status], fn=compare_example_fn) + gr.Examples(examples=example_images, inputs=[img_input], outputs=[out_img, out_status], fn=compare_example_fn) with gr.Tab("πŸ“· Single Model"): with gr.Row(): - img_input3 = gr.Image(label="Input Image") + img_input3 = gr.Image(label="Input Image", type="numpy") m_single = gr.Dropdown(choices=model_choices, label="Model", value=default1) btn3 = gr.Button("Run", variant="primary") single_slider = gr.ImageSlider(label="Original vs Depth") out_single_status = gr.Markdown() btn3.click(single_inference, inputs=[img_input3, m_single], outputs=[single_slider, out_single_status], show_progress=True) - # Simple Examples - Tab 3 - if ex_imgs: + if example_images: def single_example_fn(image): return single_inference(image, default1) - examples3 = gr.Examples(examples=ex_imgs, inputs=[img_input3], outputs=[single_slider, out_single_status], fn=single_example_fn) + gr.Examples(examples=example_images, inputs=[img_input3], outputs=[single_slider, out_single_status], fn=single_example_fn) gr.Markdown(""" --- - **v1**: [Depth Anything v1](https://github.com/LiheYoung/Depth-Anything) - **v2**: [Depth Anything v2](https://github.com/DepthAnything/Depth-Anything-V2) + - **PPD**: [Pixel-Perfect Depth](https://github.com/gangweix/pixel-perfect-depth) """) return app diff --git a/requirements.txt b/requirements.txt index dd3622f9f13a43c70b86c512a985b4d82bf46fce..4f97e9328543fb854492fd94d81039ecc620f543 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,4 +8,10 @@ matplotlib>=3.7.0 huggingface-hub>=0.16.0 spaces>=0.25.0 transformers>=4.30.0 -timm>=0.9.0 \ No newline at end of file +timm>=0.9.0 +gradio_imageslider +open3d +scikit-learn +git+https://github.com/EasternJournalist/utils3d.git@c5daf6f6c244d251f252102d09e9b7bcef791a38 +click # ==8.1.7 +trimesh # ==4.5.1 \ No newline at end of file