Spaces:
Running on Zero
Running on Zero
Commit ·
2a36119
1
Parent(s): 7dabaaa
move threeDFixer to safe file
Browse files- app.py +59 -50
- moge/__init__.py +5 -0
- moge/model/__init__.py +23 -0
- moge/model/dinov2/__init__.py +6 -0
- moge/model/dinov2/hub/__init__.py +4 -0
- moge/model/dinov2/hub/backbones.py +156 -0
- moge/model/dinov2/hub/utils.py +39 -0
- moge/model/dinov2/layers/__init__.py +11 -0
- moge/model/dinov2/layers/attention.py +100 -0
- moge/model/dinov2/layers/block.py +259 -0
- moge/model/dinov2/layers/dino_head.py +58 -0
- moge/model/dinov2/layers/drop_path.py +34 -0
- moge/model/dinov2/layers/layer_scale.py +27 -0
- moge/model/dinov2/layers/mlp.py +40 -0
- moge/model/dinov2/layers/patch_embed.py +88 -0
- moge/model/dinov2/layers/swiglu_ffn.py +72 -0
- moge/model/dinov2/models/__init__.py +43 -0
- moge/model/dinov2/models/vision_transformer.py +407 -0
- moge/model/dinov2/utils/__init__.py +4 -0
- moge/model/dinov2/utils/cluster.py +95 -0
- moge/model/dinov2/utils/config.py +72 -0
- moge/model/dinov2/utils/dtype.py +37 -0
- moge/model/dinov2/utils/param_groups.py +103 -0
- moge/model/dinov2/utils/utils.py +95 -0
- moge/model/modules.py +259 -0
- moge/model/transforms.py +1344 -0
- moge/model/utils.py +54 -0
- moge/model/v2.py +359 -0
- moge/utils/__init__.py +5 -0
- moge/utils/download.py +60 -0
- moge/utils/geometry_numpy.py +411 -0
- moge/utils/geometry_torch.py +359 -0
- moge/utils/io.py +241 -0
- moge/utils/panorama.py +196 -0
- moge/utils/pipeline.py +508 -0
- moge/utils/tools.py +294 -0
- moge/utils/vis.py +70 -0
- moge/utils/webfile.py +78 -0
- moge/utils/webzipfile.py +133 -0
app.py
CHANGED
|
@@ -18,17 +18,7 @@ import random
|
|
| 18 |
import imageio
|
| 19 |
from einops import repeat
|
| 20 |
from huggingface_hub import snapshot_download
|
| 21 |
-
from
|
| 22 |
-
from threeDFixer.pipelines import ThreeDFixerPipeline
|
| 23 |
-
from threeDFixer.datasets.utils import (
|
| 24 |
-
edge_mask_morph_gradient,
|
| 25 |
-
process_scene_image,
|
| 26 |
-
process_instance_image,
|
| 27 |
-
transform_vertices,
|
| 28 |
-
normalize_vertices,
|
| 29 |
-
project2ply
|
| 30 |
-
)
|
| 31 |
-
from threeDFixer.utils import render_utils, postprocessing_utils
|
| 32 |
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
| 33 |
from scripts.grounding_sam import plot_segmentation, segment
|
| 34 |
import copy
|
|
@@ -192,6 +182,11 @@ def run_depth_estimation(
|
|
| 192 |
) -> Image.Image:
|
| 193 |
rgb_image = image_prompts["image"].convert("RGB")
|
| 194 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
|
| 196 |
|
| 197 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
@@ -291,45 +286,6 @@ def set_random_seed(seed):
|
|
| 291 |
if torch.cuda.is_available():
|
| 292 |
torch.cuda.manual_seed_all(seed)
|
| 293 |
|
| 294 |
-
def export_single_glb_from_outputs(
|
| 295 |
-
outputs,
|
| 296 |
-
fine_scale,
|
| 297 |
-
fine_trans,
|
| 298 |
-
coarse_scale,
|
| 299 |
-
coarse_trans,
|
| 300 |
-
trans,
|
| 301 |
-
scale,
|
| 302 |
-
rot,
|
| 303 |
-
work_space,
|
| 304 |
-
instance_name,
|
| 305 |
-
run_id
|
| 306 |
-
):
|
| 307 |
-
|
| 308 |
-
with torch.enable_grad():
|
| 309 |
-
glb = postprocessing_utils.to_glb(
|
| 310 |
-
outputs["gaussian"][0],
|
| 311 |
-
outputs["mesh"][0],
|
| 312 |
-
simplify=0.95,
|
| 313 |
-
texture_size=1024,
|
| 314 |
-
transform_fn=lambda x: transform_vertices(
|
| 315 |
-
x,
|
| 316 |
-
ops=["scale", "translation", "scale", "translation"],
|
| 317 |
-
params=[fine_scale, fine_trans[None], coarse_scale, coarse_trans[None]],
|
| 318 |
-
),
|
| 319 |
-
verbose=False
|
| 320 |
-
)
|
| 321 |
-
|
| 322 |
-
instance_glb_path = os.path.abspath(
|
| 323 |
-
os.path.join(work_space, f"{run_id}_{instance_name}.glb")
|
| 324 |
-
)
|
| 325 |
-
|
| 326 |
-
glb.apply_translation(-trans) \
|
| 327 |
-
.apply_scale(1.0 / (scale + 1e-6)) \
|
| 328 |
-
.apply_transform(rot) \
|
| 329 |
-
.export(instance_glb_path)
|
| 330 |
-
|
| 331 |
-
return instance_glb_path, glb
|
| 332 |
-
|
| 333 |
|
| 334 |
def export_scene_glb(trimeshes, work_space, scene_name):
|
| 335 |
scene_path = os.path.abspath(os.path.join(work_space, scene_name))
|
|
@@ -356,6 +312,59 @@ def run_generation(
|
|
| 356 |
cfg_interval_end: float = 1.0,
|
| 357 |
t_rescale: float = 3.0,
|
| 358 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
global dpt_pack
|
| 360 |
global work_space
|
| 361 |
global generated_object_map
|
|
|
|
| 18 |
import imageio
|
| 19 |
from einops import repeat
|
| 20 |
from huggingface_hub import snapshot_download
|
| 21 |
+
from moge.model.v2 import MoGeModel
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
from transformers import AutoModelForMaskGeneration, AutoProcessor
|
| 23 |
from scripts.grounding_sam import plot_segmentation, segment
|
| 24 |
import copy
|
|
|
|
| 182 |
) -> Image.Image:
|
| 183 |
rgb_image = image_prompts["image"].convert("RGB")
|
| 184 |
|
| 185 |
+
from threeDFixer.datasets.utils import (
|
| 186 |
+
normalize_vertices,
|
| 187 |
+
project2ply
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
|
| 191 |
|
| 192 |
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
|
|
| 286 |
if torch.cuda.is_available():
|
| 287 |
torch.cuda.manual_seed_all(seed)
|
| 288 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 289 |
|
| 290 |
def export_scene_glb(trimeshes, work_space, scene_name):
|
| 291 |
scene_path = os.path.abspath(os.path.join(work_space, scene_name))
|
|
|
|
| 312 |
cfg_interval_end: float = 1.0,
|
| 313 |
t_rescale: float = 3.0,
|
| 314 |
):
|
| 315 |
+
|
| 316 |
+
from threeDFixer.pipelines import ThreeDFixerPipeline
|
| 317 |
+
from threeDFixer.datasets.utils import (
|
| 318 |
+
edge_mask_morph_gradient,
|
| 319 |
+
process_scene_image,
|
| 320 |
+
process_instance_image,
|
| 321 |
+
)
|
| 322 |
+
from threeDFixer.utils import render_utils
|
| 323 |
+
|
| 324 |
+
def export_single_glb_from_outputs(
|
| 325 |
+
outputs,
|
| 326 |
+
fine_scale,
|
| 327 |
+
fine_trans,
|
| 328 |
+
coarse_scale,
|
| 329 |
+
coarse_trans,
|
| 330 |
+
trans,
|
| 331 |
+
scale,
|
| 332 |
+
rot,
|
| 333 |
+
work_space,
|
| 334 |
+
instance_name,
|
| 335 |
+
run_id
|
| 336 |
+
):
|
| 337 |
+
|
| 338 |
+
from threeDFixer.datasets.utils import (
|
| 339 |
+
transform_vertices,
|
| 340 |
+
)
|
| 341 |
+
from threeDFixer.utils import postprocessing_utils
|
| 342 |
+
|
| 343 |
+
with torch.enable_grad():
|
| 344 |
+
glb = postprocessing_utils.to_glb(
|
| 345 |
+
outputs["gaussian"][0],
|
| 346 |
+
outputs["mesh"][0],
|
| 347 |
+
simplify=0.95,
|
| 348 |
+
texture_size=1024,
|
| 349 |
+
transform_fn=lambda x: transform_vertices(
|
| 350 |
+
x,
|
| 351 |
+
ops=["scale", "translation", "scale", "translation"],
|
| 352 |
+
params=[fine_scale, fine_trans[None], coarse_scale, coarse_trans[None]],
|
| 353 |
+
),
|
| 354 |
+
verbose=False
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
instance_glb_path = os.path.abspath(
|
| 358 |
+
os.path.join(work_space, f"{run_id}_{instance_name}.glb")
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
glb.apply_translation(-trans) \
|
| 362 |
+
.apply_scale(1.0 / (scale + 1e-6)) \
|
| 363 |
+
.apply_transform(rot) \
|
| 364 |
+
.export(instance_glb_path)
|
| 365 |
+
|
| 366 |
+
return instance_glb_path, glb
|
| 367 |
+
|
| 368 |
global dpt_pack
|
| 369 |
global work_space
|
| 370 |
global generated_object_map
|
moge/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
moge/model/__init__.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
import importlib
|
| 7 |
+
from typing import *
|
| 8 |
+
|
| 9 |
+
if TYPE_CHECKING:
|
| 10 |
+
from .v1 import MoGeModel as MoGeModelV1
|
| 11 |
+
from .v2 import MoGeModel as MoGeModelV2
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
|
| 15 |
+
assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
module = importlib.import_module(f'.{version}', __package__)
|
| 19 |
+
except ModuleNotFoundError:
|
| 20 |
+
raise ValueError(f'Model version "{version}" not found.')
|
| 21 |
+
|
| 22 |
+
cls = getattr(module, 'MoGeModel')
|
| 23 |
+
return cls
|
moge/model/dinov2/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
__version__ = "0.0.1"
|
moge/model/dinov2/hub/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
moge/model/dinov2/hub/backbones.py
ADDED
|
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
from typing import Union
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Weights(Enum):
|
| 15 |
+
LVD142M = "LVD142M"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _make_dinov2_model(
|
| 19 |
+
*,
|
| 20 |
+
arch_name: str = "vit_large",
|
| 21 |
+
img_size: int = 518,
|
| 22 |
+
patch_size: int = 14,
|
| 23 |
+
init_values: float = 1.0,
|
| 24 |
+
ffn_layer: str = "mlp",
|
| 25 |
+
block_chunks: int = 0,
|
| 26 |
+
num_register_tokens: int = 0,
|
| 27 |
+
interpolate_antialias: bool = False,
|
| 28 |
+
interpolate_offset: float = 0.1,
|
| 29 |
+
pretrained: bool = True,
|
| 30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
| 31 |
+
**kwargs,
|
| 32 |
+
):
|
| 33 |
+
from ..models import vision_transformer as vits
|
| 34 |
+
|
| 35 |
+
if isinstance(weights, str):
|
| 36 |
+
try:
|
| 37 |
+
weights = Weights[weights]
|
| 38 |
+
except KeyError:
|
| 39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
| 40 |
+
|
| 41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 42 |
+
vit_kwargs = dict(
|
| 43 |
+
img_size=img_size,
|
| 44 |
+
patch_size=patch_size,
|
| 45 |
+
init_values=init_values,
|
| 46 |
+
ffn_layer=ffn_layer,
|
| 47 |
+
block_chunks=block_chunks,
|
| 48 |
+
num_register_tokens=num_register_tokens,
|
| 49 |
+
interpolate_antialias=interpolate_antialias,
|
| 50 |
+
interpolate_offset=interpolate_offset,
|
| 51 |
+
)
|
| 52 |
+
vit_kwargs.update(**kwargs)
|
| 53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
| 54 |
+
|
| 55 |
+
if pretrained:
|
| 56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
| 57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
| 58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
| 59 |
+
model.load_state_dict(state_dict, strict=True)
|
| 60 |
+
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 65 |
+
"""
|
| 66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 67 |
+
"""
|
| 68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 72 |
+
"""
|
| 73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 74 |
+
"""
|
| 75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 79 |
+
"""
|
| 80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 81 |
+
"""
|
| 82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 86 |
+
"""
|
| 87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 88 |
+
"""
|
| 89 |
+
return _make_dinov2_model(
|
| 90 |
+
arch_name="vit_giant2",
|
| 91 |
+
ffn_layer="swiglufused",
|
| 92 |
+
weights=weights,
|
| 93 |
+
pretrained=pretrained,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 99 |
+
"""
|
| 100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 101 |
+
"""
|
| 102 |
+
return _make_dinov2_model(
|
| 103 |
+
arch_name="vit_small",
|
| 104 |
+
pretrained=pretrained,
|
| 105 |
+
weights=weights,
|
| 106 |
+
num_register_tokens=4,
|
| 107 |
+
interpolate_antialias=True,
|
| 108 |
+
interpolate_offset=0.0,
|
| 109 |
+
**kwargs,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 114 |
+
"""
|
| 115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 116 |
+
"""
|
| 117 |
+
return _make_dinov2_model(
|
| 118 |
+
arch_name="vit_base",
|
| 119 |
+
pretrained=pretrained,
|
| 120 |
+
weights=weights,
|
| 121 |
+
num_register_tokens=4,
|
| 122 |
+
interpolate_antialias=True,
|
| 123 |
+
interpolate_offset=0.0,
|
| 124 |
+
**kwargs,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 129 |
+
"""
|
| 130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 131 |
+
"""
|
| 132 |
+
return _make_dinov2_model(
|
| 133 |
+
arch_name="vit_large",
|
| 134 |
+
pretrained=pretrained,
|
| 135 |
+
weights=weights,
|
| 136 |
+
num_register_tokens=4,
|
| 137 |
+
interpolate_antialias=True,
|
| 138 |
+
interpolate_offset=0.0,
|
| 139 |
+
**kwargs,
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
| 144 |
+
"""
|
| 145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
| 146 |
+
"""
|
| 147 |
+
return _make_dinov2_model(
|
| 148 |
+
arch_name="vit_giant2",
|
| 149 |
+
ffn_layer="swiglufused",
|
| 150 |
+
weights=weights,
|
| 151 |
+
pretrained=pretrained,
|
| 152 |
+
num_register_tokens=4,
|
| 153 |
+
interpolate_antialias=True,
|
| 154 |
+
interpolate_offset=0.0,
|
| 155 |
+
**kwargs,
|
| 156 |
+
)
|
moge/model/dinov2/hub/utils.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import itertools
|
| 7 |
+
import math
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
| 18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
| 20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class CenterPadding(nn.Module):
|
| 24 |
+
def __init__(self, multiple):
|
| 25 |
+
super().__init__()
|
| 26 |
+
self.multiple = multiple
|
| 27 |
+
|
| 28 |
+
def _get_pad(self, size):
|
| 29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
| 30 |
+
pad_size = new_size - size
|
| 31 |
+
pad_size_left = pad_size // 2
|
| 32 |
+
pad_size_right = pad_size - pad_size_left
|
| 33 |
+
return pad_size_left, pad_size_right
|
| 34 |
+
|
| 35 |
+
@torch.inference_mode()
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
| 38 |
+
output = F.pad(x, pads)
|
| 39 |
+
return output
|
moge/model/dinov2/layers/__init__.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from .dino_head import DINOHead
|
| 7 |
+
from .mlp import Mlp
|
| 8 |
+
from .patch_embed import PatchEmbed
|
| 9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 10 |
+
from .block import NestedTensorBlock
|
| 11 |
+
from .attention import MemEffAttention
|
moge/model/dinov2/layers/attention.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
from torch import nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
logger = logging.getLogger("dinov2")
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 23 |
+
try:
|
| 24 |
+
if XFORMERS_ENABLED:
|
| 25 |
+
from xformers.ops import memory_efficient_attention, unbind
|
| 26 |
+
|
| 27 |
+
XFORMERS_AVAILABLE = True
|
| 28 |
+
# warnings.warn("xFormers is available (Attention)")
|
| 29 |
+
else:
|
| 30 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
| 31 |
+
raise ImportError
|
| 32 |
+
except ImportError:
|
| 33 |
+
XFORMERS_AVAILABLE = False
|
| 34 |
+
# warnings.warn("xFormers is not available (Attention)")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class Attention(nn.Module):
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
dim: int,
|
| 41 |
+
num_heads: int = 8,
|
| 42 |
+
qkv_bias: bool = False,
|
| 43 |
+
proj_bias: bool = True,
|
| 44 |
+
attn_drop: float = 0.0,
|
| 45 |
+
proj_drop: float = 0.0,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
self.num_heads = num_heads
|
| 49 |
+
head_dim = dim // num_heads
|
| 50 |
+
self.scale = head_dim**-0.5
|
| 51 |
+
|
| 52 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 53 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 54 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 55 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 56 |
+
|
| 57 |
+
# # Deprecated implementation, extremely slow
|
| 58 |
+
# def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 59 |
+
# B, N, C = x.shape
|
| 60 |
+
# qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
| 61 |
+
# q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 62 |
+
# attn = q @ k.transpose(-2, -1)
|
| 63 |
+
# attn = attn.softmax(dim=-1)
|
| 64 |
+
# attn = self.attn_drop(attn)
|
| 65 |
+
# x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 66 |
+
# x = self.proj(x)
|
| 67 |
+
# x = self.proj_drop(x)
|
| 68 |
+
# return x
|
| 69 |
+
|
| 70 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 71 |
+
B, N, C = x.shape
|
| 72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
| 73 |
+
|
| 74 |
+
q, k, v = qkv.unbind(0) # (B, H, N, C // H)
|
| 75 |
+
|
| 76 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 77 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 78 |
+
|
| 79 |
+
x = self.proj(x)
|
| 80 |
+
x = self.proj_drop(x)
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
class MemEffAttention(Attention):
|
| 84 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 85 |
+
if not XFORMERS_AVAILABLE:
|
| 86 |
+
if attn_bias is not None:
|
| 87 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 88 |
+
return super().forward(x)
|
| 89 |
+
|
| 90 |
+
B, N, C = x.shape
|
| 91 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 92 |
+
|
| 93 |
+
q, k, v = unbind(qkv, 2)
|
| 94 |
+
|
| 95 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 96 |
+
x = x.reshape([B, N, C])
|
| 97 |
+
|
| 98 |
+
x = self.proj(x)
|
| 99 |
+
x = self.proj_drop(x)
|
| 100 |
+
return x
|
moge/model/dinov2/layers/block.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
import logging
|
| 11 |
+
import os
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
import warnings
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
from torch import nn, Tensor
|
| 17 |
+
|
| 18 |
+
from .attention import Attention, MemEffAttention
|
| 19 |
+
from .drop_path import DropPath
|
| 20 |
+
from .layer_scale import LayerScale
|
| 21 |
+
from .mlp import Mlp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger("dinov2")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 28 |
+
try:
|
| 29 |
+
if XFORMERS_ENABLED:
|
| 30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
| 31 |
+
|
| 32 |
+
XFORMERS_AVAILABLE = True
|
| 33 |
+
# warnings.warn("xFormers is available (Block)")
|
| 34 |
+
else:
|
| 35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
| 36 |
+
raise ImportError
|
| 37 |
+
except ImportError:
|
| 38 |
+
XFORMERS_AVAILABLE = False
|
| 39 |
+
# warnings.warn("xFormers is not available (Block)")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Block(nn.Module):
|
| 43 |
+
def __init__(
|
| 44 |
+
self,
|
| 45 |
+
dim: int,
|
| 46 |
+
num_heads: int,
|
| 47 |
+
mlp_ratio: float = 4.0,
|
| 48 |
+
qkv_bias: bool = False,
|
| 49 |
+
proj_bias: bool = True,
|
| 50 |
+
ffn_bias: bool = True,
|
| 51 |
+
drop: float = 0.0,
|
| 52 |
+
attn_drop: float = 0.0,
|
| 53 |
+
init_values=None,
|
| 54 |
+
drop_path: float = 0.0,
|
| 55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 59 |
+
) -> None:
|
| 60 |
+
super().__init__()
|
| 61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 62 |
+
self.norm1 = norm_layer(dim)
|
| 63 |
+
self.attn = attn_class(
|
| 64 |
+
dim,
|
| 65 |
+
num_heads=num_heads,
|
| 66 |
+
qkv_bias=qkv_bias,
|
| 67 |
+
proj_bias=proj_bias,
|
| 68 |
+
attn_drop=attn_drop,
|
| 69 |
+
proj_drop=drop,
|
| 70 |
+
)
|
| 71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 73 |
+
|
| 74 |
+
self.norm2 = norm_layer(dim)
|
| 75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 76 |
+
self.mlp = ffn_layer(
|
| 77 |
+
in_features=dim,
|
| 78 |
+
hidden_features=mlp_hidden_dim,
|
| 79 |
+
act_layer=act_layer,
|
| 80 |
+
drop=drop,
|
| 81 |
+
bias=ffn_bias,
|
| 82 |
+
)
|
| 83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 85 |
+
|
| 86 |
+
self.sample_drop_ratio = drop_path
|
| 87 |
+
|
| 88 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
| 90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 91 |
+
|
| 92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
| 93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 94 |
+
|
| 95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 97 |
+
x = drop_add_residual_stochastic_depth(
|
| 98 |
+
x,
|
| 99 |
+
residual_func=attn_residual_func,
|
| 100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 101 |
+
)
|
| 102 |
+
x = drop_add_residual_stochastic_depth(
|
| 103 |
+
x,
|
| 104 |
+
residual_func=ffn_residual_func,
|
| 105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 106 |
+
)
|
| 107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 110 |
+
else:
|
| 111 |
+
x = x + attn_residual_func(x)
|
| 112 |
+
x = x + ffn_residual_func(x)
|
| 113 |
+
return x
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def drop_add_residual_stochastic_depth(
|
| 117 |
+
x: Tensor,
|
| 118 |
+
residual_func: Callable[[Tensor], Tensor],
|
| 119 |
+
sample_drop_ratio: float = 0.0,
|
| 120 |
+
) -> Tensor:
|
| 121 |
+
# 1) extract subset using permutation
|
| 122 |
+
b, n, d = x.shape
|
| 123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 125 |
+
x_subset = x[brange]
|
| 126 |
+
|
| 127 |
+
# 2) apply residual_func to get residual
|
| 128 |
+
residual = residual_func(x_subset)
|
| 129 |
+
|
| 130 |
+
x_flat = x.flatten(1)
|
| 131 |
+
residual = residual.flatten(1)
|
| 132 |
+
|
| 133 |
+
residual_scale_factor = b / sample_subset_size
|
| 134 |
+
|
| 135 |
+
# 3) add the residual
|
| 136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
| 153 |
+
else:
|
| 154 |
+
x_plus_residual = scaled_index_add(
|
| 155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
| 156 |
+
)
|
| 157 |
+
return x_plus_residual
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 164 |
+
"""
|
| 165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 166 |
+
"""
|
| 167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
| 168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 169 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 170 |
+
seqlens = []
|
| 171 |
+
for b, x in zip(batch_sizes, x_list):
|
| 172 |
+
for _ in range(b):
|
| 173 |
+
seqlens.append(x.shape[1])
|
| 174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 175 |
+
attn_bias._batch_sizes = batch_sizes
|
| 176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 177 |
+
|
| 178 |
+
if branges is not None:
|
| 179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
| 180 |
+
else:
|
| 181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 183 |
+
|
| 184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def drop_add_residual_stochastic_depth_list(
|
| 188 |
+
x_list: List[Tensor],
|
| 189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
| 190 |
+
sample_drop_ratio: float = 0.0,
|
| 191 |
+
scaling_vector=None,
|
| 192 |
+
) -> Tensor:
|
| 193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
| 195 |
+
branges = [s[0] for s in branges_scales]
|
| 196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 197 |
+
|
| 198 |
+
# 2) get attention bias and index+concat the tensors
|
| 199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 200 |
+
|
| 201 |
+
# 3) apply residual_func to get residual, and split the result
|
| 202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 203 |
+
|
| 204 |
+
outputs = []
|
| 205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
| 206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
| 207 |
+
return outputs
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class NestedTensorBlock(Block):
|
| 211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
| 212 |
+
"""
|
| 213 |
+
x_list contains a list of tensors to nest together and run
|
| 214 |
+
"""
|
| 215 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 216 |
+
|
| 217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 218 |
+
|
| 219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 221 |
+
|
| 222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 223 |
+
return self.mlp(self.norm2(x))
|
| 224 |
+
|
| 225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 226 |
+
x_list,
|
| 227 |
+
residual_func=attn_residual_func,
|
| 228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 230 |
+
)
|
| 231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 232 |
+
x_list,
|
| 233 |
+
residual_func=ffn_residual_func,
|
| 234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
| 236 |
+
)
|
| 237 |
+
return x_list
|
| 238 |
+
else:
|
| 239 |
+
|
| 240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 242 |
+
|
| 243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
| 244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 245 |
+
|
| 246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 248 |
+
x = x + ffn_residual_func(x)
|
| 249 |
+
return attn_bias.split(x)
|
| 250 |
+
|
| 251 |
+
def forward(self, x_or_x_list):
|
| 252 |
+
if isinstance(x_or_x_list, Tensor):
|
| 253 |
+
return super().forward(x_or_x_list)
|
| 254 |
+
elif isinstance(x_or_x_list, list):
|
| 255 |
+
if not XFORMERS_AVAILABLE:
|
| 256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
| 257 |
+
return self.forward_nested(x_or_x_list)
|
| 258 |
+
else:
|
| 259 |
+
raise AssertionError
|
moge/model/dinov2/layers/dino_head.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.nn.init import trunc_normal_
|
| 9 |
+
from torch.nn.utils import weight_norm
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DINOHead(nn.Module):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
in_dim,
|
| 16 |
+
out_dim,
|
| 17 |
+
use_bn=False,
|
| 18 |
+
nlayers=3,
|
| 19 |
+
hidden_dim=2048,
|
| 20 |
+
bottleneck_dim=256,
|
| 21 |
+
mlp_bias=True,
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
nlayers = max(nlayers, 1)
|
| 25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
| 26 |
+
self.apply(self._init_weights)
|
| 27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 28 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 29 |
+
|
| 30 |
+
def _init_weights(self, m):
|
| 31 |
+
if isinstance(m, nn.Linear):
|
| 32 |
+
trunc_normal_(m.weight, std=0.02)
|
| 33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 34 |
+
nn.init.constant_(m.bias, 0)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
x = self.mlp(x)
|
| 38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 40 |
+
x = self.last_layer(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
| 45 |
+
if nlayers == 1:
|
| 46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 47 |
+
else:
|
| 48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 49 |
+
if use_bn:
|
| 50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 51 |
+
layers.append(nn.GELU())
|
| 52 |
+
for _ in range(nlayers - 2):
|
| 53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 54 |
+
if use_bn:
|
| 55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 56 |
+
layers.append(nn.GELU())
|
| 57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 58 |
+
return nn.Sequential(*layers)
|
moge/model/dinov2/layers/drop_path.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from torch import nn
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 15 |
+
if drop_prob == 0.0 or not training:
|
| 16 |
+
return x
|
| 17 |
+
keep_prob = 1 - drop_prob
|
| 18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
| 19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 20 |
+
if keep_prob > 0.0:
|
| 21 |
+
random_tensor.div_(keep_prob)
|
| 22 |
+
output = x * random_tensor
|
| 23 |
+
return output
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class DropPath(nn.Module):
|
| 27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 28 |
+
|
| 29 |
+
def __init__(self, drop_prob=None):
|
| 30 |
+
super(DropPath, self).__init__()
|
| 31 |
+
self.drop_prob = drop_prob
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
return drop_path(x, self.drop_prob, self.training)
|
moge/model/dinov2/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 7 |
+
|
| 8 |
+
from typing import Union
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch import Tensor
|
| 12 |
+
from torch import nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class LayerScale(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self,
|
| 18 |
+
dim: int,
|
| 19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 20 |
+
inplace: bool = False,
|
| 21 |
+
) -> None:
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.inplace = inplace
|
| 24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 25 |
+
|
| 26 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
moge/model/dinov2/layers/mlp.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional
|
| 12 |
+
|
| 13 |
+
from torch import Tensor, nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Mlp(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
in_features: int,
|
| 20 |
+
hidden_features: Optional[int] = None,
|
| 21 |
+
out_features: Optional[int] = None,
|
| 22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 23 |
+
drop: float = 0.0,
|
| 24 |
+
bias: bool = True,
|
| 25 |
+
) -> None:
|
| 26 |
+
super().__init__()
|
| 27 |
+
out_features = out_features or in_features
|
| 28 |
+
hidden_features = hidden_features or in_features
|
| 29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 30 |
+
self.act = act_layer()
|
| 31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 32 |
+
self.drop = nn.Dropout(drop)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
x = self.fc1(x)
|
| 36 |
+
x = self.act(x)
|
| 37 |
+
x = self.drop(x)
|
| 38 |
+
x = self.fc2(x)
|
| 39 |
+
x = self.drop(x)
|
| 40 |
+
return x
|
moge/model/dinov2/layers/patch_embed.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 9 |
+
|
| 10 |
+
from typing import Callable, Optional, Tuple, Union
|
| 11 |
+
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def make_2tuple(x):
|
| 17 |
+
if isinstance(x, tuple):
|
| 18 |
+
assert len(x) == 2
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
assert isinstance(x, int)
|
| 22 |
+
return (x, x)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class PatchEmbed(nn.Module):
|
| 26 |
+
"""
|
| 27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
img_size: Image size.
|
| 31 |
+
patch_size: Patch token size.
|
| 32 |
+
in_chans: Number of input image channels.
|
| 33 |
+
embed_dim: Number of linear projection output channels.
|
| 34 |
+
norm_layer: Normalization layer.
|
| 35 |
+
"""
|
| 36 |
+
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 41 |
+
in_chans: int = 3,
|
| 42 |
+
embed_dim: int = 768,
|
| 43 |
+
norm_layer: Optional[Callable] = None,
|
| 44 |
+
flatten_embedding: bool = True,
|
| 45 |
+
) -> None:
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
image_HW = make_2tuple(img_size)
|
| 49 |
+
patch_HW = make_2tuple(patch_size)
|
| 50 |
+
patch_grid_size = (
|
| 51 |
+
image_HW[0] // patch_HW[0],
|
| 52 |
+
image_HW[1] // patch_HW[1],
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
self.img_size = image_HW
|
| 56 |
+
self.patch_size = patch_HW
|
| 57 |
+
self.patches_resolution = patch_grid_size
|
| 58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 59 |
+
|
| 60 |
+
self.in_chans = in_chans
|
| 61 |
+
self.embed_dim = embed_dim
|
| 62 |
+
|
| 63 |
+
self.flatten_embedding = flatten_embedding
|
| 64 |
+
|
| 65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
| 66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 67 |
+
|
| 68 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 69 |
+
_, _, H, W = x.shape
|
| 70 |
+
patch_H, patch_W = self.patch_size
|
| 71 |
+
|
| 72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 74 |
+
|
| 75 |
+
x = self.proj(x) # B C H W
|
| 76 |
+
H, W = x.size(2), x.size(3)
|
| 77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 78 |
+
x = self.norm(x)
|
| 79 |
+
if not self.flatten_embedding:
|
| 80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 81 |
+
return x
|
| 82 |
+
|
| 83 |
+
def flops(self) -> float:
|
| 84 |
+
Ho, Wo = self.patches_resolution
|
| 85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
| 86 |
+
if self.norm is not None:
|
| 87 |
+
flops += Ho * Wo * self.embed_dim
|
| 88 |
+
return flops
|
moge/model/dinov2/layers/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
import warnings
|
| 9 |
+
|
| 10 |
+
from torch import Tensor, nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class SwiGLUFFN(nn.Module):
|
| 15 |
+
def __init__(
|
| 16 |
+
self,
|
| 17 |
+
in_features: int,
|
| 18 |
+
hidden_features: Optional[int] = None,
|
| 19 |
+
out_features: Optional[int] = None,
|
| 20 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 21 |
+
drop: float = 0.0,
|
| 22 |
+
bias: bool = True,
|
| 23 |
+
) -> None:
|
| 24 |
+
super().__init__()
|
| 25 |
+
out_features = out_features or in_features
|
| 26 |
+
hidden_features = hidden_features or in_features
|
| 27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 29 |
+
|
| 30 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 31 |
+
x12 = self.w12(x)
|
| 32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 33 |
+
hidden = F.silu(x1) * x2
|
| 34 |
+
return self.w3(hidden)
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
| 38 |
+
try:
|
| 39 |
+
if XFORMERS_ENABLED:
|
| 40 |
+
from xformers.ops import SwiGLU
|
| 41 |
+
|
| 42 |
+
XFORMERS_AVAILABLE = True
|
| 43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
| 44 |
+
else:
|
| 45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
| 46 |
+
raise ImportError
|
| 47 |
+
except ImportError:
|
| 48 |
+
SwiGLU = SwiGLUFFN
|
| 49 |
+
XFORMERS_AVAILABLE = False
|
| 50 |
+
|
| 51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 55 |
+
def __init__(
|
| 56 |
+
self,
|
| 57 |
+
in_features: int,
|
| 58 |
+
hidden_features: Optional[int] = None,
|
| 59 |
+
out_features: Optional[int] = None,
|
| 60 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 61 |
+
drop: float = 0.0,
|
| 62 |
+
bias: bool = True,
|
| 63 |
+
) -> None:
|
| 64 |
+
out_features = out_features or in_features
|
| 65 |
+
hidden_features = hidden_features or in_features
|
| 66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 67 |
+
super().__init__(
|
| 68 |
+
in_features=in_features,
|
| 69 |
+
hidden_features=hidden_features,
|
| 70 |
+
out_features=out_features,
|
| 71 |
+
bias=bias,
|
| 72 |
+
)
|
moge/model/dinov2/models/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
|
| 8 |
+
from . import vision_transformer as vits
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
logger = logging.getLogger("dinov2")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
| 15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
| 16 |
+
if "vit" in args.arch:
|
| 17 |
+
vit_kwargs = dict(
|
| 18 |
+
img_size=img_size,
|
| 19 |
+
patch_size=args.patch_size,
|
| 20 |
+
init_values=args.layerscale,
|
| 21 |
+
ffn_layer=args.ffn_layer,
|
| 22 |
+
block_chunks=args.block_chunks,
|
| 23 |
+
qkv_bias=args.qkv_bias,
|
| 24 |
+
proj_bias=args.proj_bias,
|
| 25 |
+
ffn_bias=args.ffn_bias,
|
| 26 |
+
num_register_tokens=args.num_register_tokens,
|
| 27 |
+
interpolate_offset=args.interpolate_offset,
|
| 28 |
+
interpolate_antialias=args.interpolate_antialias,
|
| 29 |
+
)
|
| 30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
| 31 |
+
if only_teacher:
|
| 32 |
+
return teacher, teacher.embed_dim
|
| 33 |
+
student = vits.__dict__[args.arch](
|
| 34 |
+
**vit_kwargs,
|
| 35 |
+
drop_path_rate=args.drop_path_rate,
|
| 36 |
+
drop_path_uniform=args.drop_path_uniform,
|
| 37 |
+
)
|
| 38 |
+
embed_dim = student.embed_dim
|
| 39 |
+
return student, teacher, embed_dim
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
| 43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
moge/model/dinov2/models/vision_transformer.py
ADDED
|
@@ -0,0 +1,407 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
# References:
|
| 7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
| 8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 9 |
+
|
| 10 |
+
from functools import partial
|
| 11 |
+
import math
|
| 12 |
+
import logging
|
| 13 |
+
from typing import Sequence, Tuple, Union, Callable, Optional, List
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
import torch.utils.checkpoint
|
| 18 |
+
from torch.nn.init import trunc_normal_
|
| 19 |
+
|
| 20 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
| 27 |
+
if not depth_first and include_root:
|
| 28 |
+
fn(module=module, name=name)
|
| 29 |
+
for child_name, child_module in module.named_children():
|
| 30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
| 32 |
+
if depth_first and include_root:
|
| 33 |
+
fn(module=module, name=name)
|
| 34 |
+
return module
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class BlockChunk(nn.ModuleList):
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
for b in self:
|
| 40 |
+
x = b(x)
|
| 41 |
+
return x
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class DinoVisionTransformer(nn.Module):
|
| 45 |
+
def __init__(
|
| 46 |
+
self,
|
| 47 |
+
img_size=224,
|
| 48 |
+
patch_size=16,
|
| 49 |
+
in_chans=3,
|
| 50 |
+
embed_dim=768,
|
| 51 |
+
depth=12,
|
| 52 |
+
num_heads=12,
|
| 53 |
+
mlp_ratio=4.0,
|
| 54 |
+
qkv_bias=True,
|
| 55 |
+
ffn_bias=True,
|
| 56 |
+
proj_bias=True,
|
| 57 |
+
drop_path_rate=0.0,
|
| 58 |
+
drop_path_uniform=False,
|
| 59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 60 |
+
embed_layer=PatchEmbed,
|
| 61 |
+
act_layer=nn.GELU,
|
| 62 |
+
block_fn=Block,
|
| 63 |
+
ffn_layer="mlp",
|
| 64 |
+
block_chunks=1,
|
| 65 |
+
num_register_tokens=0,
|
| 66 |
+
interpolate_antialias=False,
|
| 67 |
+
interpolate_offset=0.1,
|
| 68 |
+
):
|
| 69 |
+
"""
|
| 70 |
+
Args:
|
| 71 |
+
img_size (int, tuple): input image size
|
| 72 |
+
patch_size (int, tuple): patch size
|
| 73 |
+
in_chans (int): number of input channels
|
| 74 |
+
embed_dim (int): embedding dimension
|
| 75 |
+
depth (int): depth of transformer
|
| 76 |
+
num_heads (int): number of attention heads
|
| 77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 78 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 80 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 81 |
+
drop_path_rate (float): stochastic depth rate
|
| 82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 83 |
+
weight_init (str): weight init scheme
|
| 84 |
+
init_values (float): layer-scale init values
|
| 85 |
+
embed_layer (nn.Module): patch embedding layer
|
| 86 |
+
act_layer (nn.Module): MLP activation layer
|
| 87 |
+
block_fn (nn.Module): transformer block class
|
| 88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
| 91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
| 92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
| 93 |
+
"""
|
| 94 |
+
super().__init__()
|
| 95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 96 |
+
|
| 97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
| 98 |
+
self.num_tokens = 1
|
| 99 |
+
self.n_blocks = depth
|
| 100 |
+
self.num_heads = num_heads
|
| 101 |
+
self.patch_size = patch_size
|
| 102 |
+
self.num_register_tokens = num_register_tokens
|
| 103 |
+
self.interpolate_antialias = interpolate_antialias
|
| 104 |
+
self.interpolate_offset = interpolate_offset
|
| 105 |
+
|
| 106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
| 107 |
+
num_patches = self.patch_embed.num_patches
|
| 108 |
+
|
| 109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
| 111 |
+
assert num_register_tokens >= 0
|
| 112 |
+
self.register_tokens = (
|
| 113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
if drop_path_uniform is True:
|
| 117 |
+
dpr = [drop_path_rate] * depth
|
| 118 |
+
else:
|
| 119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
| 120 |
+
|
| 121 |
+
if ffn_layer == "mlp":
|
| 122 |
+
logger.info("using MLP layer as FFN")
|
| 123 |
+
ffn_layer = Mlp
|
| 124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 125 |
+
logger.info("using SwiGLU layer as FFN")
|
| 126 |
+
ffn_layer = SwiGLUFFNFused
|
| 127 |
+
elif ffn_layer == "identity":
|
| 128 |
+
logger.info("using Identity layer as FFN")
|
| 129 |
+
|
| 130 |
+
def f(*args, **kwargs):
|
| 131 |
+
return nn.Identity()
|
| 132 |
+
|
| 133 |
+
ffn_layer = f
|
| 134 |
+
else:
|
| 135 |
+
raise NotImplementedError
|
| 136 |
+
|
| 137 |
+
blocks_list = [
|
| 138 |
+
block_fn(
|
| 139 |
+
dim=embed_dim,
|
| 140 |
+
num_heads=num_heads,
|
| 141 |
+
mlp_ratio=mlp_ratio,
|
| 142 |
+
qkv_bias=qkv_bias,
|
| 143 |
+
proj_bias=proj_bias,
|
| 144 |
+
ffn_bias=ffn_bias,
|
| 145 |
+
drop_path=dpr[i],
|
| 146 |
+
norm_layer=norm_layer,
|
| 147 |
+
act_layer=act_layer,
|
| 148 |
+
ffn_layer=ffn_layer,
|
| 149 |
+
init_values=init_values,
|
| 150 |
+
)
|
| 151 |
+
for i in range(depth)
|
| 152 |
+
]
|
| 153 |
+
if block_chunks > 0:
|
| 154 |
+
self.chunked_blocks = True
|
| 155 |
+
chunked_blocks = []
|
| 156 |
+
chunksize = depth // block_chunks
|
| 157 |
+
for i in range(0, depth, chunksize):
|
| 158 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
| 160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 161 |
+
else:
|
| 162 |
+
self.chunked_blocks = False
|
| 163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 164 |
+
|
| 165 |
+
self.norm = norm_layer(embed_dim)
|
| 166 |
+
self.head = nn.Identity()
|
| 167 |
+
|
| 168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 169 |
+
|
| 170 |
+
self.init_weights()
|
| 171 |
+
|
| 172 |
+
@property
|
| 173 |
+
def onnx_compatible_mode(self):
|
| 174 |
+
return getattr(self, "_onnx_compatible_mode", False)
|
| 175 |
+
|
| 176 |
+
@onnx_compatible_mode.setter
|
| 177 |
+
def onnx_compatible_mode(self, value: bool):
|
| 178 |
+
self._onnx_compatible_mode = value
|
| 179 |
+
|
| 180 |
+
def init_weights(self):
|
| 181 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 182 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 183 |
+
if self.register_tokens is not None:
|
| 184 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 185 |
+
named_apply(init_weights_vit_timm, self)
|
| 186 |
+
|
| 187 |
+
def interpolate_pos_encoding(self, x, h, w):
|
| 188 |
+
previous_dtype = x.dtype
|
| 189 |
+
npatch = x.shape[1] - 1
|
| 190 |
+
batch_size = x.shape[0]
|
| 191 |
+
N = self.pos_embed.shape[1] - 1
|
| 192 |
+
if not self.onnx_compatible_mode and npatch == N and w == h:
|
| 193 |
+
return self.pos_embed
|
| 194 |
+
pos_embed = self.pos_embed.float()
|
| 195 |
+
class_pos_embed = pos_embed[:, 0, :]
|
| 196 |
+
patch_pos_embed = pos_embed[:, 1:, :]
|
| 197 |
+
dim = x.shape[-1]
|
| 198 |
+
h0, w0 = h // self.patch_size, w // self.patch_size
|
| 199 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
| 200 |
+
assert N == M * M
|
| 201 |
+
kwargs = {}
|
| 202 |
+
if not self.onnx_compatible_mode and self.interpolate_offset > 0:
|
| 203 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
| 204 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
| 205 |
+
sx = float(w0 + self.interpolate_offset) / M
|
| 206 |
+
sy = float(h0 + self.interpolate_offset) / M
|
| 207 |
+
kwargs["scale_factor"] = (sy, sx)
|
| 208 |
+
else:
|
| 209 |
+
# Simply specify an output size instead of a scale factor
|
| 210 |
+
kwargs["size"] = (h0, w0)
|
| 211 |
+
|
| 212 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 213 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
| 214 |
+
mode="bicubic",
|
| 215 |
+
antialias=self.interpolate_antialias,
|
| 216 |
+
**kwargs,
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
assert (h0, w0) == patch_pos_embed.shape[-2:]
|
| 220 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
|
| 221 |
+
return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
|
| 222 |
+
|
| 223 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 224 |
+
B, nc, h, w = x.shape
|
| 225 |
+
x = self.patch_embed(x)
|
| 226 |
+
|
| 227 |
+
if masks is not None:
|
| 228 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 229 |
+
|
| 230 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 231 |
+
x = x + self.interpolate_pos_encoding(x, h, w)
|
| 232 |
+
|
| 233 |
+
if self.register_tokens is not None:
|
| 234 |
+
x = torch.cat(
|
| 235 |
+
(
|
| 236 |
+
x[:, :1],
|
| 237 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
| 238 |
+
x[:, 1:],
|
| 239 |
+
),
|
| 240 |
+
dim=1,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
return x
|
| 244 |
+
|
| 245 |
+
def forward_features_list(self, x_list, masks_list):
|
| 246 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks, ar in zip(x_list, masks_list)]
|
| 247 |
+
for blk in self.blocks:
|
| 248 |
+
x = blk(x)
|
| 249 |
+
|
| 250 |
+
all_x = x
|
| 251 |
+
output = []
|
| 252 |
+
for x, masks in zip(all_x, masks_list):
|
| 253 |
+
x_norm = self.norm(x)
|
| 254 |
+
output.append(
|
| 255 |
+
{
|
| 256 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 257 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 258 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 259 |
+
"x_prenorm": x,
|
| 260 |
+
"masks": masks,
|
| 261 |
+
}
|
| 262 |
+
)
|
| 263 |
+
return output
|
| 264 |
+
|
| 265 |
+
def forward_features(self, x, masks=None):
|
| 266 |
+
if isinstance(x, list):
|
| 267 |
+
return self.forward_features_list(x, masks)
|
| 268 |
+
|
| 269 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 270 |
+
|
| 271 |
+
for blk in self.blocks:
|
| 272 |
+
x = blk(x)
|
| 273 |
+
|
| 274 |
+
x_norm = self.norm(x)
|
| 275 |
+
return {
|
| 276 |
+
"x_norm_clstoken": x_norm[:, 0],
|
| 277 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
| 278 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
| 279 |
+
"x_prenorm": x,
|
| 280 |
+
"masks": masks,
|
| 281 |
+
}
|
| 282 |
+
|
| 283 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 284 |
+
x = self.prepare_tokens_with_masks(x)
|
| 285 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 286 |
+
output, total_block_len = [], len(self.blocks)
|
| 287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 288 |
+
for i, blk in enumerate(self.blocks):
|
| 289 |
+
x = blk(x)
|
| 290 |
+
if i in blocks_to_take:
|
| 291 |
+
output.append(x)
|
| 292 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 293 |
+
return output
|
| 294 |
+
|
| 295 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
| 296 |
+
x = self.prepare_tokens_with_masks(x)
|
| 297 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 298 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
| 299 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 300 |
+
for block_chunk in self.blocks:
|
| 301 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 302 |
+
x = blk(x)
|
| 303 |
+
if i in blocks_to_take:
|
| 304 |
+
output.append(x)
|
| 305 |
+
i += 1
|
| 306 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 307 |
+
return output
|
| 308 |
+
|
| 309 |
+
def get_intermediate_layers(
|
| 310 |
+
self,
|
| 311 |
+
x: torch.Tensor,
|
| 312 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 313 |
+
reshape: bool = False,
|
| 314 |
+
return_class_token: bool = False,
|
| 315 |
+
norm=True,
|
| 316 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 317 |
+
if self.chunked_blocks:
|
| 318 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
| 319 |
+
else:
|
| 320 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 321 |
+
if norm:
|
| 322 |
+
outputs = [self.norm(out) for out in outputs]
|
| 323 |
+
class_tokens = [out[:, 0] for out in outputs]
|
| 324 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
| 325 |
+
if reshape:
|
| 326 |
+
B, _, w, h = x.shape
|
| 327 |
+
outputs = [
|
| 328 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 329 |
+
for out in outputs
|
| 330 |
+
]
|
| 331 |
+
if return_class_token:
|
| 332 |
+
return tuple(zip(outputs, class_tokens))
|
| 333 |
+
return tuple(outputs)
|
| 334 |
+
|
| 335 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 336 |
+
ret = self.forward_features(*args, **kwargs)
|
| 337 |
+
if is_training:
|
| 338 |
+
return ret
|
| 339 |
+
else:
|
| 340 |
+
return self.head(ret["x_norm_clstoken"])
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 344 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 345 |
+
if isinstance(module, nn.Linear):
|
| 346 |
+
trunc_normal_(module.weight, std=0.02)
|
| 347 |
+
if module.bias is not None:
|
| 348 |
+
nn.init.zeros_(module.bias)
|
| 349 |
+
|
| 350 |
+
|
| 351 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
| 352 |
+
model = DinoVisionTransformer(
|
| 353 |
+
patch_size=patch_size,
|
| 354 |
+
embed_dim=384,
|
| 355 |
+
depth=12,
|
| 356 |
+
num_heads=6,
|
| 357 |
+
mlp_ratio=4,
|
| 358 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 359 |
+
num_register_tokens=num_register_tokens,
|
| 360 |
+
**kwargs,
|
| 361 |
+
)
|
| 362 |
+
return model
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 366 |
+
model = DinoVisionTransformer(
|
| 367 |
+
patch_size=patch_size,
|
| 368 |
+
embed_dim=768,
|
| 369 |
+
depth=12,
|
| 370 |
+
num_heads=12,
|
| 371 |
+
mlp_ratio=4,
|
| 372 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 373 |
+
num_register_tokens=num_register_tokens,
|
| 374 |
+
**kwargs,
|
| 375 |
+
)
|
| 376 |
+
return model
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 380 |
+
model = DinoVisionTransformer(
|
| 381 |
+
patch_size=patch_size,
|
| 382 |
+
embed_dim=1024,
|
| 383 |
+
depth=24,
|
| 384 |
+
num_heads=16,
|
| 385 |
+
mlp_ratio=4,
|
| 386 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 387 |
+
num_register_tokens=num_register_tokens,
|
| 388 |
+
**kwargs,
|
| 389 |
+
)
|
| 390 |
+
return model
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
| 394 |
+
"""
|
| 395 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 396 |
+
"""
|
| 397 |
+
model = DinoVisionTransformer(
|
| 398 |
+
patch_size=patch_size,
|
| 399 |
+
embed_dim=1536,
|
| 400 |
+
depth=40,
|
| 401 |
+
num_heads=24,
|
| 402 |
+
mlp_ratio=4,
|
| 403 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 404 |
+
num_register_tokens=num_register_tokens,
|
| 405 |
+
**kwargs,
|
| 406 |
+
)
|
| 407 |
+
return model
|
moge/model/dinov2/utils/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
moge/model/dinov2/utils/cluster.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from enum import Enum
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, Optional
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ClusterType(Enum):
|
| 13 |
+
AWS = "aws"
|
| 14 |
+
FAIR = "fair"
|
| 15 |
+
RSC = "rsc"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _guess_cluster_type() -> ClusterType:
|
| 19 |
+
uname = os.uname()
|
| 20 |
+
if uname.sysname == "Linux":
|
| 21 |
+
if uname.release.endswith("-aws"):
|
| 22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
| 23 |
+
return ClusterType.AWS
|
| 24 |
+
elif uname.nodename.startswith("rsc"):
|
| 25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
| 26 |
+
return ClusterType.RSC
|
| 27 |
+
|
| 28 |
+
return ClusterType.FAIR
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
| 32 |
+
if cluster_type is None:
|
| 33 |
+
return _guess_cluster_type()
|
| 34 |
+
|
| 35 |
+
return cluster_type
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 39 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 40 |
+
if cluster_type is None:
|
| 41 |
+
return None
|
| 42 |
+
|
| 43 |
+
CHECKPOINT_DIRNAMES = {
|
| 44 |
+
ClusterType.AWS: "checkpoints",
|
| 45 |
+
ClusterType.FAIR: "checkpoint",
|
| 46 |
+
ClusterType.RSC: "checkpoint/dino",
|
| 47 |
+
}
|
| 48 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
| 52 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
| 53 |
+
if checkpoint_path is None:
|
| 54 |
+
return None
|
| 55 |
+
|
| 56 |
+
username = os.environ.get("USER")
|
| 57 |
+
assert username is not None
|
| 58 |
+
return checkpoint_path / username
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
| 62 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 63 |
+
if cluster_type is None:
|
| 64 |
+
return None
|
| 65 |
+
|
| 66 |
+
SLURM_PARTITIONS = {
|
| 67 |
+
ClusterType.AWS: "learnlab",
|
| 68 |
+
ClusterType.FAIR: "learnlab",
|
| 69 |
+
ClusterType.RSC: "learn",
|
| 70 |
+
}
|
| 71 |
+
return SLURM_PARTITIONS[cluster_type]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_slurm_executor_parameters(
|
| 75 |
+
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
| 76 |
+
) -> Dict[str, Any]:
|
| 77 |
+
# create default parameters
|
| 78 |
+
params = {
|
| 79 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
| 80 |
+
"gpus_per_node": num_gpus_per_node,
|
| 81 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
| 82 |
+
"cpus_per_task": 10,
|
| 83 |
+
"nodes": nodes,
|
| 84 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
| 85 |
+
}
|
| 86 |
+
# apply cluster-specific adjustments
|
| 87 |
+
cluster_type = get_cluster_type(cluster_type)
|
| 88 |
+
if cluster_type == ClusterType.AWS:
|
| 89 |
+
params["cpus_per_task"] = 12
|
| 90 |
+
del params["mem_gb"]
|
| 91 |
+
elif cluster_type == ClusterType.RSC:
|
| 92 |
+
params["cpus_per_task"] = 12
|
| 93 |
+
# set additional parameters / apply overrides
|
| 94 |
+
params.update(kwargs)
|
| 95 |
+
return params
|
moge/model/dinov2/utils/config.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import logging
|
| 8 |
+
import os
|
| 9 |
+
|
| 10 |
+
from omegaconf import OmegaConf
|
| 11 |
+
|
| 12 |
+
import dinov2.distributed as distributed
|
| 13 |
+
from dinov2.logging import setup_logging
|
| 14 |
+
from dinov2.utils import utils
|
| 15 |
+
from dinov2.configs import dinov2_default_config
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
logger = logging.getLogger("dinov2")
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
| 22 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
| 23 |
+
base_lr = cfg.optim.base_lr
|
| 24 |
+
cfg.optim.lr = base_lr
|
| 25 |
+
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
| 26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
return cfg
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
| 33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
| 34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
| 35 |
+
with open(saved_cfg_path, "w") as f:
|
| 36 |
+
OmegaConf.save(config=cfg, f=f)
|
| 37 |
+
return saved_cfg_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_cfg_from_args(args):
|
| 41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
| 42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
| 43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
| 44 |
+
cfg = OmegaConf.load(args.config_file)
|
| 45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
| 46 |
+
return cfg
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def default_setup(args):
|
| 50 |
+
distributed.enable(overwrite=True)
|
| 51 |
+
seed = getattr(args, "seed", 0)
|
| 52 |
+
rank = distributed.get_global_rank()
|
| 53 |
+
|
| 54 |
+
global logger
|
| 55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
| 56 |
+
logger = logging.getLogger("dinov2")
|
| 57 |
+
|
| 58 |
+
utils.fix_random_seeds(seed + rank)
|
| 59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
| 60 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def setup(args):
|
| 64 |
+
"""
|
| 65 |
+
Create configs and perform basic setups.
|
| 66 |
+
"""
|
| 67 |
+
cfg = get_cfg_from_args(args)
|
| 68 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 69 |
+
default_setup(args)
|
| 70 |
+
apply_scaling_rules_to_cfg(cfg)
|
| 71 |
+
write_config(cfg, args.output_dir)
|
| 72 |
+
return cfg
|
moge/model/dinov2/utils/dtype.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
from typing import Dict, Union
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
| 17 |
+
np.dtype("bool"): torch.bool,
|
| 18 |
+
np.dtype("uint8"): torch.uint8,
|
| 19 |
+
np.dtype("int8"): torch.int8,
|
| 20 |
+
np.dtype("int16"): torch.int16,
|
| 21 |
+
np.dtype("int32"): torch.int32,
|
| 22 |
+
np.dtype("int64"): torch.int64,
|
| 23 |
+
np.dtype("float16"): torch.float16,
|
| 24 |
+
np.dtype("float32"): torch.float32,
|
| 25 |
+
np.dtype("float64"): torch.float64,
|
| 26 |
+
np.dtype("complex64"): torch.complex64,
|
| 27 |
+
np.dtype("complex128"): torch.complex128,
|
| 28 |
+
}
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
| 32 |
+
if isinstance(dtype, torch.dtype):
|
| 33 |
+
return dtype
|
| 34 |
+
if isinstance(dtype, str):
|
| 35 |
+
dtype = np.dtype(dtype)
|
| 36 |
+
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
| 37 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
moge/model/dinov2/utils/param_groups.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
from collections import defaultdict
|
| 7 |
+
import logging
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger("dinov2")
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
| 14 |
+
"""
|
| 15 |
+
Calculate lr decay rate for different ViT blocks.
|
| 16 |
+
Args:
|
| 17 |
+
name (string): parameter name.
|
| 18 |
+
lr_decay_rate (float): base lr decay rate.
|
| 19 |
+
num_layers (int): number of ViT blocks.
|
| 20 |
+
Returns:
|
| 21 |
+
lr decay rate for the given parameter.
|
| 22 |
+
"""
|
| 23 |
+
layer_id = num_layers + 1
|
| 24 |
+
if name.startswith("backbone") or force_is_backbone:
|
| 25 |
+
if (
|
| 26 |
+
".pos_embed" in name
|
| 27 |
+
or ".patch_embed" in name
|
| 28 |
+
or ".mask_token" in name
|
| 29 |
+
or ".cls_token" in name
|
| 30 |
+
or ".register_tokens" in name
|
| 31 |
+
):
|
| 32 |
+
layer_id = 0
|
| 33 |
+
elif force_is_backbone and (
|
| 34 |
+
"pos_embed" in name
|
| 35 |
+
or "patch_embed" in name
|
| 36 |
+
or "mask_token" in name
|
| 37 |
+
or "cls_token" in name
|
| 38 |
+
or "register_tokens" in name
|
| 39 |
+
):
|
| 40 |
+
layer_id = 0
|
| 41 |
+
elif ".blocks." in name and ".residual." not in name:
|
| 42 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
| 43 |
+
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
| 44 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
| 45 |
+
elif "blocks." in name and "residual." not in name:
|
| 46 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
| 47 |
+
|
| 48 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
| 52 |
+
chunked_blocks = False
|
| 53 |
+
if hasattr(model, "n_blocks"):
|
| 54 |
+
logger.info("chunked fsdp")
|
| 55 |
+
n_blocks = model.n_blocks
|
| 56 |
+
chunked_blocks = model.chunked_blocks
|
| 57 |
+
elif hasattr(model, "blocks"):
|
| 58 |
+
logger.info("first code branch")
|
| 59 |
+
n_blocks = len(model.blocks)
|
| 60 |
+
elif hasattr(model, "backbone"):
|
| 61 |
+
logger.info("second code branch")
|
| 62 |
+
n_blocks = len(model.backbone.blocks)
|
| 63 |
+
else:
|
| 64 |
+
logger.info("else code branch")
|
| 65 |
+
n_blocks = 0
|
| 66 |
+
all_param_groups = []
|
| 67 |
+
|
| 68 |
+
for name, param in model.named_parameters():
|
| 69 |
+
name = name.replace("_fsdp_wrapped_module.", "")
|
| 70 |
+
if not param.requires_grad:
|
| 71 |
+
continue
|
| 72 |
+
decay_rate = get_vit_lr_decay_rate(
|
| 73 |
+
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
| 74 |
+
)
|
| 75 |
+
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
| 76 |
+
|
| 77 |
+
if "last_layer" in name:
|
| 78 |
+
d.update({"is_last_layer": True})
|
| 79 |
+
|
| 80 |
+
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
| 81 |
+
d.update({"wd_multiplier": 0.0})
|
| 82 |
+
|
| 83 |
+
if "patch_embed" in name:
|
| 84 |
+
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
| 85 |
+
|
| 86 |
+
all_param_groups.append(d)
|
| 87 |
+
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
| 88 |
+
|
| 89 |
+
return all_param_groups
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
| 93 |
+
fused_params_groups = defaultdict(lambda: {"params": []})
|
| 94 |
+
for d in all_params_groups:
|
| 95 |
+
identifier = ""
|
| 96 |
+
for k in keys:
|
| 97 |
+
identifier += k + str(d[k]) + "_"
|
| 98 |
+
|
| 99 |
+
for k in keys:
|
| 100 |
+
fused_params_groups[identifier][k] = d[k]
|
| 101 |
+
fused_params_groups[identifier]["params"].append(d["params"])
|
| 102 |
+
|
| 103 |
+
return fused_params_groups.values()
|
moge/model/dinov2/utils/utils.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
#
|
| 3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
| 4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
| 5 |
+
|
| 6 |
+
import logging
|
| 7 |
+
import os
|
| 8 |
+
import random
|
| 9 |
+
import subprocess
|
| 10 |
+
from urllib.parse import urlparse
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
from torch import nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
| 21 |
+
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
| 22 |
+
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
| 23 |
+
else:
|
| 24 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
| 25 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
| 26 |
+
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
| 27 |
+
state_dict = state_dict[checkpoint_key]
|
| 28 |
+
# remove `module.` prefix
|
| 29 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
| 30 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
| 31 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
| 32 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
| 33 |
+
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def fix_random_seeds(seed=31):
|
| 37 |
+
"""
|
| 38 |
+
Fix random seeds.
|
| 39 |
+
"""
|
| 40 |
+
torch.manual_seed(seed)
|
| 41 |
+
torch.cuda.manual_seed_all(seed)
|
| 42 |
+
np.random.seed(seed)
|
| 43 |
+
random.seed(seed)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_sha():
|
| 47 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
| 48 |
+
|
| 49 |
+
def _run(command):
|
| 50 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
| 51 |
+
|
| 52 |
+
sha = "N/A"
|
| 53 |
+
diff = "clean"
|
| 54 |
+
branch = "N/A"
|
| 55 |
+
try:
|
| 56 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
| 57 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
| 58 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
| 59 |
+
diff = "has uncommitted changes" if diff else "clean"
|
| 60 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
| 61 |
+
except Exception:
|
| 62 |
+
pass
|
| 63 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
| 64 |
+
return message
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class CosineScheduler(object):
|
| 68 |
+
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
| 69 |
+
super().__init__()
|
| 70 |
+
self.final_value = final_value
|
| 71 |
+
self.total_iters = total_iters
|
| 72 |
+
|
| 73 |
+
freeze_schedule = np.zeros((freeze_iters))
|
| 74 |
+
|
| 75 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
| 76 |
+
|
| 77 |
+
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
| 78 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
| 79 |
+
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
| 80 |
+
|
| 81 |
+
assert len(self.schedule) == self.total_iters
|
| 82 |
+
|
| 83 |
+
def __getitem__(self, it):
|
| 84 |
+
if it >= self.total_iters:
|
| 85 |
+
return self.final_value
|
| 86 |
+
else:
|
| 87 |
+
return self.schedule[it]
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def has_batchnorms(model):
|
| 91 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
| 92 |
+
for name, module in model.named_modules():
|
| 93 |
+
if isinstance(module, bn_types):
|
| 94 |
+
return True
|
| 95 |
+
return False
|
moge/model/modules.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
from numbers import Number
|
| 8 |
+
import importlib
|
| 9 |
+
import itertools
|
| 10 |
+
import functools
|
| 11 |
+
import sys
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
from torch import Tensor
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
|
| 18 |
+
from .dinov2.models.vision_transformer import DinoVisionTransformer
|
| 19 |
+
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
| 20 |
+
from ..utils.geometry_torch import normalized_view_plane_uv
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
class ResidualConvBlock(nn.Module):
|
| 24 |
+
def __init__(
|
| 25 |
+
self,
|
| 26 |
+
in_channels: int,
|
| 27 |
+
out_channels: int = None,
|
| 28 |
+
hidden_channels: int = None,
|
| 29 |
+
kernel_size: int = 3,
|
| 30 |
+
padding_mode: str = 'replicate',
|
| 31 |
+
activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
|
| 32 |
+
in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
|
| 33 |
+
hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
|
| 34 |
+
):
|
| 35 |
+
super(ResidualConvBlock, self).__init__()
|
| 36 |
+
if out_channels is None:
|
| 37 |
+
out_channels = in_channels
|
| 38 |
+
if hidden_channels is None:
|
| 39 |
+
hidden_channels = in_channels
|
| 40 |
+
|
| 41 |
+
if activation =='relu':
|
| 42 |
+
activation_cls = nn.ReLU
|
| 43 |
+
elif activation == 'leaky_relu':
|
| 44 |
+
activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
|
| 45 |
+
elif activation =='silu':
|
| 46 |
+
activation_cls = nn.SiLU
|
| 47 |
+
elif activation == 'elu':
|
| 48 |
+
activation_cls = nn.ELU
|
| 49 |
+
else:
|
| 50 |
+
raise ValueError(f'Unsupported activation function: {activation}')
|
| 51 |
+
|
| 52 |
+
self.layers = nn.Sequential(
|
| 53 |
+
nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
|
| 54 |
+
nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
|
| 55 |
+
nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
|
| 56 |
+
nn.Identity(),
|
| 57 |
+
activation_cls(),
|
| 58 |
+
nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
|
| 59 |
+
nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
|
| 60 |
+
nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
|
| 61 |
+
nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
|
| 62 |
+
nn.Identity(),
|
| 63 |
+
activation_cls(),
|
| 64 |
+
nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
|
| 68 |
+
|
| 69 |
+
def forward(self, x):
|
| 70 |
+
skip = self.skip_connection(x)
|
| 71 |
+
x = self.layers(x)
|
| 72 |
+
x = x + skip
|
| 73 |
+
return x
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
class DINOv2Encoder(nn.Module):
|
| 77 |
+
"Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
|
| 78 |
+
backbone: DinoVisionTransformer
|
| 79 |
+
image_mean: torch.Tensor
|
| 80 |
+
image_std: torch.Tensor
|
| 81 |
+
dim_features: int
|
| 82 |
+
|
| 83 |
+
def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs):
|
| 84 |
+
super(DINOv2Encoder, self).__init__()
|
| 85 |
+
|
| 86 |
+
self.intermediate_layers = intermediate_layers
|
| 87 |
+
|
| 88 |
+
# Load the backbone
|
| 89 |
+
self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone)
|
| 90 |
+
self.backbone_name = backbone
|
| 91 |
+
self.backbone = self.hub_loader(pretrained=False)
|
| 92 |
+
|
| 93 |
+
self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
|
| 94 |
+
self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
|
| 95 |
+
|
| 96 |
+
self.output_projections = nn.ModuleList([
|
| 97 |
+
nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
|
| 98 |
+
for _ in range(self.num_features)
|
| 99 |
+
])
|
| 100 |
+
|
| 101 |
+
self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
|
| 102 |
+
self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def onnx_compatible_mode(self):
|
| 106 |
+
return getattr(self, "_onnx_compatible_mode", False)
|
| 107 |
+
|
| 108 |
+
@onnx_compatible_mode.setter
|
| 109 |
+
def onnx_compatible_mode(self, value: bool):
|
| 110 |
+
self._onnx_compatible_mode = value
|
| 111 |
+
self.backbone.onnx_compatible_mode = value
|
| 112 |
+
|
| 113 |
+
def init_weights(self):
|
| 114 |
+
pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
|
| 115 |
+
self.backbone.load_state_dict(pretrained_backbone_state_dict)
|
| 116 |
+
|
| 117 |
+
def enable_gradient_checkpointing(self):
|
| 118 |
+
for i in range(len(self.backbone.blocks)):
|
| 119 |
+
wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
| 120 |
+
|
| 121 |
+
def enable_pytorch_native_sdpa(self):
|
| 122 |
+
for i in range(len(self.backbone.blocks)):
|
| 123 |
+
wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
|
| 124 |
+
|
| 125 |
+
def forward(self, image: torch.Tensor, token_rows: Union[int, torch.LongTensor], token_cols: Union[int, torch.LongTensor], return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 126 |
+
image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
|
| 127 |
+
image_14 = (image_14 - self.image_mean) / self.image_std
|
| 128 |
+
|
| 129 |
+
# Get intermediate layers from the backbone
|
| 130 |
+
features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True)
|
| 131 |
+
|
| 132 |
+
# Project features to the desired dimensionality
|
| 133 |
+
x = torch.stack([
|
| 134 |
+
proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
|
| 135 |
+
for proj, (feat, clstoken) in zip(self.output_projections, features)
|
| 136 |
+
], dim=1).sum(dim=1)
|
| 137 |
+
|
| 138 |
+
if return_class_token:
|
| 139 |
+
return x, features[-1][1]
|
| 140 |
+
else:
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
class Resampler(nn.Sequential):
|
| 145 |
+
def __init__(self,
|
| 146 |
+
in_channels: int,
|
| 147 |
+
out_channels: int,
|
| 148 |
+
type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
|
| 149 |
+
scale_factor: int = 2,
|
| 150 |
+
):
|
| 151 |
+
if type_ == 'pixel_shuffle':
|
| 152 |
+
nn.Sequential.__init__(self,
|
| 153 |
+
nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 154 |
+
nn.PixelShuffle(scale_factor),
|
| 155 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 156 |
+
)
|
| 157 |
+
for i in range(1, scale_factor ** 2):
|
| 158 |
+
self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
|
| 159 |
+
self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
|
| 160 |
+
elif type_ in ['nearest', 'bilinear']:
|
| 161 |
+
nn.Sequential.__init__(self,
|
| 162 |
+
nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
|
| 163 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 164 |
+
)
|
| 165 |
+
elif type_ == 'conv_transpose':
|
| 166 |
+
nn.Sequential.__init__(self,
|
| 167 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
|
| 168 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 169 |
+
)
|
| 170 |
+
self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
|
| 171 |
+
elif type_ == 'pixel_unshuffle':
|
| 172 |
+
nn.Sequential.__init__(self,
|
| 173 |
+
nn.PixelUnshuffle(scale_factor),
|
| 174 |
+
nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
| 175 |
+
)
|
| 176 |
+
elif type_ == 'avg_pool':
|
| 177 |
+
nn.Sequential.__init__(self,
|
| 178 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 179 |
+
nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 180 |
+
)
|
| 181 |
+
elif type_ == 'max_pool':
|
| 182 |
+
nn.Sequential.__init__(self,
|
| 183 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
| 184 |
+
nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
|
| 185 |
+
)
|
| 186 |
+
else:
|
| 187 |
+
raise ValueError(f'Unsupported resampler type: {type_}')
|
| 188 |
+
|
| 189 |
+
class MLP(nn.Sequential):
|
| 190 |
+
def __init__(self, dims: Sequence[int]):
|
| 191 |
+
nn.Sequential.__init__(self,
|
| 192 |
+
*itertools.chain(*[
|
| 193 |
+
(nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
|
| 194 |
+
for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
|
| 195 |
+
]),
|
| 196 |
+
nn.Linear(dims[-2], dims[-1]),
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
class ConvStack(nn.Module):
|
| 201 |
+
def __init__(self,
|
| 202 |
+
dim_in: List[Optional[int]],
|
| 203 |
+
dim_res_blocks: List[int],
|
| 204 |
+
dim_out: List[Optional[int]],
|
| 205 |
+
resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
|
| 206 |
+
dim_times_res_block_hidden: int = 1,
|
| 207 |
+
num_res_blocks: int = 1,
|
| 208 |
+
res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
|
| 209 |
+
res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
|
| 210 |
+
activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
|
| 211 |
+
):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.input_blocks = nn.ModuleList([
|
| 214 |
+
nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
|
| 215 |
+
for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
|
| 216 |
+
])
|
| 217 |
+
self.resamplers = nn.ModuleList([
|
| 218 |
+
Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
|
| 219 |
+
for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
|
| 220 |
+
dim_res_blocks[:-1],
|
| 221 |
+
dim_res_blocks[1:],
|
| 222 |
+
resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
|
| 223 |
+
))
|
| 224 |
+
])
|
| 225 |
+
self.res_blocks = nn.ModuleList([
|
| 226 |
+
nn.Sequential(
|
| 227 |
+
*(
|
| 228 |
+
ResidualConvBlock(
|
| 229 |
+
dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
|
| 230 |
+
activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
|
| 231 |
+
) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
|
| 232 |
+
)
|
| 233 |
+
) for i, dim_res_block_ in enumerate(dim_res_blocks)
|
| 234 |
+
])
|
| 235 |
+
self.output_blocks = nn.ModuleList([
|
| 236 |
+
nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
|
| 237 |
+
for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
|
| 238 |
+
])
|
| 239 |
+
|
| 240 |
+
def enable_gradient_checkpointing(self):
|
| 241 |
+
for i in range(len(self.resamplers)):
|
| 242 |
+
self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
|
| 243 |
+
for i in range(len(self.res_blocks)):
|
| 244 |
+
for j in range(len(self.res_blocks[i])):
|
| 245 |
+
self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
|
| 246 |
+
|
| 247 |
+
def forward(self, in_features: List[torch.Tensor]):
|
| 248 |
+
out_features = []
|
| 249 |
+
for i in range(len(self.res_blocks)):
|
| 250 |
+
feature = self.input_blocks[i](in_features[i])
|
| 251 |
+
if i == 0:
|
| 252 |
+
x = feature
|
| 253 |
+
elif feature is not None:
|
| 254 |
+
x = x + feature
|
| 255 |
+
x = self.res_blocks[i](x)
|
| 256 |
+
out_features.append(self.output_blocks[i](x))
|
| 257 |
+
if i < len(self.res_blocks) - 1:
|
| 258 |
+
x = self.resamplers[i](x)
|
| 259 |
+
return out_features
|
moge/model/transforms.py
ADDED
|
@@ -0,0 +1,1344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
from numbers import Number
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
import inspect
|
| 13 |
+
from functools import wraps
|
| 14 |
+
|
| 15 |
+
import warnings
|
| 16 |
+
|
| 17 |
+
def suppress_traceback(fn):
|
| 18 |
+
@wraps(fn)
|
| 19 |
+
def wrapper(*args, **kwargs):
|
| 20 |
+
try:
|
| 21 |
+
return fn(*args, **kwargs)
|
| 22 |
+
except Exception as e:
|
| 23 |
+
e.__traceback__ = e.__traceback__.tb_next.tb_next
|
| 24 |
+
raise
|
| 25 |
+
return wrapper
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class no_warnings:
|
| 29 |
+
def __init__(self, action: str = 'ignore', **kwargs):
|
| 30 |
+
self.action = action
|
| 31 |
+
self.filter_kwargs = kwargs
|
| 32 |
+
|
| 33 |
+
def __call__(self, fn):
|
| 34 |
+
@wraps(fn)
|
| 35 |
+
def wrapper(*args, **kwargs):
|
| 36 |
+
with warnings.catch_warnings():
|
| 37 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 38 |
+
return fn(*args, **kwargs)
|
| 39 |
+
return wrapper
|
| 40 |
+
|
| 41 |
+
def __enter__(self):
|
| 42 |
+
self.warnings_manager = warnings.catch_warnings()
|
| 43 |
+
self.warnings_manager.__enter__()
|
| 44 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 45 |
+
|
| 46 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 47 |
+
self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_device(args, kwargs):
|
| 51 |
+
device = None
|
| 52 |
+
for arg in (list(args) + list(kwargs.values())):
|
| 53 |
+
if isinstance(arg, torch.Tensor):
|
| 54 |
+
if device is None:
|
| 55 |
+
device = arg.device
|
| 56 |
+
elif device != arg.device:
|
| 57 |
+
raise ValueError("All tensors must be on the same device.")
|
| 58 |
+
return device
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def get_args_order(func, args, kwargs):
|
| 62 |
+
"""
|
| 63 |
+
Get the order of the arguments of a function.
|
| 64 |
+
"""
|
| 65 |
+
names = inspect.getfullargspec(func).args
|
| 66 |
+
names_idx = {name: i for i, name in enumerate(names)}
|
| 67 |
+
args_order = []
|
| 68 |
+
kwargs_order = {}
|
| 69 |
+
for name, arg in kwargs.items():
|
| 70 |
+
if name in names:
|
| 71 |
+
kwargs_order[name] = names_idx[name]
|
| 72 |
+
names.remove(name)
|
| 73 |
+
for i, arg in enumerate(args):
|
| 74 |
+
if i < len(names):
|
| 75 |
+
args_order.append(names_idx[names[i]])
|
| 76 |
+
return args_order, kwargs_order
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def broadcast_args(args, kwargs, args_dim, kwargs_dim):
|
| 80 |
+
spatial = []
|
| 81 |
+
for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
|
| 82 |
+
if isinstance(arg, torch.Tensor) and arg_dim is not None:
|
| 83 |
+
arg_spatial = arg.shape[:arg.ndim-arg_dim]
|
| 84 |
+
if len(arg_spatial) > len(spatial):
|
| 85 |
+
spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
|
| 86 |
+
for j in range(len(arg_spatial)):
|
| 87 |
+
if spatial[-j] < arg_spatial[-j]:
|
| 88 |
+
if spatial[-j] == 1:
|
| 89 |
+
spatial[-j] = arg_spatial[-j]
|
| 90 |
+
else:
|
| 91 |
+
raise ValueError("Cannot broadcast arguments.")
|
| 92 |
+
for i, arg in enumerate(args):
|
| 93 |
+
if isinstance(arg, torch.Tensor) and args_dim[i] is not None:
|
| 94 |
+
args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
|
| 95 |
+
for key, arg in kwargs.items():
|
| 96 |
+
if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
|
| 97 |
+
kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
|
| 98 |
+
return args, kwargs, spatial
|
| 99 |
+
|
| 100 |
+
@suppress_traceback
|
| 101 |
+
def batched(*dims):
|
| 102 |
+
"""
|
| 103 |
+
Decorator that allows a function to be called with batched arguments.
|
| 104 |
+
"""
|
| 105 |
+
def decorator(func):
|
| 106 |
+
@wraps(func)
|
| 107 |
+
def wrapper(*args, device=torch.device('cpu'), **kwargs):
|
| 108 |
+
args = list(args)
|
| 109 |
+
# get arguments dimensions
|
| 110 |
+
args_order, kwargs_order = get_args_order(func, args, kwargs)
|
| 111 |
+
args_dim = [dims[i] for i in args_order]
|
| 112 |
+
kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
|
| 113 |
+
# convert to torch tensor
|
| 114 |
+
device = get_device(args, kwargs) or device
|
| 115 |
+
for i, arg in enumerate(args):
|
| 116 |
+
if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
|
| 117 |
+
args[i] = torch.tensor(arg, device=device)
|
| 118 |
+
for key, arg in kwargs.items():
|
| 119 |
+
if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
|
| 120 |
+
kwargs[key] = torch.tensor(arg, device=device)
|
| 121 |
+
# broadcast arguments
|
| 122 |
+
args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
|
| 123 |
+
for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
|
| 124 |
+
if isinstance(arg, torch.Tensor) and arg_dim is not None:
|
| 125 |
+
args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
|
| 126 |
+
for key, arg in kwargs.items():
|
| 127 |
+
if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
|
| 128 |
+
kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
|
| 129 |
+
# call function
|
| 130 |
+
results = func(*args, **kwargs)
|
| 131 |
+
type_results = type(results)
|
| 132 |
+
results = list(results) if isinstance(results, (tuple, list)) else [results]
|
| 133 |
+
# restore spatial dimensions
|
| 134 |
+
for i, result in enumerate(results):
|
| 135 |
+
results[i] = result.reshape([*spatial, *result.shape[1:]])
|
| 136 |
+
if type_results == tuple:
|
| 137 |
+
results = tuple(results)
|
| 138 |
+
elif type_results == list:
|
| 139 |
+
results = list(results)
|
| 140 |
+
else:
|
| 141 |
+
results = results[0]
|
| 142 |
+
return results
|
| 143 |
+
return wrapper
|
| 144 |
+
return decorator
|
| 145 |
+
|
| 146 |
+
__all__ = [
|
| 147 |
+
'perspective',
|
| 148 |
+
'perspective_from_fov',
|
| 149 |
+
'perspective_from_fov_xy',
|
| 150 |
+
'intrinsics_from_focal_center',
|
| 151 |
+
'intrinsics_from_fov',
|
| 152 |
+
'intrinsics_from_fov_xy',
|
| 153 |
+
'focal_to_fov',
|
| 154 |
+
'fov_to_focal',
|
| 155 |
+
'intrinsics_to_fov',
|
| 156 |
+
'view_look_at',
|
| 157 |
+
'extrinsics_look_at',
|
| 158 |
+
'perspective_to_intrinsics',
|
| 159 |
+
'intrinsics_to_perspective',
|
| 160 |
+
'extrinsics_to_view',
|
| 161 |
+
'view_to_extrinsics',
|
| 162 |
+
'normalize_intrinsics',
|
| 163 |
+
'crop_intrinsics',
|
| 164 |
+
'pixel_to_uv',
|
| 165 |
+
'pixel_to_ndc',
|
| 166 |
+
'uv_to_pixel',
|
| 167 |
+
'project_depth',
|
| 168 |
+
'depth_buffer_to_linear',
|
| 169 |
+
'project_gl',
|
| 170 |
+
'project_cv',
|
| 171 |
+
'unproject_gl',
|
| 172 |
+
'unproject_cv',
|
| 173 |
+
'skew_symmetric',
|
| 174 |
+
'rotation_matrix_from_vectors',
|
| 175 |
+
'euler_axis_angle_rotation',
|
| 176 |
+
'euler_angles_to_matrix',
|
| 177 |
+
'matrix_to_euler_angles',
|
| 178 |
+
'matrix_to_quaternion',
|
| 179 |
+
'quaternion_to_matrix',
|
| 180 |
+
'matrix_to_axis_angle',
|
| 181 |
+
'axis_angle_to_matrix',
|
| 182 |
+
'axis_angle_to_quaternion',
|
| 183 |
+
'quaternion_to_axis_angle',
|
| 184 |
+
'slerp',
|
| 185 |
+
'interpolate_extrinsics',
|
| 186 |
+
'interpolate_view',
|
| 187 |
+
'extrinsics_to_essential',
|
| 188 |
+
'to4x4',
|
| 189 |
+
'rotation_matrix_2d',
|
| 190 |
+
'rotate_2d',
|
| 191 |
+
'translate_2d',
|
| 192 |
+
'scale_2d',
|
| 193 |
+
'apply_2d',
|
| 194 |
+
]
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@batched(0,0,0,0)
|
| 198 |
+
def perspective(
|
| 199 |
+
fov_y: Union[float, torch.Tensor],
|
| 200 |
+
aspect: Union[float, torch.Tensor],
|
| 201 |
+
near: Union[float, torch.Tensor],
|
| 202 |
+
far: Union[float, torch.Tensor]
|
| 203 |
+
) -> torch.Tensor:
|
| 204 |
+
"""
|
| 205 |
+
Get OpenGL perspective matrix
|
| 206 |
+
|
| 207 |
+
Args:
|
| 208 |
+
fov_y (float | torch.Tensor): field of view in y axis
|
| 209 |
+
aspect (float | torch.Tensor): aspect ratio
|
| 210 |
+
near (float | torch.Tensor): near plane to clip
|
| 211 |
+
far (float | torch.Tensor): far plane to clip
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
(torch.Tensor): [..., 4, 4] perspective matrix
|
| 215 |
+
"""
|
| 216 |
+
N = fov_y.shape[0]
|
| 217 |
+
ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device)
|
| 218 |
+
ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect)
|
| 219 |
+
ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2))
|
| 220 |
+
ret[:, 2, 2] = (near + far) / (near - far)
|
| 221 |
+
ret[:, 2, 3] = 2. * near * far / (near - far)
|
| 222 |
+
ret[:, 3, 2] = -1.
|
| 223 |
+
return ret
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
def perspective_from_fov(
|
| 227 |
+
fov: Union[float, torch.Tensor],
|
| 228 |
+
width: Union[int, torch.Tensor],
|
| 229 |
+
height: Union[int, torch.Tensor],
|
| 230 |
+
near: Union[float, torch.Tensor],
|
| 231 |
+
far: Union[float, torch.Tensor]
|
| 232 |
+
) -> torch.Tensor:
|
| 233 |
+
"""
|
| 234 |
+
Get OpenGL perspective matrix from field of view in largest dimension
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
fov (float | torch.Tensor): field of view in largest dimension
|
| 238 |
+
width (int | torch.Tensor): image width
|
| 239 |
+
height (int | torch.Tensor): image height
|
| 240 |
+
near (float | torch.Tensor): near plane to clip
|
| 241 |
+
far (float | torch.Tensor): far plane to clip
|
| 242 |
+
|
| 243 |
+
Returns:
|
| 244 |
+
(torch.Tensor): [..., 4, 4] perspective matrix
|
| 245 |
+
"""
|
| 246 |
+
fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height))
|
| 247 |
+
aspect = width / height
|
| 248 |
+
return perspective(fov_y, aspect, near, far)
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
def perspective_from_fov_xy(
|
| 252 |
+
fov_x: Union[float, torch.Tensor],
|
| 253 |
+
fov_y: Union[float, torch.Tensor],
|
| 254 |
+
near: Union[float, torch.Tensor],
|
| 255 |
+
far: Union[float, torch.Tensor]
|
| 256 |
+
) -> torch.Tensor:
|
| 257 |
+
"""
|
| 258 |
+
Get OpenGL perspective matrix from field of view in x and y axis
|
| 259 |
+
|
| 260 |
+
Args:
|
| 261 |
+
fov_x (float | torch.Tensor): field of view in x axis
|
| 262 |
+
fov_y (float | torch.Tensor): field of view in y axis
|
| 263 |
+
near (float | torch.Tensor): near plane to clip
|
| 264 |
+
far (float | torch.Tensor): far plane to clip
|
| 265 |
+
|
| 266 |
+
Returns:
|
| 267 |
+
(torch.Tensor): [..., 4, 4] perspective matrix
|
| 268 |
+
"""
|
| 269 |
+
aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2)
|
| 270 |
+
return perspective(fov_y, aspect, near, far)
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
@batched(0,0,0,0)
|
| 274 |
+
def intrinsics_from_focal_center(
|
| 275 |
+
fx: Union[float, torch.Tensor],
|
| 276 |
+
fy: Union[float, torch.Tensor],
|
| 277 |
+
cx: Union[float, torch.Tensor],
|
| 278 |
+
cy: Union[float, torch.Tensor]
|
| 279 |
+
) -> torch.Tensor:
|
| 280 |
+
"""
|
| 281 |
+
Get OpenCV intrinsics matrix
|
| 282 |
+
|
| 283 |
+
Args:
|
| 284 |
+
focal_x (float | torch.Tensor): focal length in x axis
|
| 285 |
+
focal_y (float | torch.Tensor): focal length in y axis
|
| 286 |
+
cx (float | torch.Tensor): principal point in x axis
|
| 287 |
+
cy (float | torch.Tensor): principal point in y axis
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
(torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
|
| 291 |
+
"""
|
| 292 |
+
N = fx.shape[0]
|
| 293 |
+
ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
|
| 294 |
+
zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device)
|
| 295 |
+
ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3))
|
| 296 |
+
return ret
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
@batched(0, 0, 0, 0, 0, 0)
|
| 300 |
+
def intrinsics_from_fov(
|
| 301 |
+
fov_max: Union[float, torch.Tensor] = None,
|
| 302 |
+
fov_min: Union[float, torch.Tensor] = None,
|
| 303 |
+
fov_x: Union[float, torch.Tensor] = None,
|
| 304 |
+
fov_y: Union[float, torch.Tensor] = None,
|
| 305 |
+
width: Union[int, torch.Tensor] = None,
|
| 306 |
+
height: Union[int, torch.Tensor] = None,
|
| 307 |
+
) -> torch.Tensor:
|
| 308 |
+
"""
|
| 309 |
+
Get normalized OpenCV intrinsics matrix from given field of view.
|
| 310 |
+
You can provide either fov_max, fov_min, fov_x or fov_y
|
| 311 |
+
|
| 312 |
+
Args:
|
| 313 |
+
width (int | torch.Tensor): image width
|
| 314 |
+
height (int | torch.Tensor): image height
|
| 315 |
+
fov_max (float | torch.Tensor): field of view in largest dimension
|
| 316 |
+
fov_min (float | torch.Tensor): field of view in smallest dimension
|
| 317 |
+
fov_x (float | torch.Tensor): field of view in x axis
|
| 318 |
+
fov_y (float | torch.Tensor): field of view in y axis
|
| 319 |
+
|
| 320 |
+
Returns:
|
| 321 |
+
(torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
|
| 322 |
+
"""
|
| 323 |
+
if fov_max is not None:
|
| 324 |
+
fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2))
|
| 325 |
+
fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2))
|
| 326 |
+
elif fov_min is not None:
|
| 327 |
+
fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2))
|
| 328 |
+
fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2))
|
| 329 |
+
elif fov_x is not None and fov_y is not None:
|
| 330 |
+
fx = 1 / (2 * torch.tan(fov_x / 2))
|
| 331 |
+
fy = 1 / (2 * torch.tan(fov_y / 2))
|
| 332 |
+
elif fov_x is not None:
|
| 333 |
+
fx = 1 / (2 * torch.tan(fov_x / 2))
|
| 334 |
+
fy = fx * width / height
|
| 335 |
+
elif fov_y is not None:
|
| 336 |
+
fy = 1 / (2 * torch.tan(fov_y / 2))
|
| 337 |
+
fx = fy * height / width
|
| 338 |
+
cx = 0.5
|
| 339 |
+
cy = 0.5
|
| 340 |
+
ret = intrinsics_from_focal_center(fx, fy, cx, cy)
|
| 341 |
+
return ret
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def intrinsics_from_fov_xy(
|
| 346 |
+
fov_x: Union[float, torch.Tensor],
|
| 347 |
+
fov_y: Union[float, torch.Tensor]
|
| 348 |
+
) -> torch.Tensor:
|
| 349 |
+
"""
|
| 350 |
+
Get OpenCV intrinsics matrix from field of view in x and y axis
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
fov_x (float | torch.Tensor): field of view in x axis
|
| 354 |
+
fov_y (float | torch.Tensor): field of view in y axis
|
| 355 |
+
|
| 356 |
+
Returns:
|
| 357 |
+
(torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
|
| 358 |
+
"""
|
| 359 |
+
focal_x = 0.5 / torch.tan(fov_x / 2)
|
| 360 |
+
focal_y = 0.5 / torch.tan(fov_y / 2)
|
| 361 |
+
cx = cy = 0.5
|
| 362 |
+
return intrinsics_from_focal_center(focal_x, focal_y, cx, cy)
|
| 363 |
+
|
| 364 |
+
|
| 365 |
+
def focal_to_fov(focal: torch.Tensor):
|
| 366 |
+
return 2 * torch.atan(0.5 / focal)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def fov_to_focal(fov: torch.Tensor):
|
| 370 |
+
return 0.5 / torch.tan(fov / 2)
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
def intrinsics_to_fov(intrinsics: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 374 |
+
"NOTE: approximate FOV by assuming centered principal point"
|
| 375 |
+
fov_x = focal_to_fov(intrinsics[..., 0, 0])
|
| 376 |
+
fov_y = focal_to_fov(intrinsics[..., 1, 1])
|
| 377 |
+
return fov_x, fov_y
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
@batched(1,1,1)
|
| 381 |
+
def view_look_at(
|
| 382 |
+
eye: torch.Tensor,
|
| 383 |
+
look_at: torch.Tensor,
|
| 384 |
+
up: torch.Tensor
|
| 385 |
+
) -> torch.Tensor:
|
| 386 |
+
"""
|
| 387 |
+
Get OpenGL view matrix looking at something
|
| 388 |
+
|
| 389 |
+
Args:
|
| 390 |
+
eye (torch.Tensor): [..., 3] the eye position
|
| 391 |
+
look_at (torch.Tensor): [..., 3] the position to look at
|
| 392 |
+
up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
(torch.Tensor): [..., 4, 4], view matrix
|
| 396 |
+
"""
|
| 397 |
+
N = eye.shape[0]
|
| 398 |
+
z = eye - look_at
|
| 399 |
+
x = torch.cross(up, z, dim=-1)
|
| 400 |
+
y = torch.cross(z, x, dim=-1)
|
| 401 |
+
# x = torch.cross(y, z, dim=-1)
|
| 402 |
+
x = x / x.norm(dim=-1, keepdim=True)
|
| 403 |
+
y = y / y.norm(dim=-1, keepdim=True)
|
| 404 |
+
z = z / z.norm(dim=-1, keepdim=True)
|
| 405 |
+
R = torch.stack([x, y, z], dim=-2)
|
| 406 |
+
t = -torch.matmul(R, eye[..., None])
|
| 407 |
+
ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
|
| 408 |
+
ret[:, :3, :3] = R
|
| 409 |
+
ret[:, :3, 3] = t[:, :, 0]
|
| 410 |
+
ret[:, 3, 3] = 1.
|
| 411 |
+
return ret
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
@batched(1, 1, 1)
|
| 415 |
+
def extrinsics_look_at(
|
| 416 |
+
eye: torch.Tensor,
|
| 417 |
+
look_at: torch.Tensor,
|
| 418 |
+
up: torch.Tensor
|
| 419 |
+
) -> torch.Tensor:
|
| 420 |
+
"""
|
| 421 |
+
Get OpenCV extrinsics matrix looking at something
|
| 422 |
+
|
| 423 |
+
Args:
|
| 424 |
+
eye (torch.Tensor): [..., 3] the eye position
|
| 425 |
+
look_at (torch.Tensor): [..., 3] the position to look at
|
| 426 |
+
up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction
|
| 427 |
+
|
| 428 |
+
Returns:
|
| 429 |
+
(torch.Tensor): [..., 4, 4], extrinsics matrix
|
| 430 |
+
"""
|
| 431 |
+
N = eye.shape[0]
|
| 432 |
+
z = look_at - eye
|
| 433 |
+
x = torch.cross(-up, z, dim=-1)
|
| 434 |
+
y = torch.cross(z, x, dim=-1)
|
| 435 |
+
# x = torch.cross(y, z, dim=-1)
|
| 436 |
+
x = x / x.norm(dim=-1, keepdim=True)
|
| 437 |
+
y = y / y.norm(dim=-1, keepdim=True)
|
| 438 |
+
z = z / z.norm(dim=-1, keepdim=True)
|
| 439 |
+
R = torch.stack([x, y, z], dim=-2)
|
| 440 |
+
t = -torch.matmul(R, eye[..., None])
|
| 441 |
+
ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
|
| 442 |
+
ret[:, :3, :3] = R
|
| 443 |
+
ret[:, :3, 3] = t[:, :, 0]
|
| 444 |
+
ret[:, 3, 3] = 1.
|
| 445 |
+
return ret
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
@batched(2)
|
| 449 |
+
def perspective_to_intrinsics(
|
| 450 |
+
perspective: torch.Tensor
|
| 451 |
+
) -> torch.Tensor:
|
| 452 |
+
"""
|
| 453 |
+
OpenGL perspective matrix to OpenCV intrinsics
|
| 454 |
+
|
| 455 |
+
Args:
|
| 456 |
+
perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
|
| 457 |
+
|
| 458 |
+
Returns:
|
| 459 |
+
(torch.Tensor): shape [..., 3, 3] OpenCV intrinsics
|
| 460 |
+
"""
|
| 461 |
+
assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix"
|
| 462 |
+
ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \
|
| 463 |
+
@ perspective[:, [0, 1, 3], :3] \
|
| 464 |
+
@ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device))
|
| 465 |
+
return ret / ret[:, 2, 2, None, None]
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
@batched(2,0,0)
|
| 469 |
+
def intrinsics_to_perspective(
|
| 470 |
+
intrinsics: torch.Tensor,
|
| 471 |
+
near: Union[float, torch.Tensor],
|
| 472 |
+
far: Union[float, torch.Tensor],
|
| 473 |
+
) -> torch.Tensor:
|
| 474 |
+
"""
|
| 475 |
+
OpenCV intrinsics to OpenGL perspective matrix
|
| 476 |
+
|
| 477 |
+
Args:
|
| 478 |
+
intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
|
| 479 |
+
near (float | torch.Tensor): [...] near plane to clip
|
| 480 |
+
far (float | torch.Tensor): [...] far plane to clip
|
| 481 |
+
Returns:
|
| 482 |
+
(torch.Tensor): [..., 4, 4] OpenGL perspective matrix
|
| 483 |
+
"""
|
| 484 |
+
N = intrinsics.shape[0]
|
| 485 |
+
fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
|
| 486 |
+
cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
|
| 487 |
+
ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
|
| 488 |
+
ret[:, 0, 0] = 2 * fx
|
| 489 |
+
ret[:, 1, 1] = 2 * fy
|
| 490 |
+
ret[:, 0, 2] = -2 * cx + 1
|
| 491 |
+
ret[:, 1, 2] = 2 * cy - 1
|
| 492 |
+
ret[:, 2, 2] = (near + far) / (near - far)
|
| 493 |
+
ret[:, 2, 3] = 2. * near * far / (near - far)
|
| 494 |
+
ret[:, 3, 2] = -1.
|
| 495 |
+
return ret
|
| 496 |
+
|
| 497 |
+
|
| 498 |
+
@batched(2)
|
| 499 |
+
def extrinsics_to_view(
|
| 500 |
+
extrinsics: torch.Tensor
|
| 501 |
+
) -> torch.Tensor:
|
| 502 |
+
"""
|
| 503 |
+
OpenCV camera extrinsics to OpenGL view matrix
|
| 504 |
+
|
| 505 |
+
Args:
|
| 506 |
+
extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
|
| 507 |
+
|
| 508 |
+
Returns:
|
| 509 |
+
(torch.Tensor): [..., 4, 4] OpenGL view matrix
|
| 510 |
+
"""
|
| 511 |
+
return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None]
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@batched(2)
|
| 515 |
+
def view_to_extrinsics(
|
| 516 |
+
view: torch.Tensor
|
| 517 |
+
) -> torch.Tensor:
|
| 518 |
+
"""
|
| 519 |
+
OpenGL view matrix to OpenCV camera extrinsics
|
| 520 |
+
|
| 521 |
+
Args:
|
| 522 |
+
view (torch.Tensor): [..., 4, 4] OpenGL view matrix
|
| 523 |
+
|
| 524 |
+
Returns:
|
| 525 |
+
(torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
|
| 526 |
+
"""
|
| 527 |
+
return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None]
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
@batched(2,0,0)
|
| 531 |
+
def normalize_intrinsics(
|
| 532 |
+
intrinsics: torch.Tensor,
|
| 533 |
+
width: Union[int, torch.Tensor],
|
| 534 |
+
height: Union[int, torch.Tensor]
|
| 535 |
+
) -> torch.Tensor:
|
| 536 |
+
"""
|
| 537 |
+
Normalize camera intrinsics(s) to uv space
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize
|
| 541 |
+
width (int | torch.Tensor): [...] image width(s)
|
| 542 |
+
height (int | torch.Tensor): [...] image height(s)
|
| 543 |
+
|
| 544 |
+
Returns:
|
| 545 |
+
(torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)
|
| 546 |
+
"""
|
| 547 |
+
zeros = torch.zeros_like(width)
|
| 548 |
+
ones = torch.ones_like(width)
|
| 549 |
+
transform = torch.stack([
|
| 550 |
+
1 / width, zeros, 0.5 / width,
|
| 551 |
+
zeros, 1 / height, 0.5 / height,
|
| 552 |
+
zeros, zeros, ones
|
| 553 |
+
]).reshape(*zeros.shape, 3, 3).to(intrinsics)
|
| 554 |
+
return transform @ intrinsics
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
@batched(2,0,0,0,0,0,0)
|
| 559 |
+
def crop_intrinsics(
|
| 560 |
+
intrinsics: torch.Tensor,
|
| 561 |
+
width: Union[int, torch.Tensor],
|
| 562 |
+
height: Union[int, torch.Tensor],
|
| 563 |
+
left: Union[int, torch.Tensor],
|
| 564 |
+
top: Union[int, torch.Tensor],
|
| 565 |
+
crop_width: Union[int, torch.Tensor],
|
| 566 |
+
crop_height: Union[int, torch.Tensor]
|
| 567 |
+
) -> torch.Tensor:
|
| 568 |
+
"""
|
| 569 |
+
Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]
|
| 570 |
+
|
| 571 |
+
Args:
|
| 572 |
+
intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop
|
| 573 |
+
width (int | torch.Tensor): [...] image width(s)
|
| 574 |
+
height (int | torch.Tensor): [...] image height(s)
|
| 575 |
+
left (int | torch.Tensor): [...] left crop boundary
|
| 576 |
+
top (int | torch.Tensor): [...] top crop boundary
|
| 577 |
+
crop_width (int | torch.Tensor): [...] crop width
|
| 578 |
+
crop_height (int | torch.Tensor): [...] crop height
|
| 579 |
+
|
| 580 |
+
Returns:
|
| 581 |
+
(torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)
|
| 582 |
+
"""
|
| 583 |
+
zeros = torch.zeros_like(width)
|
| 584 |
+
ones = torch.ones_like(width)
|
| 585 |
+
transform = torch.stack([
|
| 586 |
+
width / crop_width, zeros, -left / crop_width,
|
| 587 |
+
zeros, height / crop_height, -top / crop_height,
|
| 588 |
+
zeros, zeros, ones
|
| 589 |
+
]).reshape(*zeros.shape, 3, 3).to(intrinsics)
|
| 590 |
+
return transform @ intrinsics
|
| 591 |
+
|
| 592 |
+
|
| 593 |
+
@batched(1,0,0)
|
| 594 |
+
def pixel_to_uv(
|
| 595 |
+
pixel: torch.Tensor,
|
| 596 |
+
width: Union[int, torch.Tensor],
|
| 597 |
+
height: Union[int, torch.Tensor]
|
| 598 |
+
) -> torch.Tensor:
|
| 599 |
+
"""
|
| 600 |
+
Args:
|
| 601 |
+
pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
|
| 602 |
+
width (int | torch.Tensor): [...] image width(s)
|
| 603 |
+
height (int | torch.Tensor): [...] image height(s)
|
| 604 |
+
|
| 605 |
+
Returns:
|
| 606 |
+
(torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
|
| 607 |
+
"""
|
| 608 |
+
if not torch.is_floating_point(pixel):
|
| 609 |
+
pixel = pixel.float()
|
| 610 |
+
uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel)
|
| 611 |
+
return uv
|
| 612 |
+
|
| 613 |
+
|
| 614 |
+
@batched(1,0,0)
|
| 615 |
+
def uv_to_pixel(
|
| 616 |
+
uv: torch.Tensor,
|
| 617 |
+
width: Union[int, torch.Tensor],
|
| 618 |
+
height: Union[int, torch.Tensor]
|
| 619 |
+
) -> torch.Tensor:
|
| 620 |
+
"""
|
| 621 |
+
Args:
|
| 622 |
+
uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
|
| 623 |
+
width (int | torch.Tensor): [...] image width(s)
|
| 624 |
+
height (int | torch.Tensor): [...] image height(s)
|
| 625 |
+
|
| 626 |
+
Returns:
|
| 627 |
+
(torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
|
| 628 |
+
"""
|
| 629 |
+
pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5
|
| 630 |
+
return pixel
|
| 631 |
+
|
| 632 |
+
|
| 633 |
+
@batched(1,0,0)
|
| 634 |
+
def pixel_to_ndc(
|
| 635 |
+
pixel: torch.Tensor,
|
| 636 |
+
width: Union[int, torch.Tensor],
|
| 637 |
+
height: Union[int, torch.Tensor]
|
| 638 |
+
) -> torch.Tensor:
|
| 639 |
+
"""
|
| 640 |
+
Args:
|
| 641 |
+
pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
|
| 642 |
+
width (int | torch.Tensor): [...] image width(s)
|
| 643 |
+
height (int | torch.Tensor): [...] image height(s)
|
| 644 |
+
|
| 645 |
+
Returns:
|
| 646 |
+
(torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)
|
| 647 |
+
"""
|
| 648 |
+
if not torch.is_floating_point(pixel):
|
| 649 |
+
pixel = pixel.float()
|
| 650 |
+
ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \
|
| 651 |
+
+ torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device)
|
| 652 |
+
return ndc
|
| 653 |
+
|
| 654 |
+
|
| 655 |
+
@batched(0,0,0)
|
| 656 |
+
def project_depth(
|
| 657 |
+
depth: torch.Tensor,
|
| 658 |
+
near: Union[float, torch.Tensor],
|
| 659 |
+
far: Union[float, torch.Tensor]
|
| 660 |
+
) -> torch.Tensor:
|
| 661 |
+
"""
|
| 662 |
+
Project linear depth to depth value in screen space
|
| 663 |
+
|
| 664 |
+
Args:
|
| 665 |
+
depth (torch.Tensor): [...] depth value
|
| 666 |
+
near (float | torch.Tensor): [...] near plane to clip
|
| 667 |
+
far (float | torch.Tensor): [...] far plane to clip
|
| 668 |
+
|
| 669 |
+
Returns:
|
| 670 |
+
(torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]
|
| 671 |
+
"""
|
| 672 |
+
return (far - near * far / depth) / (far - near)
|
| 673 |
+
|
| 674 |
+
|
| 675 |
+
@batched(0,0,0)
|
| 676 |
+
def depth_buffer_to_linear(
|
| 677 |
+
depth: torch.Tensor,
|
| 678 |
+
near: Union[float, torch.Tensor],
|
| 679 |
+
far: Union[float, torch.Tensor]
|
| 680 |
+
) -> torch.Tensor:
|
| 681 |
+
"""
|
| 682 |
+
Linearize depth value to linear depth
|
| 683 |
+
|
| 684 |
+
Args:
|
| 685 |
+
depth (torch.Tensor): [...] screen depth value, ranging in [0, 1]
|
| 686 |
+
near (float | torch.Tensor): [...] near plane to clip
|
| 687 |
+
far (float | torch.Tensor): [...] far plane to clip
|
| 688 |
+
|
| 689 |
+
Returns:
|
| 690 |
+
(torch.Tensor): [...] linear depth
|
| 691 |
+
"""
|
| 692 |
+
return near * far / (far - (far - near) * depth)
|
| 693 |
+
|
| 694 |
+
|
| 695 |
+
@batched(2, 2, 2, 2)
|
| 696 |
+
def project_gl(
|
| 697 |
+
points: torch.Tensor,
|
| 698 |
+
model: torch.Tensor = None,
|
| 699 |
+
view: torch.Tensor = None,
|
| 700 |
+
perspective: torch.Tensor = None
|
| 701 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 702 |
+
"""
|
| 703 |
+
Project 3D points to 2D following the OpenGL convention (except for row major matrice)
|
| 704 |
+
|
| 705 |
+
Args:
|
| 706 |
+
points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last
|
| 707 |
+
dimension is 4, the points are assumed to be in homogeneous coordinates
|
| 708 |
+
model (torch.Tensor): [..., 4, 4] model matrix
|
| 709 |
+
view (torch.Tensor): [..., 4, 4] view matrix
|
| 710 |
+
perspective (torch.Tensor): [..., 4, 4] perspective matrix
|
| 711 |
+
|
| 712 |
+
Returns:
|
| 713 |
+
scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1].
|
| 714 |
+
The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
|
| 715 |
+
linear_depth (torch.Tensor): [..., N] linear depth
|
| 716 |
+
"""
|
| 717 |
+
assert perspective is not None, "perspective matrix is required"
|
| 718 |
+
|
| 719 |
+
if points.shape[-1] == 3:
|
| 720 |
+
points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 721 |
+
mvp = perspective if perspective is not None else torch.eye(4).to(points)
|
| 722 |
+
if view is not None:
|
| 723 |
+
mvp = mvp @ view
|
| 724 |
+
if model is not None:
|
| 725 |
+
mvp = mvp @ model
|
| 726 |
+
clip_coord = points @ mvp.transpose(-1, -2)
|
| 727 |
+
ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:]
|
| 728 |
+
scr_coord = ndc_coord * 0.5 + 0.5
|
| 729 |
+
linear_depth = clip_coord[..., 3]
|
| 730 |
+
return scr_coord, linear_depth
|
| 731 |
+
|
| 732 |
+
|
| 733 |
+
@batched(2, 2, 2)
|
| 734 |
+
def project_cv(
|
| 735 |
+
points: torch.Tensor,
|
| 736 |
+
extrinsics: torch.Tensor = None,
|
| 737 |
+
intrinsics: torch.Tensor = None
|
| 738 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 739 |
+
"""
|
| 740 |
+
Project 3D points to 2D following the OpenCV convention
|
| 741 |
+
|
| 742 |
+
Args:
|
| 743 |
+
points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last
|
| 744 |
+
dimension is 4, the points are assumed to be in homogeneous coordinates
|
| 745 |
+
extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
|
| 746 |
+
intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
|
| 747 |
+
|
| 748 |
+
Returns:
|
| 749 |
+
uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
|
| 750 |
+
The origin (0., 0.) is corresponding to the left & top
|
| 751 |
+
linear_depth (torch.Tensor): [..., N] linear depth
|
| 752 |
+
"""
|
| 753 |
+
assert intrinsics is not None, "intrinsics matrix is required"
|
| 754 |
+
if points.shape[-1] == 3:
|
| 755 |
+
points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 756 |
+
if extrinsics is not None:
|
| 757 |
+
points = points @ extrinsics.transpose(-1, -2)
|
| 758 |
+
points = points[..., :3] @ intrinsics.transpose(-2, -1)
|
| 759 |
+
uv_coord = points[..., :2] / points[..., 2:]
|
| 760 |
+
linear_depth = points[..., 2]
|
| 761 |
+
return uv_coord, linear_depth
|
| 762 |
+
|
| 763 |
+
|
| 764 |
+
@batched(2, 2, 2, 2)
|
| 765 |
+
def unproject_gl(
|
| 766 |
+
screen_coord: torch.Tensor,
|
| 767 |
+
model: torch.Tensor = None,
|
| 768 |
+
view: torch.Tensor = None,
|
| 769 |
+
perspective: torch.Tensor = None
|
| 770 |
+
) -> torch.Tensor:
|
| 771 |
+
"""
|
| 772 |
+
Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)
|
| 773 |
+
|
| 774 |
+
Args:
|
| 775 |
+
screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1].
|
| 776 |
+
The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
|
| 777 |
+
model (torch.Tensor): [..., 4, 4] model matrix
|
| 778 |
+
view (torch.Tensor): [..., 4, 4] view matrix
|
| 779 |
+
perspective (torch.Tensor): [..., 4, 4] perspective matrix
|
| 780 |
+
|
| 781 |
+
Returns:
|
| 782 |
+
points (torch.Tensor): [..., N, 3] 3d points
|
| 783 |
+
"""
|
| 784 |
+
assert perspective is not None, "perspective matrix is required"
|
| 785 |
+
ndc_xy = screen_coord * 2 - 1
|
| 786 |
+
clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1)
|
| 787 |
+
transform = perspective
|
| 788 |
+
if view is not None:
|
| 789 |
+
transform = transform @ view
|
| 790 |
+
if model is not None:
|
| 791 |
+
transform = transform @ model
|
| 792 |
+
transform = torch.inverse(transform)
|
| 793 |
+
points = clip_coord @ transform.transpose(-1, -2)
|
| 794 |
+
points = points[..., :3] / points[..., 3:]
|
| 795 |
+
return points
|
| 796 |
+
|
| 797 |
+
|
| 798 |
+
@batched(2, 1, 2, 2)
|
| 799 |
+
def unproject_cv(
|
| 800 |
+
uv_coord: torch.Tensor,
|
| 801 |
+
depth: torch.Tensor = None,
|
| 802 |
+
extrinsics: torch.Tensor = None,
|
| 803 |
+
intrinsics: torch.Tensor = None
|
| 804 |
+
) -> torch.Tensor:
|
| 805 |
+
"""
|
| 806 |
+
Unproject uv coordinates to 3D view space following the OpenCV convention
|
| 807 |
+
|
| 808 |
+
Args:
|
| 809 |
+
uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
|
| 810 |
+
The origin (0., 0.) is corresponding to the left & top
|
| 811 |
+
depth (torch.Tensor): [..., N] depth value
|
| 812 |
+
extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
|
| 813 |
+
intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
|
| 814 |
+
|
| 815 |
+
Returns:
|
| 816 |
+
points (torch.Tensor): [..., N, 3] 3d points
|
| 817 |
+
"""
|
| 818 |
+
assert intrinsics is not None, "intrinsics matrix is required"
|
| 819 |
+
points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
|
| 820 |
+
points = points @ torch.inverse(intrinsics).transpose(-2, -1)
|
| 821 |
+
if depth is not None:
|
| 822 |
+
points = points * depth[..., None]
|
| 823 |
+
if extrinsics is not None:
|
| 824 |
+
points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
| 825 |
+
points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
|
| 826 |
+
return points
|
| 827 |
+
|
| 828 |
+
|
| 829 |
+
def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
|
| 830 |
+
"""
|
| 831 |
+
Return the rotation matrices for one of the rotations about an axis
|
| 832 |
+
of which Euler angles describe, for each value of the angle given.
|
| 833 |
+
|
| 834 |
+
Args:
|
| 835 |
+
axis: Axis label "X" or "Y or "Z".
|
| 836 |
+
angle: any shape tensor of Euler angles in radians
|
| 837 |
+
|
| 838 |
+
Returns:
|
| 839 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 840 |
+
"""
|
| 841 |
+
|
| 842 |
+
cos = torch.cos(angle)
|
| 843 |
+
sin = torch.sin(angle)
|
| 844 |
+
one = torch.ones_like(angle)
|
| 845 |
+
zero = torch.zeros_like(angle)
|
| 846 |
+
|
| 847 |
+
if axis == "X":
|
| 848 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
| 849 |
+
elif axis == "Y":
|
| 850 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
| 851 |
+
elif axis == "Z":
|
| 852 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
| 853 |
+
else:
|
| 854 |
+
raise ValueError("letter must be either X, Y or Z.")
|
| 855 |
+
|
| 856 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
| 857 |
+
|
| 858 |
+
|
| 859 |
+
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
|
| 860 |
+
"""
|
| 861 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
| 862 |
+
|
| 863 |
+
Args:
|
| 864 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ
|
| 865 |
+
convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.
|
| 866 |
+
|
| 867 |
+
Returns:
|
| 868 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
| 869 |
+
"""
|
| 870 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
| 871 |
+
raise ValueError("Invalid input euler angles.")
|
| 872 |
+
if len(convention) != 3:
|
| 873 |
+
raise ValueError("Convention must have 3 letters.")
|
| 874 |
+
if convention[1] in (convention[0], convention[2]):
|
| 875 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 876 |
+
for letter in convention:
|
| 877 |
+
if letter not in ("X", "Y", "Z"):
|
| 878 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
| 879 |
+
matrices = [
|
| 880 |
+
euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)])
|
| 881 |
+
for c in convention
|
| 882 |
+
]
|
| 883 |
+
# return functools.reduce(torch.matmul, matrices)
|
| 884 |
+
return matrices[2] @ matrices[1] @ matrices[0]
|
| 885 |
+
|
| 886 |
+
|
| 887 |
+
def skew_symmetric(v: torch.Tensor):
|
| 888 |
+
"Skew symmetric matrix from a 3D vector"
|
| 889 |
+
assert v.shape[-1] == 3, "v must be 3D"
|
| 890 |
+
x, y, z = v.unbind(dim=-1)
|
| 891 |
+
zeros = torch.zeros_like(x)
|
| 892 |
+
return torch.stack([
|
| 893 |
+
zeros, -z, y,
|
| 894 |
+
z, zeros, -x,
|
| 895 |
+
-y, x, zeros,
|
| 896 |
+
], dim=-1).reshape(*v.shape[:-1], 3, 3)
|
| 897 |
+
|
| 898 |
+
|
| 899 |
+
def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor):
|
| 900 |
+
"Rotation matrix that rotates v1 to v2"
|
| 901 |
+
I = torch.eye(3).to(v1)
|
| 902 |
+
v1 = F.normalize(v1, dim=-1)
|
| 903 |
+
v2 = F.normalize(v2, dim=-1)
|
| 904 |
+
v = torch.cross(v1, v2, dim=-1)
|
| 905 |
+
c = torch.sum(v1 * v2, dim=-1)
|
| 906 |
+
K = skew_symmetric(v)
|
| 907 |
+
R = I + K + (1 / (1 + c))[None, None] * (K @ K)
|
| 908 |
+
return R
|
| 909 |
+
|
| 910 |
+
|
| 911 |
+
def _angle_from_tan(
|
| 912 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
| 913 |
+
) -> torch.Tensor:
|
| 914 |
+
"""
|
| 915 |
+
Extract the first or third Euler angle from the two members of
|
| 916 |
+
the matrix which are positive constant times its sine and cosine.
|
| 917 |
+
|
| 918 |
+
Args:
|
| 919 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
| 920 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
| 921 |
+
convention.
|
| 922 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
| 923 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
| 924 |
+
which means the relevant entries are in the same row of the
|
| 925 |
+
rotation matrix. If not, they are in the same column.
|
| 926 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
| 927 |
+
|
| 928 |
+
Returns:
|
| 929 |
+
Euler Angles in radians for each matrix in data as a tensor
|
| 930 |
+
of shape (...).
|
| 931 |
+
"""
|
| 932 |
+
|
| 933 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
| 934 |
+
if horizontal:
|
| 935 |
+
i2, i1 = i1, i2
|
| 936 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
| 937 |
+
if horizontal == even:
|
| 938 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
| 939 |
+
if tait_bryan:
|
| 940 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
| 941 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
| 942 |
+
|
| 943 |
+
|
| 944 |
+
def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
|
| 945 |
+
"""
|
| 946 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
| 947 |
+
NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d)
|
| 948 |
+
|
| 949 |
+
Args:
|
| 950 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
| 951 |
+
convention: Convention string of three uppercase letters.
|
| 952 |
+
|
| 953 |
+
Returns:
|
| 954 |
+
Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)
|
| 955 |
+
"""
|
| 956 |
+
if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'):
|
| 957 |
+
raise ValueError(f"Invalid convention {convention}.")
|
| 958 |
+
if not matrix.shape[-2:] == (3, 3):
|
| 959 |
+
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
| 960 |
+
|
| 961 |
+
i0 = 'XYZ'.index(convention[0])
|
| 962 |
+
i2 = 'XYZ'.index(convention[2])
|
| 963 |
+
tait_bryan = i0 != i2
|
| 964 |
+
if tait_bryan:
|
| 965 |
+
central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0))
|
| 966 |
+
else:
|
| 967 |
+
central_angle = torch.acos(matrix[..., i2, i2])
|
| 968 |
+
|
| 969 |
+
# Angles in composition order
|
| 970 |
+
o = [
|
| 971 |
+
_angle_from_tan(
|
| 972 |
+
convention[0], convention[1], matrix[..., i2, :], True, tait_bryan
|
| 973 |
+
),
|
| 974 |
+
central_angle,
|
| 975 |
+
_angle_from_tan(
|
| 976 |
+
convention[2], convention[1], matrix[..., i0], False, tait_bryan
|
| 977 |
+
),
|
| 978 |
+
]
|
| 979 |
+
return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1)
|
| 980 |
+
|
| 981 |
+
|
| 982 |
+
def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
| 983 |
+
"""Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation
|
| 984 |
+
|
| 985 |
+
Args:
|
| 986 |
+
axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
|
| 987 |
+
|
| 988 |
+
Returns:
|
| 989 |
+
torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters
|
| 990 |
+
"""
|
| 991 |
+
batch_shape = axis_angle.shape[:-1]
|
| 992 |
+
device, dtype = axis_angle.device, axis_angle.dtype
|
| 993 |
+
|
| 994 |
+
angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True)
|
| 995 |
+
axis = axis_angle / angle
|
| 996 |
+
|
| 997 |
+
cos = torch.cos(angle)[..., None, :]
|
| 998 |
+
sin = torch.sin(angle)[..., None, :]
|
| 999 |
+
|
| 1000 |
+
rx, ry, rz = torch.split(axis, 3, dim=-1)
|
| 1001 |
+
zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device)
|
| 1002 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3))
|
| 1003 |
+
|
| 1004 |
+
ident = torch.eye(3, dtype=dtype, device=device)
|
| 1005 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K)
|
| 1006 |
+
return rot_mat
|
| 1007 |
+
|
| 1008 |
+
|
| 1009 |
+
def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
| 1010 |
+
"""Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector)
|
| 1011 |
+
|
| 1012 |
+
Args:
|
| 1013 |
+
rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
|
| 1014 |
+
|
| 1015 |
+
Returns:
|
| 1016 |
+
torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices
|
| 1017 |
+
"""
|
| 1018 |
+
quat = matrix_to_quaternion(rot_mat)
|
| 1019 |
+
axis_angle = quaternion_to_axis_angle(quat, eps=eps)
|
| 1020 |
+
return axis_angle
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
| 1024 |
+
"""Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector)
|
| 1025 |
+
|
| 1026 |
+
Args:
|
| 1027 |
+
quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
|
| 1028 |
+
|
| 1029 |
+
Returns:
|
| 1030 |
+
torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions
|
| 1031 |
+
"""
|
| 1032 |
+
assert quaternion.shape[-1] == 4
|
| 1033 |
+
norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True)
|
| 1034 |
+
axis = quaternion[..., 1:] / norm.clamp(min=eps)
|
| 1035 |
+
angle = 2 * torch.atan2(norm, quaternion[..., 0:1])
|
| 1036 |
+
return angle * axis
|
| 1037 |
+
|
| 1038 |
+
|
| 1039 |
+
def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
| 1040 |
+
"""Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z)
|
| 1041 |
+
|
| 1042 |
+
Args:
|
| 1043 |
+
axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
|
| 1044 |
+
|
| 1045 |
+
Returns:
|
| 1046 |
+
torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters
|
| 1047 |
+
"""
|
| 1048 |
+
axis = F.normalize(axis_angle, dim=-1, eps=eps)
|
| 1049 |
+
angle = torch.norm(axis_angle, dim=-1, keepdim=True)
|
| 1050 |
+
quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1)
|
| 1051 |
+
return quat
|
| 1052 |
+
|
| 1053 |
+
|
| 1054 |
+
def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
| 1055 |
+
"""Convert 3x3 rotation matrix to quaternion (w, x, y, z)
|
| 1056 |
+
|
| 1057 |
+
Args:
|
| 1058 |
+
rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
|
| 1059 |
+
|
| 1060 |
+
Returns:
|
| 1061 |
+
torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices
|
| 1062 |
+
"""
|
| 1063 |
+
# Extract the diagonal and off-diagonal elements of the rotation matrix
|
| 1064 |
+
m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1)
|
| 1065 |
+
|
| 1066 |
+
diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1)
|
| 1067 |
+
M = torch.tensor([
|
| 1068 |
+
[1, 1, 1],
|
| 1069 |
+
[1, -1, -1],
|
| 1070 |
+
[-1, 1, -1],
|
| 1071 |
+
[-1, -1, 1]
|
| 1072 |
+
], dtype=rot_mat.dtype, device=rot_mat.device)
|
| 1073 |
+
wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5)
|
| 1074 |
+
_, max_idx = wxyz.max(dim=-1)
|
| 1075 |
+
xw = torch.sign(m21 - m12)
|
| 1076 |
+
yw = torch.sign(m02 - m20)
|
| 1077 |
+
zw = torch.sign(m10 - m01)
|
| 1078 |
+
yz = torch.sign(m21 + m12)
|
| 1079 |
+
xz = torch.sign(m02 + m20)
|
| 1080 |
+
xy = torch.sign(m01 + m10)
|
| 1081 |
+
ones = torch.ones_like(xw)
|
| 1082 |
+
sign = torch.where(
|
| 1083 |
+
max_idx[..., None] == 0,
|
| 1084 |
+
torch.stack([ones, xw, yw, zw], dim=-1),
|
| 1085 |
+
torch.where(
|
| 1086 |
+
max_idx[..., None] == 1,
|
| 1087 |
+
torch.stack([xw, ones, xy, xz], dim=-1),
|
| 1088 |
+
torch.where(
|
| 1089 |
+
max_idx[..., None] == 2,
|
| 1090 |
+
torch.stack([yw, xy, ones, yz], dim=-1),
|
| 1091 |
+
torch.stack([zw, xz, yz, ones], dim=-1)
|
| 1092 |
+
)
|
| 1093 |
+
)
|
| 1094 |
+
)
|
| 1095 |
+
quat = sign * wxyz
|
| 1096 |
+
quat = F.normalize(quat, dim=-1, eps=eps)
|
| 1097 |
+
return quat
|
| 1098 |
+
|
| 1099 |
+
|
| 1100 |
+
def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
|
| 1101 |
+
"""Converts a batch of quaternions (w, x, y, z) to rotation matrices
|
| 1102 |
+
|
| 1103 |
+
Args:
|
| 1104 |
+
quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
|
| 1105 |
+
|
| 1106 |
+
Returns:
|
| 1107 |
+
torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions
|
| 1108 |
+
"""
|
| 1109 |
+
assert quaternion.shape[-1] == 4
|
| 1110 |
+
quaternion = F.normalize(quaternion, dim=-1, eps=eps)
|
| 1111 |
+
w, x, y, z = quaternion.unbind(dim=-1)
|
| 1112 |
+
zeros = torch.zeros_like(w)
|
| 1113 |
+
I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device)
|
| 1114 |
+
xyz = quaternion[..., 1:]
|
| 1115 |
+
A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None]
|
| 1116 |
+
B = torch.stack([
|
| 1117 |
+
zeros, -z, y,
|
| 1118 |
+
z, zeros, -x,
|
| 1119 |
+
-y, x, zeros
|
| 1120 |
+
], dim=-1).unflatten(-1, (3, 3))
|
| 1121 |
+
rot_mat = I + 2 * (A + w[..., None, None] * B)
|
| 1122 |
+
return rot_mat
|
| 1123 |
+
|
| 1124 |
+
|
| 1125 |
+
def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
|
| 1126 |
+
"""Spherical linear interpolation between two rotation matrices
|
| 1127 |
+
|
| 1128 |
+
Args:
|
| 1129 |
+
rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix
|
| 1130 |
+
rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix
|
| 1131 |
+
t (torch.Tensor): scalar or shape (...,), the interpolation factor
|
| 1132 |
+
|
| 1133 |
+
Returns:
|
| 1134 |
+
torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix
|
| 1135 |
+
"""
|
| 1136 |
+
assert rot_mat_1.shape[-2:] == (3, 3)
|
| 1137 |
+
rot_vec_1 = matrix_to_axis_angle(rot_mat_1)
|
| 1138 |
+
rot_vec_2 = matrix_to_axis_angle(rot_mat_2)
|
| 1139 |
+
if isinstance(t, Number):
|
| 1140 |
+
t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device)
|
| 1141 |
+
rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2
|
| 1142 |
+
rot_mat = axis_angle_to_matrix(rot_vec)
|
| 1143 |
+
return rot_mat
|
| 1144 |
+
|
| 1145 |
+
|
| 1146 |
+
def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
|
| 1147 |
+
"""Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
|
| 1148 |
+
|
| 1149 |
+
Args:
|
| 1150 |
+
ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
|
| 1151 |
+
ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
|
| 1152 |
+
t (torch.Tensor): scalar or shape (...,), the interpolation factor
|
| 1153 |
+
|
| 1154 |
+
Returns:
|
| 1155 |
+
torch.Tensor: shape (..., 4, 4), the interpolated camera pose
|
| 1156 |
+
"""
|
| 1157 |
+
return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t))
|
| 1158 |
+
|
| 1159 |
+
|
| 1160 |
+
def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]):
|
| 1161 |
+
"""Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
|
| 1162 |
+
|
| 1163 |
+
Args:
|
| 1164 |
+
ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
|
| 1165 |
+
ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
|
| 1166 |
+
t (torch.Tensor): scalar or shape (...,), the interpolation factor
|
| 1167 |
+
|
| 1168 |
+
Returns:
|
| 1169 |
+
torch.Tensor: shape (..., 4, 4), the interpolated camera pose
|
| 1170 |
+
"""
|
| 1171 |
+
return interpolate_extrinsics(view1, view2, t)
|
| 1172 |
+
|
| 1173 |
+
|
| 1174 |
+
def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]):
|
| 1175 |
+
assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4)
|
| 1176 |
+
if isinstance(t, Number):
|
| 1177 |
+
t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device)
|
| 1178 |
+
pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3]
|
| 1179 |
+
rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t)
|
| 1180 |
+
transform = torch.cat([rot, pos[..., None]], dim=-1)
|
| 1181 |
+
transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2)
|
| 1182 |
+
return transform
|
| 1183 |
+
|
| 1184 |
+
|
| 1185 |
+
def extrinsics_to_essential(extrinsics: torch.Tensor):
|
| 1186 |
+
"""
|
| 1187 |
+
extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`
|
| 1188 |
+
|
| 1189 |
+
Args:
|
| 1190 |
+
extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
|
| 1191 |
+
|
| 1192 |
+
Returns:
|
| 1193 |
+
(torch.Tensor): [..., 3, 3] essential matrix
|
| 1194 |
+
"""
|
| 1195 |
+
assert extrinsics.shape[-2:] == (4, 4)
|
| 1196 |
+
R = extrinsics[..., :3, :3]
|
| 1197 |
+
t = extrinsics[..., :3, 3]
|
| 1198 |
+
zeros = torch.zeros_like(t)
|
| 1199 |
+
t_x = torch.stack([
|
| 1200 |
+
zeros, -t[..., 2], t[..., 1],
|
| 1201 |
+
t[..., 2], zeros, -t[..., 0],
|
| 1202 |
+
-t[..., 1], t[..., 0], zeros
|
| 1203 |
+
]).reshape(*t.shape[:-1], 3, 3)
|
| 1204 |
+
return R @ t_x
|
| 1205 |
+
|
| 1206 |
+
|
| 1207 |
+
def to4x4(R: torch.Tensor, t: torch.Tensor):
|
| 1208 |
+
"""
|
| 1209 |
+
Compose rotation matrix and translation vector to 4x4 transformation matrix
|
| 1210 |
+
|
| 1211 |
+
Args:
|
| 1212 |
+
R (torch.Tensor): [..., 3, 3] rotation matrix
|
| 1213 |
+
t (torch.Tensor): [..., 3] translation vector
|
| 1214 |
+
|
| 1215 |
+
Returns:
|
| 1216 |
+
(torch.Tensor): [..., 4, 4] transformation matrix
|
| 1217 |
+
"""
|
| 1218 |
+
assert R.shape[-2:] == (3, 3)
|
| 1219 |
+
assert t.shape[-1] == 3
|
| 1220 |
+
assert R.shape[:-2] == t.shape[:-1]
|
| 1221 |
+
return torch.cat([
|
| 1222 |
+
torch.cat([R, t[..., None]], dim=-1),
|
| 1223 |
+
torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4)
|
| 1224 |
+
], dim=-2)
|
| 1225 |
+
|
| 1226 |
+
|
| 1227 |
+
def rotation_matrix_2d(theta: Union[float, torch.Tensor]):
|
| 1228 |
+
"""
|
| 1229 |
+
2x2 matrix for 2D rotation
|
| 1230 |
+
|
| 1231 |
+
Args:
|
| 1232 |
+
theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
|
| 1233 |
+
|
| 1234 |
+
Returns:
|
| 1235 |
+
(torch.Tensor): (..., 2, 2) rotation matrix
|
| 1236 |
+
"""
|
| 1237 |
+
if isinstance(theta, float):
|
| 1238 |
+
theta = torch.tensor(theta)
|
| 1239 |
+
return torch.stack([
|
| 1240 |
+
torch.cos(theta), -torch.sin(theta),
|
| 1241 |
+
torch.sin(theta), torch.cos(theta),
|
| 1242 |
+
], dim=-1).unflatten(-1, (2, 2))
|
| 1243 |
+
|
| 1244 |
+
|
| 1245 |
+
def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None):
|
| 1246 |
+
"""
|
| 1247 |
+
3x3 matrix for 2D rotation around a center
|
| 1248 |
+
```
|
| 1249 |
+
[[Rxx, Rxy, tx],
|
| 1250 |
+
[Ryx, Ryy, ty],
|
| 1251 |
+
[0, 0, 1]]
|
| 1252 |
+
```
|
| 1253 |
+
Args:
|
| 1254 |
+
theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
|
| 1255 |
+
center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0)
|
| 1256 |
+
|
| 1257 |
+
Returns:
|
| 1258 |
+
(torch.Tensor): (..., 3, 3) transformation matrix
|
| 1259 |
+
"""
|
| 1260 |
+
if isinstance(theta, float):
|
| 1261 |
+
theta = torch.tensor(theta)
|
| 1262 |
+
if center is not None:
|
| 1263 |
+
theta = theta.to(center)
|
| 1264 |
+
if center is None:
|
| 1265 |
+
center = torch.zeros(2).to(theta).expand(*theta.shape, -1)
|
| 1266 |
+
R = rotation_matrix_2d(theta)
|
| 1267 |
+
return torch.cat([
|
| 1268 |
+
torch.cat([
|
| 1269 |
+
R,
|
| 1270 |
+
center[..., :, None] - R @ center[..., :, None],
|
| 1271 |
+
], dim=-1),
|
| 1272 |
+
torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1),
|
| 1273 |
+
], dim=-2)
|
| 1274 |
+
|
| 1275 |
+
|
| 1276 |
+
def translate_2d(translation: torch.Tensor):
|
| 1277 |
+
"""
|
| 1278 |
+
Translation matrix for 2D translation
|
| 1279 |
+
```
|
| 1280 |
+
[[1, 0, tx],
|
| 1281 |
+
[0, 1, ty],
|
| 1282 |
+
[0, 0, 1]]
|
| 1283 |
+
```
|
| 1284 |
+
Args:
|
| 1285 |
+
translation (torch.Tensor): translation vector, arbitrary shape (..., 2)
|
| 1286 |
+
|
| 1287 |
+
Returns:
|
| 1288 |
+
(torch.Tensor): (..., 3, 3) transformation matrix
|
| 1289 |
+
"""
|
| 1290 |
+
return torch.cat([
|
| 1291 |
+
torch.cat([
|
| 1292 |
+
torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
|
| 1293 |
+
translation[..., None],
|
| 1294 |
+
], dim=-1),
|
| 1295 |
+
torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
|
| 1296 |
+
], dim=-2)
|
| 1297 |
+
|
| 1298 |
+
|
| 1299 |
+
def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None):
|
| 1300 |
+
"""
|
| 1301 |
+
Scale matrix for 2D scaling
|
| 1302 |
+
```
|
| 1303 |
+
[[s, 0, tx],
|
| 1304 |
+
[0, s, ty],
|
| 1305 |
+
[0, 0, 1]]
|
| 1306 |
+
```
|
| 1307 |
+
Args:
|
| 1308 |
+
scale (float | torch.Tensor): scale factor, arbitrary shape (...,)
|
| 1309 |
+
center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0)
|
| 1310 |
+
|
| 1311 |
+
Returns:
|
| 1312 |
+
(torch.Tensor): (..., 3, 3) transformation matrix
|
| 1313 |
+
"""
|
| 1314 |
+
if isinstance(scale, float):
|
| 1315 |
+
scale = torch.tensor(scale)
|
| 1316 |
+
if center is not None:
|
| 1317 |
+
scale = scale.to(center)
|
| 1318 |
+
if center is None:
|
| 1319 |
+
center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1)
|
| 1320 |
+
return torch.cat([
|
| 1321 |
+
torch.cat([
|
| 1322 |
+
scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1),
|
| 1323 |
+
center[..., :, None] - center[..., :, None] * scale[..., None, None],
|
| 1324 |
+
], dim=-1),
|
| 1325 |
+
torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1),
|
| 1326 |
+
], dim=-2)
|
| 1327 |
+
|
| 1328 |
+
|
| 1329 |
+
def apply_2d(transform: torch.Tensor, points: torch.Tensor):
|
| 1330 |
+
"""
|
| 1331 |
+
Apply (3x3 or 2x3) 2D affine transformation to points
|
| 1332 |
+
```
|
| 1333 |
+
p = R @ p + t
|
| 1334 |
+
```
|
| 1335 |
+
Args:
|
| 1336 |
+
transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix
|
| 1337 |
+
points (torch.Tensor): (..., N, 2) points to transform
|
| 1338 |
+
|
| 1339 |
+
Returns:
|
| 1340 |
+
(torch.Tensor): (..., N, 2) transformed points
|
| 1341 |
+
"""
|
| 1342 |
+
assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3"
|
| 1343 |
+
assert points.shape[-1] == 2, "points must be 2D"
|
| 1344 |
+
return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2]
|
moge/model/utils.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
def wrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 13 |
+
from torch.utils.checkpoint import checkpoint
|
| 14 |
+
class _CheckpointingWrapper(module.__class__):
|
| 15 |
+
_restore_cls = module.__class__
|
| 16 |
+
def forward(self, *args, **kwargs):
|
| 17 |
+
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
| 18 |
+
|
| 19 |
+
module.__class__ = _CheckpointingWrapper
|
| 20 |
+
return module
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def unwrap_module_with_gradient_checkpointing(module: nn.Module):
|
| 24 |
+
module.__class__ = module.__class__._restore_cls
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def wrap_dinov2_attention_with_sdpa(module: nn.Module):
|
| 28 |
+
assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
|
| 29 |
+
class _AttentionWrapper(module.__class__):
|
| 30 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 31 |
+
B, N, C = x.shape
|
| 32 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
| 33 |
+
|
| 34 |
+
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
| 35 |
+
|
| 36 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
| 37 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
| 38 |
+
|
| 39 |
+
x = self.proj(x)
|
| 40 |
+
x = self.proj_drop(x)
|
| 41 |
+
return x
|
| 42 |
+
module.__class__ = _AttentionWrapper
|
| 43 |
+
return module
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
|
| 47 |
+
group_to_use = torch.distributed.group.WORLD
|
| 48 |
+
world_size = group_to_use.size()
|
| 49 |
+
grad = bucket.buffer()
|
| 50 |
+
grad.div_(world_size)
|
| 51 |
+
torch.distributed.all_reduce(grad, group=group_to_use)
|
| 52 |
+
fut = torch.futures.Future()
|
| 53 |
+
fut.set_result(grad)
|
| 54 |
+
return fut
|
moge/model/v2.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# This file is modified from MoGe:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
# Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
from typing import *
|
| 9 |
+
from numbers import Number
|
| 10 |
+
from functools import partial
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
import torch.nn.functional as F
|
| 17 |
+
import torch.utils
|
| 18 |
+
import torch.utils.checkpoint
|
| 19 |
+
import torch.amp
|
| 20 |
+
import torch.version
|
| 21 |
+
import utils3d
|
| 22 |
+
from huggingface_hub import hf_hub_download
|
| 23 |
+
|
| 24 |
+
from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3
|
| 25 |
+
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
| 26 |
+
from .modules import DINOv2Encoder, MLP, ConvStack
|
| 27 |
+
from . import transforms
|
| 28 |
+
|
| 29 |
+
from einops import rearrange
|
| 30 |
+
|
| 31 |
+
def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Get image space UV grid, ranging in [0, 1].
|
| 34 |
+
|
| 35 |
+
>>> image_uv(10, 10):
|
| 36 |
+
[[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
|
| 37 |
+
[[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
|
| 38 |
+
... ... ...
|
| 39 |
+
[[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
width (int): image width
|
| 43 |
+
height (int): image height
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
torch.Tensor: shape (height, width, 2)
|
| 47 |
+
"""
|
| 48 |
+
if left is None: left = 0
|
| 49 |
+
if top is None: top = 0
|
| 50 |
+
if right is None: right = width
|
| 51 |
+
if bottom is None: bottom = height
|
| 52 |
+
u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype)
|
| 53 |
+
v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype)
|
| 54 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
| 55 |
+
uv = torch.stack([u, v], dim=-1)
|
| 56 |
+
return uv
|
| 57 |
+
|
| 58 |
+
def depth_to_points(depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None):
|
| 59 |
+
height, width = depth.shape[-2:]
|
| 60 |
+
uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device)
|
| 61 |
+
pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics[..., None, :, :], extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None)
|
| 62 |
+
return pts
|
| 63 |
+
|
| 64 |
+
class MoGeModel(nn.Module):
|
| 65 |
+
encoder: DINOv2Encoder
|
| 66 |
+
neck: ConvStack
|
| 67 |
+
points_head: ConvStack
|
| 68 |
+
mask_head: ConvStack
|
| 69 |
+
scale_head: MLP
|
| 70 |
+
|
| 71 |
+
def __init__(self,
|
| 72 |
+
encoder: Dict[str, Any],
|
| 73 |
+
neck: Dict[str, Any],
|
| 74 |
+
points_head: Dict[str, Any] = None,
|
| 75 |
+
mask_head: Dict[str, Any] = None,
|
| 76 |
+
normal_head: Dict[str, Any] = None,
|
| 77 |
+
scale_head: Dict[str, Any] = None,
|
| 78 |
+
remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
|
| 79 |
+
num_tokens_range: List[int] = [1200, 3600],
|
| 80 |
+
**deprecated_kwargs
|
| 81 |
+
):
|
| 82 |
+
super(MoGeModel, self).__init__()
|
| 83 |
+
if deprecated_kwargs:
|
| 84 |
+
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
|
| 85 |
+
|
| 86 |
+
self.remap_output = remap_output
|
| 87 |
+
self.num_tokens_range = num_tokens_range
|
| 88 |
+
|
| 89 |
+
self.encoder = DINOv2Encoder(**encoder)
|
| 90 |
+
self.neck = ConvStack(**neck)
|
| 91 |
+
if points_head is not None:
|
| 92 |
+
self.points_head = ConvStack(**points_head)
|
| 93 |
+
if mask_head is not None:
|
| 94 |
+
self.mask_head = ConvStack(**mask_head)
|
| 95 |
+
if normal_head is not None:
|
| 96 |
+
self.normal_head = ConvStack(**normal_head)
|
| 97 |
+
if scale_head is not None:
|
| 98 |
+
self.scale_head = MLP(**scale_head)
|
| 99 |
+
|
| 100 |
+
@property
|
| 101 |
+
def device(self) -> torch.device:
|
| 102 |
+
return next(self.parameters()).device
|
| 103 |
+
|
| 104 |
+
@property
|
| 105 |
+
def dtype(self) -> torch.dtype:
|
| 106 |
+
return next(self.parameters()).dtype
|
| 107 |
+
|
| 108 |
+
@classmethod
|
| 109 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
|
| 110 |
+
"""
|
| 111 |
+
Load a model from a checkpoint file.
|
| 112 |
+
|
| 113 |
+
### Parameters:
|
| 114 |
+
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
|
| 115 |
+
- `compiled`
|
| 116 |
+
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
|
| 117 |
+
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
|
| 118 |
+
|
| 119 |
+
### Returns:
|
| 120 |
+
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
|
| 121 |
+
"""
|
| 122 |
+
if Path(pretrained_model_name_or_path).exists():
|
| 123 |
+
checkpoint_path = pretrained_model_name_or_path
|
| 124 |
+
else:
|
| 125 |
+
checkpoint_path = hf_hub_download(
|
| 126 |
+
repo_id=pretrained_model_name_or_path,
|
| 127 |
+
repo_type="model",
|
| 128 |
+
filename="model.pt",
|
| 129 |
+
**hf_kwargs
|
| 130 |
+
)
|
| 131 |
+
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
|
| 132 |
+
|
| 133 |
+
model_config = checkpoint['model_config']
|
| 134 |
+
if model_kwargs is not None:
|
| 135 |
+
model_config.update(model_kwargs)
|
| 136 |
+
model = cls(**model_config)
|
| 137 |
+
model.load_state_dict(checkpoint['model'], strict=False)
|
| 138 |
+
|
| 139 |
+
return model
|
| 140 |
+
|
| 141 |
+
def init_weights(self):
|
| 142 |
+
self.encoder.init_weights()
|
| 143 |
+
|
| 144 |
+
def enable_gradient_checkpointing(self):
|
| 145 |
+
self.encoder.enable_gradient_checkpointing()
|
| 146 |
+
self.neck.enable_gradient_checkpointing()
|
| 147 |
+
for head in ['points_head', 'normal_head', 'mask_head']:
|
| 148 |
+
if hasattr(self, head):
|
| 149 |
+
getattr(self, head).enable_gradient_checkpointing()
|
| 150 |
+
|
| 151 |
+
def enable_pytorch_native_sdpa(self):
|
| 152 |
+
self.encoder.enable_pytorch_native_sdpa()
|
| 153 |
+
|
| 154 |
+
def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
|
| 155 |
+
if self.remap_output == 'linear':
|
| 156 |
+
pass
|
| 157 |
+
elif self.remap_output =='sinh':
|
| 158 |
+
points = torch.sinh(points)
|
| 159 |
+
elif self.remap_output == 'exp':
|
| 160 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 161 |
+
z = torch.exp(z)
|
| 162 |
+
points = torch.cat([xy * z, z], dim=-1)
|
| 163 |
+
elif self.remap_output =='sinh_exp':
|
| 164 |
+
xy, z = points.split([2, 1], dim=-1)
|
| 165 |
+
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
| 166 |
+
else:
|
| 167 |
+
raise ValueError(f"Invalid remap output type: {self.remap_output}")
|
| 168 |
+
return points
|
| 169 |
+
|
| 170 |
+
@torch.inference_mode()
|
| 171 |
+
def infer_feature_tokens(self, image: torch.Tensor, num_tokens: int, tokens_layer: int = -1) -> torch.Tensor:
|
| 172 |
+
batch_size, _, img_h, img_w = image.shape
|
| 173 |
+
device, dtype = image.device, image.dtype
|
| 174 |
+
|
| 175 |
+
aspect_ratio = img_w / img_h
|
| 176 |
+
base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
|
| 177 |
+
num_tokens = base_h * base_w
|
| 178 |
+
|
| 179 |
+
# Backbones encoding
|
| 180 |
+
features = self.encoder(image, base_h, base_w, return_class_token=False)
|
| 181 |
+
features = [features, None, None, None, None]
|
| 182 |
+
|
| 183 |
+
# Concat UVs for aspect ratio input
|
| 184 |
+
for level in range(5):
|
| 185 |
+
uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
|
| 186 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 187 |
+
if features[level] is None:
|
| 188 |
+
features[level] = uv
|
| 189 |
+
else:
|
| 190 |
+
features[level] = torch.concat([features[level], uv], dim=1)
|
| 191 |
+
|
| 192 |
+
# Shared neck
|
| 193 |
+
features = self.neck(features)[tokens_layer]
|
| 194 |
+
return features
|
| 195 |
+
|
| 196 |
+
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
|
| 197 |
+
batch_size, _, img_h, img_w = image.shape
|
| 198 |
+
device, dtype = image.device, image.dtype
|
| 199 |
+
|
| 200 |
+
aspect_ratio = img_w / img_h
|
| 201 |
+
base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
|
| 202 |
+
num_tokens = base_h * base_w
|
| 203 |
+
|
| 204 |
+
# Backbones encoding
|
| 205 |
+
features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
|
| 206 |
+
features = [features, None, None, None, None]
|
| 207 |
+
|
| 208 |
+
# Concat UVs for aspect ratio input
|
| 209 |
+
for level in range(5):
|
| 210 |
+
uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
|
| 211 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
|
| 212 |
+
if features[level] is None:
|
| 213 |
+
features[level] = uv
|
| 214 |
+
else:
|
| 215 |
+
features[level] = torch.concat([features[level], uv], dim=1)
|
| 216 |
+
|
| 217 |
+
# Shared neck
|
| 218 |
+
features = self.neck(features)
|
| 219 |
+
|
| 220 |
+
# Heads decoding
|
| 221 |
+
points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head'])
|
| 222 |
+
metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
|
| 223 |
+
|
| 224 |
+
# Resize
|
| 225 |
+
points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask])
|
| 226 |
+
|
| 227 |
+
# Remap output
|
| 228 |
+
if points is not None:
|
| 229 |
+
points = points.permute(0, 2, 3, 1)
|
| 230 |
+
points = self._remap_points(points) # slightly improves the performance in case of very large output values
|
| 231 |
+
if normal is not None:
|
| 232 |
+
normal = normal.permute(0, 2, 3, 1)
|
| 233 |
+
normal = F.normalize(normal, dim=-1)
|
| 234 |
+
if mask is not None:
|
| 235 |
+
mask = mask.squeeze(1).sigmoid()
|
| 236 |
+
if metric_scale is not None:
|
| 237 |
+
metric_scale = metric_scale.squeeze(1).exp()
|
| 238 |
+
|
| 239 |
+
return_dict = {
|
| 240 |
+
'points': points,
|
| 241 |
+
'normal': normal,
|
| 242 |
+
'mask': mask,
|
| 243 |
+
'metric_scale': metric_scale
|
| 244 |
+
}
|
| 245 |
+
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 246 |
+
|
| 247 |
+
return return_dict
|
| 248 |
+
|
| 249 |
+
@torch.inference_mode()
|
| 250 |
+
def infer(
|
| 251 |
+
self,
|
| 252 |
+
image: torch.Tensor,
|
| 253 |
+
num_tokens: int = None,
|
| 254 |
+
resolution_level: int = 9,
|
| 255 |
+
force_projection: bool = True,
|
| 256 |
+
apply_mask: Literal[False, True, 'blend'] = True,
|
| 257 |
+
fov_x: Optional[Union[Number, torch.Tensor]] = None,
|
| 258 |
+
use_fp16: bool = True,
|
| 259 |
+
) -> Dict[str, torch.Tensor]:
|
| 260 |
+
"""
|
| 261 |
+
User-friendly inference function
|
| 262 |
+
|
| 263 |
+
### Parameters
|
| 264 |
+
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
|
| 265 |
+
- `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
|
| 266 |
+
More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
|
| 267 |
+
- `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
|
| 268 |
+
- `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
|
| 269 |
+
- `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
|
| 270 |
+
- `use_fp16`: if True, use mixed precision to speed up inference. Default: True
|
| 271 |
+
|
| 272 |
+
### Returns
|
| 273 |
+
|
| 274 |
+
A dictionary containing the following keys:
|
| 275 |
+
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
|
| 276 |
+
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
|
| 277 |
+
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
|
| 278 |
+
"""
|
| 279 |
+
if image.dim() == 3:
|
| 280 |
+
omit_batch_dim = True
|
| 281 |
+
image = image.unsqueeze(0)
|
| 282 |
+
else:
|
| 283 |
+
omit_batch_dim = False
|
| 284 |
+
image = image.to(dtype=self.dtype, device=self.device)
|
| 285 |
+
|
| 286 |
+
original_height, original_width = image.shape[-2:]
|
| 287 |
+
area = original_height * original_width
|
| 288 |
+
aspect_ratio = original_width / original_height
|
| 289 |
+
|
| 290 |
+
# Determine the number of base tokens to use
|
| 291 |
+
if num_tokens is None:
|
| 292 |
+
min_tokens, max_tokens = self.num_tokens_range
|
| 293 |
+
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
|
| 294 |
+
|
| 295 |
+
# Forward pass
|
| 296 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
|
| 297 |
+
output = self.forward(image, num_tokens=num_tokens)
|
| 298 |
+
points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale'])
|
| 299 |
+
|
| 300 |
+
# Always process the output in fp32 precision
|
| 301 |
+
points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x])
|
| 302 |
+
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
|
| 303 |
+
if mask is not None:
|
| 304 |
+
mask_binary = mask > 0.5
|
| 305 |
+
else:
|
| 306 |
+
mask_binary = None
|
| 307 |
+
|
| 308 |
+
if points is not None:
|
| 309 |
+
# Convert affine point map to camera-space. Recover depth and intrinsics from point map.
|
| 310 |
+
# NOTE: Focal here is the focal length relative to half the image diagonal
|
| 311 |
+
if fov_x is None:
|
| 312 |
+
# Recover focal and shift from predicted point map
|
| 313 |
+
focal, shift = recover_focal_shift(points, mask_binary)
|
| 314 |
+
else:
|
| 315 |
+
# Focal is known, recover shift only
|
| 316 |
+
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
|
| 317 |
+
if focal.ndim == 0:
|
| 318 |
+
focal = focal[None].expand(points.shape[0])
|
| 319 |
+
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
|
| 320 |
+
fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
|
| 321 |
+
intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
|
| 322 |
+
points[..., 2] += shift[..., None, None]
|
| 323 |
+
if mask_binary is not None:
|
| 324 |
+
mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
|
| 325 |
+
depth = points[..., 2].clone()
|
| 326 |
+
else:
|
| 327 |
+
depth, intrinsics = None, None
|
| 328 |
+
|
| 329 |
+
# If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
|
| 330 |
+
if force_projection and depth is not None:
|
| 331 |
+
points = depth_to_points(depth, intrinsics=intrinsics)
|
| 332 |
+
|
| 333 |
+
# Apply metric scale
|
| 334 |
+
if metric_scale is not None:
|
| 335 |
+
if points is not None:
|
| 336 |
+
points *= metric_scale[:, None, None, None]
|
| 337 |
+
if depth is not None:
|
| 338 |
+
depth *= metric_scale[:, None, None]
|
| 339 |
+
|
| 340 |
+
# Apply mask
|
| 341 |
+
if apply_mask and mask_binary is not None:
|
| 342 |
+
points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
|
| 343 |
+
depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
|
| 344 |
+
normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None
|
| 345 |
+
|
| 346 |
+
return_dict = {
|
| 347 |
+
'points': points,
|
| 348 |
+
'intrinsics': intrinsics,
|
| 349 |
+
'depth': depth,
|
| 350 |
+
'mask': mask_binary,
|
| 351 |
+
'normal': normal,
|
| 352 |
+
'metric_scale': metric_scale
|
| 353 |
+
}
|
| 354 |
+
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
| 355 |
+
|
| 356 |
+
if omit_batch_dim:
|
| 357 |
+
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
|
| 358 |
+
|
| 359 |
+
return return_dict
|
moge/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
moge/utils/download.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
from typing import *
|
| 8 |
+
import requests
|
| 9 |
+
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
__all__ = ["download_file", "download_bytes"]
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
|
| 17 |
+
# Ensure headers is a dict if not provided
|
| 18 |
+
headers = headers or {}
|
| 19 |
+
|
| 20 |
+
# Initialize local variables
|
| 21 |
+
file_path = Path(filepath)
|
| 22 |
+
downloaded_bytes = 0
|
| 23 |
+
|
| 24 |
+
# Check if we should resume the download
|
| 25 |
+
if resume and file_path.exists():
|
| 26 |
+
downloaded_bytes = file_path.stat().st_size
|
| 27 |
+
headers['Range'] = f"bytes={downloaded_bytes}-"
|
| 28 |
+
|
| 29 |
+
# Make a GET request to fetch the file
|
| 30 |
+
with requests.get(url, stream=True, headers=headers) as response:
|
| 31 |
+
response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
|
| 32 |
+
|
| 33 |
+
# Calculate the total size to download
|
| 34 |
+
total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
|
| 35 |
+
|
| 36 |
+
# Display a progress bar while downloading
|
| 37 |
+
with (
|
| 38 |
+
tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
|
| 39 |
+
open(file_path, 'ab') as file,
|
| 40 |
+
):
|
| 41 |
+
# Set the initial position of the progress bar
|
| 42 |
+
pbar.update(downloaded_bytes)
|
| 43 |
+
|
| 44 |
+
# Write the content to the file in chunks
|
| 45 |
+
for chunk in response.iter_content(chunk_size=4096):
|
| 46 |
+
file.write(chunk)
|
| 47 |
+
pbar.update(len(chunk))
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def download_bytes(url: str, headers: dict = None) -> bytes:
|
| 51 |
+
# Ensure headers is a dict if not provided
|
| 52 |
+
headers = headers or {}
|
| 53 |
+
|
| 54 |
+
# Make a GET request to fetch the file
|
| 55 |
+
with requests.get(url, stream=True, headers=headers) as response:
|
| 56 |
+
response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
|
| 57 |
+
|
| 58 |
+
# Read the content of the response
|
| 59 |
+
return response.content
|
| 60 |
+
|
moge/utils/geometry_numpy.py
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
from functools import partial
|
| 8 |
+
import math
|
| 9 |
+
|
| 10 |
+
import cv2
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy.signal import fftconvolve
|
| 13 |
+
import numpy as np
|
| 14 |
+
import utils3d
|
| 15 |
+
|
| 16 |
+
from .tools import timeit
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
| 20 |
+
if w is None:
|
| 21 |
+
return np.mean(x, axis=axis)
|
| 22 |
+
else:
|
| 23 |
+
w = w.astype(x.dtype)
|
| 24 |
+
return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
| 28 |
+
if w is None:
|
| 29 |
+
return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
|
| 30 |
+
else:
|
| 31 |
+
w = w.astype(x.dtype)
|
| 32 |
+
return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
|
| 36 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 37 |
+
if aspect_ratio is None:
|
| 38 |
+
aspect_ratio = width / height
|
| 39 |
+
|
| 40 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
| 41 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
| 42 |
+
|
| 43 |
+
u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
|
| 44 |
+
v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
|
| 45 |
+
u, v = np.meshgrid(u, v, indexing='xy')
|
| 46 |
+
uv = np.stack([u, v], axis=-1)
|
| 47 |
+
return uv
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def focal_to_fov_numpy(focal: np.ndarray):
|
| 51 |
+
return 2 * np.arctan(0.5 / focal)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def fov_to_focal_numpy(fov: np.ndarray):
|
| 55 |
+
return 0.5 / np.tan(fov / 2)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
| 59 |
+
fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
|
| 60 |
+
fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
|
| 61 |
+
return fov_x, fov_y
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def point_map_to_depth_legacy_numpy(points: np.ndarray):
|
| 65 |
+
height, width = points.shape[-3:-1]
|
| 66 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 67 |
+
uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
|
| 68 |
+
_, uv = np.broadcast_arrays(points[..., :2], uv)
|
| 69 |
+
|
| 70 |
+
# Solve least squares problem
|
| 71 |
+
b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
|
| 72 |
+
A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
|
| 73 |
+
|
| 74 |
+
M = A.swapaxes(-2, -1) @ A
|
| 75 |
+
solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
|
| 76 |
+
focal, shift = solution
|
| 77 |
+
|
| 78 |
+
depth = points[..., 2] + shift[..., None, None]
|
| 79 |
+
fov_x = np.arctan(width / diagonal / focal) * 2
|
| 80 |
+
fov_y = np.arctan(height / diagonal / focal) * 2
|
| 81 |
+
return depth, fov_x, fov_y, shift
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
|
| 85 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
|
| 86 |
+
from scipy.optimize import least_squares
|
| 87 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 88 |
+
|
| 89 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 90 |
+
xy_proj = xy / (z + shift)[: , None]
|
| 91 |
+
f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 92 |
+
err = (f * xy_proj - uv).ravel()
|
| 93 |
+
return err
|
| 94 |
+
|
| 95 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
| 96 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
| 97 |
+
|
| 98 |
+
xy_proj = xy / (z + optim_shift)[: , None]
|
| 99 |
+
optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
| 100 |
+
|
| 101 |
+
return optim_shift, optim_focal
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
|
| 105 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
|
| 106 |
+
from scipy.optimize import least_squares
|
| 107 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
| 108 |
+
|
| 109 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
| 110 |
+
xy_proj = xy / (z + shift)[: , None]
|
| 111 |
+
err = (focal * xy_proj - uv).ravel()
|
| 112 |
+
return err
|
| 113 |
+
|
| 114 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
|
| 115 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
| 116 |
+
|
| 117 |
+
return optim_shift
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
|
| 121 |
+
import cv2
|
| 122 |
+
assert points.shape[-1] == 3, "Points should (H, W, 3)"
|
| 123 |
+
|
| 124 |
+
height, width = points.shape[-3], points.shape[-2]
|
| 125 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 126 |
+
|
| 127 |
+
uv = normalized_view_plane_uv_numpy(width=width, height=height)
|
| 128 |
+
|
| 129 |
+
if mask is None:
|
| 130 |
+
points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
|
| 131 |
+
uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
|
| 132 |
+
else:
|
| 133 |
+
(points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size)
|
| 134 |
+
|
| 135 |
+
if points_lr.size < 2:
|
| 136 |
+
return 1., 0.
|
| 137 |
+
|
| 138 |
+
if focal is None:
|
| 139 |
+
focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
|
| 140 |
+
else:
|
| 141 |
+
shift = solve_optimal_shift(uv_lr, points_lr, focal)
|
| 142 |
+
|
| 143 |
+
return focal, shift
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def mask_aware_nearest_resize_numpy(
|
| 147 |
+
inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None],
|
| 148 |
+
mask: np.ndarray,
|
| 149 |
+
size: Tuple[int, int],
|
| 150 |
+
return_index: bool = False
|
| 151 |
+
) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]:
|
| 152 |
+
"""
|
| 153 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 154 |
+
|
| 155 |
+
### Parameters
|
| 156 |
+
- `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
|
| 157 |
+
- `mask`: input 2D mask of shape (..., H, W)
|
| 158 |
+
- `size`: target size (width, height)
|
| 159 |
+
|
| 160 |
+
### Returns
|
| 161 |
+
- `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
|
| 162 |
+
- `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
|
| 163 |
+
- `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension.
|
| 164 |
+
"""
|
| 165 |
+
height, width = mask.shape[-2:]
|
| 166 |
+
target_width, target_height = size
|
| 167 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 168 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
| 169 |
+
filter_size = filter_h_i * filter_w_i
|
| 170 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 171 |
+
|
| 172 |
+
# Window the original mask and uv
|
| 173 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
| 174 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
| 175 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
| 176 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 177 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
| 178 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 179 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
| 180 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 181 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 182 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
| 183 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 184 |
+
|
| 185 |
+
# Gather the target pixels's local window
|
| 186 |
+
target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
| 187 |
+
target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 188 |
+
target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
| 189 |
+
|
| 190 |
+
target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 191 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 192 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 193 |
+
|
| 194 |
+
# Compute nearest neighbor in the local window for each pixel
|
| 195 |
+
dist = np.square(target_window_centers - target_centers[..., None])
|
| 196 |
+
dist = dist[..., 0, :] + dist[..., 1, :]
|
| 197 |
+
dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
|
| 198 |
+
nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
|
| 199 |
+
nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
|
| 200 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
| 201 |
+
target_mask = np.any(target_window_mask, axis=-1)
|
| 202 |
+
batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
| 203 |
+
|
| 204 |
+
index = (*batch_indices, nearest_i, nearest_j)
|
| 205 |
+
|
| 206 |
+
if inputs is None:
|
| 207 |
+
outputs = None
|
| 208 |
+
elif isinstance(inputs, np.ndarray):
|
| 209 |
+
outputs = inputs[index]
|
| 210 |
+
elif isinstance(inputs, Sequence):
|
| 211 |
+
outputs = tuple(x[index] for x in inputs)
|
| 212 |
+
else:
|
| 213 |
+
raise ValueError(f'Invalid input type: {type(inputs)}')
|
| 214 |
+
|
| 215 |
+
if return_index:
|
| 216 |
+
return outputs, target_mask, index
|
| 217 |
+
else:
|
| 218 |
+
return outputs, target_mask
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
|
| 222 |
+
"""
|
| 223 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 224 |
+
|
| 225 |
+
### Parameters
|
| 226 |
+
- `image`: Input 2D image of shape (..., H, W, C)
|
| 227 |
+
- `mask`: Input 2D mask of shape (..., H, W)
|
| 228 |
+
- `target_width`: target width of the resized map
|
| 229 |
+
- `target_height`: target height of the resized map
|
| 230 |
+
|
| 231 |
+
### Returns
|
| 232 |
+
- `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width).
|
| 233 |
+
- `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
|
| 234 |
+
"""
|
| 235 |
+
height, width = mask.shape[-2:]
|
| 236 |
+
|
| 237 |
+
if image.shape[-2:] == (height, width):
|
| 238 |
+
omit_channel_dim = True
|
| 239 |
+
else:
|
| 240 |
+
omit_channel_dim = False
|
| 241 |
+
if omit_channel_dim:
|
| 242 |
+
image = image[..., None]
|
| 243 |
+
|
| 244 |
+
image = np.where(mask[..., None], image, 0)
|
| 245 |
+
|
| 246 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 247 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1
|
| 248 |
+
filter_size = filter_h_i * filter_w_i
|
| 249 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 250 |
+
|
| 251 |
+
# Window the original mask and uv (non-copy)
|
| 252 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
| 253 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
| 254 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
| 255 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 256 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
| 257 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 258 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
| 259 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 260 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 261 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
| 262 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
| 263 |
+
|
| 264 |
+
# Gather the target pixels's local window
|
| 265 |
+
target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
| 266 |
+
target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 267 |
+
target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
| 268 |
+
target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
| 269 |
+
|
| 270 |
+
target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 271 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 272 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 273 |
+
|
| 274 |
+
# Compute pixel area in the local windows
|
| 275 |
+
target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None])
|
| 276 |
+
target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None])
|
| 277 |
+
target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None)
|
| 278 |
+
target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0)
|
| 279 |
+
|
| 280 |
+
# Weighted sum by area
|
| 281 |
+
target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1)
|
| 282 |
+
target_mask = np.sum(target_window_area, axis=-1) >= 0.25
|
| 283 |
+
target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1)
|
| 284 |
+
|
| 285 |
+
if omit_channel_dim:
|
| 286 |
+
target_image = target_image[..., 0]
|
| 287 |
+
|
| 288 |
+
return target_image, target_mask
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
def norm3d(x: np.ndarray) -> np.ndarray:
|
| 292 |
+
"Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
|
| 293 |
+
return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1):
|
| 297 |
+
disp = np.where(mask, 1 / depth, 0)
|
| 298 |
+
disp_pad = np.pad(disp, (thickness, thickness), constant_values=0)
|
| 299 |
+
mask_pad = np.pad(mask, (thickness, thickness), constant_values=False)
|
| 300 |
+
kernel_size = 2 * thickness + 1
|
| 301 |
+
disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 302 |
+
mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 303 |
+
|
| 304 |
+
disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
|
| 305 |
+
fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
|
| 306 |
+
bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
|
| 307 |
+
|
| 308 |
+
edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \
|
| 309 |
+
& (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0)
|
| 310 |
+
|
| 311 |
+
return edge_mask
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def disk_kernel(radius: int) -> np.ndarray:
|
| 315 |
+
"""
|
| 316 |
+
Generate disk kernel with given radius.
|
| 317 |
+
|
| 318 |
+
Args:
|
| 319 |
+
radius (int): Radius of the disk (in pixels).
|
| 320 |
+
|
| 321 |
+
Returns:
|
| 322 |
+
np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
|
| 323 |
+
"""
|
| 324 |
+
# Create coordinate grid centered at (0,0)
|
| 325 |
+
L = np.arange(-radius, radius + 1)
|
| 326 |
+
X, Y = np.meshgrid(L, L)
|
| 327 |
+
# Generate disk: region inside circle with radius R is 1
|
| 328 |
+
kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
|
| 329 |
+
# Normalize the kernel
|
| 330 |
+
kernel /= np.sum(kernel)
|
| 331 |
+
return kernel
|
| 332 |
+
|
| 333 |
+
|
| 334 |
+
def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
|
| 335 |
+
"""
|
| 336 |
+
Apply disk blur to an image using FFT convolution.
|
| 337 |
+
|
| 338 |
+
Args:
|
| 339 |
+
image (np.ndarray): Input image, can be grayscale or color.
|
| 340 |
+
radius (int): Blur radius (in pixels).
|
| 341 |
+
|
| 342 |
+
Returns:
|
| 343 |
+
np.ndarray: Blurred image.
|
| 344 |
+
"""
|
| 345 |
+
if radius == 0:
|
| 346 |
+
return image
|
| 347 |
+
kernel = disk_kernel(radius)
|
| 348 |
+
if image.ndim == 2:
|
| 349 |
+
blurred = fftconvolve(image, kernel, mode='same')
|
| 350 |
+
elif image.ndim == 3:
|
| 351 |
+
channels = []
|
| 352 |
+
for i in range(image.shape[2]):
|
| 353 |
+
blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
|
| 354 |
+
channels.append(blurred_channel)
|
| 355 |
+
blurred = np.stack(channels, axis=-1)
|
| 356 |
+
else:
|
| 357 |
+
raise ValueError("Image must be 2D or 3D.")
|
| 358 |
+
return blurred
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def depth_of_field(
|
| 362 |
+
img: np.ndarray,
|
| 363 |
+
disp: np.ndarray,
|
| 364 |
+
focus_disp : float,
|
| 365 |
+
max_blur_radius : int = 10,
|
| 366 |
+
) -> np.ndarray:
|
| 367 |
+
"""
|
| 368 |
+
Apply depth of field effect to an image.
|
| 369 |
+
|
| 370 |
+
Args:
|
| 371 |
+
img (numpy.ndarray): (H, W, 3) input image.
|
| 372 |
+
depth (numpy.ndarray): (H, W) depth map of the scene.
|
| 373 |
+
focus_depth (float): Focus depth of the lens.
|
| 374 |
+
strength (float): Strength of the depth of field effect.
|
| 375 |
+
max_blur_radius (int): Maximum blur radius (in pixels).
|
| 376 |
+
|
| 377 |
+
Returns:
|
| 378 |
+
numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
|
| 379 |
+
"""
|
| 380 |
+
# Precalculate dialated depth map for each blur radius
|
| 381 |
+
max_disp = np.max(disp)
|
| 382 |
+
disp = disp / max_disp
|
| 383 |
+
focus_disp = focus_disp / max_disp
|
| 384 |
+
dilated_disp = []
|
| 385 |
+
for radius in range(max_blur_radius + 1):
|
| 386 |
+
dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
|
| 387 |
+
|
| 388 |
+
# Determine the blur radius for each pixel based on the depth map
|
| 389 |
+
blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
|
| 390 |
+
for radius in range(max_blur_radius + 1):
|
| 391 |
+
dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
|
| 392 |
+
mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
|
| 393 |
+
blur_radii[mask] = dialted_blur_radii[mask]
|
| 394 |
+
blur_radii = np.clip(blur_radii, 0, max_blur_radius)
|
| 395 |
+
blur_radii = cv2.blur(blur_radii, (5, 5))
|
| 396 |
+
|
| 397 |
+
# Precalculate the blured image for each blur radius
|
| 398 |
+
unique_radii = np.unique(blur_radii)
|
| 399 |
+
precomputed = {}
|
| 400 |
+
for radius in range(max_blur_radius + 1):
|
| 401 |
+
if radius not in unique_radii:
|
| 402 |
+
continue
|
| 403 |
+
precomputed[radius] = disk_blur(img, radius)
|
| 404 |
+
|
| 405 |
+
# Composit the blured image for each pixel
|
| 406 |
+
output = np.zeros_like(img)
|
| 407 |
+
for r in unique_radii:
|
| 408 |
+
mask = blur_radii == r
|
| 409 |
+
output[mask] = precomputed[r][mask]
|
| 410 |
+
|
| 411 |
+
return output
|
moge/utils/geometry_torch.py
ADDED
|
@@ -0,0 +1,359 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
import math
|
| 8 |
+
from collections import namedtuple
|
| 9 |
+
|
| 10 |
+
import numpy as np
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torch.types
|
| 15 |
+
import utils3d
|
| 16 |
+
|
| 17 |
+
from .tools import timeit
|
| 18 |
+
from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 22 |
+
if w is None:
|
| 23 |
+
return x.mean(dim=dim, keepdim=keepdim)
|
| 24 |
+
else:
|
| 25 |
+
w = w.to(x.dtype)
|
| 26 |
+
return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 30 |
+
if w is None:
|
| 31 |
+
return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
|
| 32 |
+
else:
|
| 33 |
+
w = w.to(x.dtype)
|
| 34 |
+
return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
| 38 |
+
if w is None:
|
| 39 |
+
return x.add(eps).log().mean(dim=dim).exp()
|
| 40 |
+
else:
|
| 41 |
+
w = w.to(x.dtype)
|
| 42 |
+
return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
|
| 46 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
| 47 |
+
if aspect_ratio is None:
|
| 48 |
+
aspect_ratio = width / height
|
| 49 |
+
|
| 50 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
| 51 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
| 52 |
+
|
| 53 |
+
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
| 54 |
+
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
| 55 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
| 56 |
+
uv = torch.stack([u, v], dim=-1)
|
| 57 |
+
return uv
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
|
| 61 |
+
kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
|
| 62 |
+
kernel = kernel / kernel.sum()
|
| 63 |
+
kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
|
| 64 |
+
input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
|
| 65 |
+
input = F.conv2d(input, kernel, groups=input.shape[1])
|
| 66 |
+
return input
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def focal_to_fov(focal: torch.Tensor):
|
| 70 |
+
return 2 * torch.atan(0.5 / focal)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def fov_to_focal(fov: torch.Tensor):
|
| 74 |
+
return 0.5 / torch.tan(fov / 2)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
|
| 78 |
+
return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
|
| 79 |
+
|
| 80 |
+
def intrinsics_to_fov(intrinsics: torch.Tensor):
|
| 81 |
+
"""
|
| 82 |
+
Returns field of view in radians from normalized intrinsics matrix.
|
| 83 |
+
### Parameters:
|
| 84 |
+
- intrinsics: torch.Tensor of shape (..., 3, 3)
|
| 85 |
+
|
| 86 |
+
### Returns:
|
| 87 |
+
- fov_x: torch.Tensor of shape (...)
|
| 88 |
+
- fov_y: torch.Tensor of shape (...)
|
| 89 |
+
"""
|
| 90 |
+
focal_x = intrinsics[..., 0, 0]
|
| 91 |
+
focal_y = intrinsics[..., 1, 1]
|
| 92 |
+
return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def point_map_to_depth_legacy(points: torch.Tensor):
|
| 96 |
+
height, width = points.shape[-3:-1]
|
| 97 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 98 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
| 99 |
+
|
| 100 |
+
# Solve least squares problem
|
| 101 |
+
b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
|
| 102 |
+
A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
|
| 103 |
+
|
| 104 |
+
M = A.transpose(-2, -1) @ A
|
| 105 |
+
solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
|
| 106 |
+
focal, shift = solution.unbind(-1)
|
| 107 |
+
|
| 108 |
+
depth = points[..., 2] + shift[..., None, None]
|
| 109 |
+
fov_x = torch.atan(width / diagonal / focal) * 2
|
| 110 |
+
fov_y = torch.atan(height / diagonal / focal) * 2
|
| 111 |
+
return depth, fov_x, fov_y, shift
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def view_plane_uv_to_focal(uv: torch.Tensor):
|
| 115 |
+
normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
|
| 116 |
+
focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
|
| 117 |
+
return focal
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
|
| 121 |
+
"""
|
| 122 |
+
Recover the depth map and FoV from a point map with unknown z shift and focal.
|
| 123 |
+
|
| 124 |
+
Note that it assumes:
|
| 125 |
+
- the optical center is at the center of the map
|
| 126 |
+
- the map is undistorted
|
| 127 |
+
- the map is isometric in the x and y directions
|
| 128 |
+
|
| 129 |
+
### Parameters:
|
| 130 |
+
- `points: torch.Tensor` of shape (..., H, W, 3)
|
| 131 |
+
- `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
|
| 132 |
+
|
| 133 |
+
### Returns:
|
| 134 |
+
- `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
|
| 135 |
+
- `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
|
| 136 |
+
"""
|
| 137 |
+
shape = points.shape
|
| 138 |
+
height, width = points.shape[-3], points.shape[-2]
|
| 139 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
| 140 |
+
|
| 141 |
+
points = points.reshape(-1, *shape[-3:])
|
| 142 |
+
mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
|
| 143 |
+
focal = focal.reshape(-1) if focal is not None else None
|
| 144 |
+
uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
| 145 |
+
|
| 146 |
+
points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
|
| 147 |
+
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
|
| 148 |
+
mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
|
| 149 |
+
|
| 150 |
+
uv_lr_np = uv_lr.cpu().numpy()
|
| 151 |
+
points_lr_np = points_lr.detach().cpu().numpy()
|
| 152 |
+
focal_np = focal.cpu().numpy() if focal is not None else None
|
| 153 |
+
mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
|
| 154 |
+
optim_shift, optim_focal = [], []
|
| 155 |
+
for i in range(points.shape[0]):
|
| 156 |
+
points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
|
| 157 |
+
uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
|
| 158 |
+
if uv_lr_i_np.shape[0] < 2:
|
| 159 |
+
optim_focal.append(1)
|
| 160 |
+
optim_shift.append(0)
|
| 161 |
+
continue
|
| 162 |
+
if focal is None:
|
| 163 |
+
optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
|
| 164 |
+
optim_focal.append(float(optim_focal_i))
|
| 165 |
+
else:
|
| 166 |
+
optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
|
| 167 |
+
optim_shift.append(float(optim_shift_i))
|
| 168 |
+
optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
| 169 |
+
|
| 170 |
+
if focal is None:
|
| 171 |
+
optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
|
| 172 |
+
else:
|
| 173 |
+
optim_focal = focal.reshape(shape[:-3])
|
| 174 |
+
|
| 175 |
+
return optim_focal, optim_shift
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def mask_aware_nearest_resize(
|
| 179 |
+
inputs: Union[torch.Tensor, Sequence[torch.Tensor], None],
|
| 180 |
+
mask: torch.BoolTensor,
|
| 181 |
+
size: Tuple[int, int],
|
| 182 |
+
return_index: bool = False
|
| 183 |
+
) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]:
|
| 184 |
+
"""
|
| 185 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
| 186 |
+
|
| 187 |
+
### Parameters
|
| 188 |
+
- `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
|
| 189 |
+
- `mask`: input 2D mask of shape (..., H, W)
|
| 190 |
+
- `size`: target size (target_width, target_height)
|
| 191 |
+
|
| 192 |
+
### Returns
|
| 193 |
+
- `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
|
| 194 |
+
- `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
|
| 195 |
+
- `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, .
|
| 196 |
+
"""
|
| 197 |
+
height, width = mask.shape[-2:]
|
| 198 |
+
target_width, target_height = size
|
| 199 |
+
device = mask.device
|
| 200 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
| 201 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
| 202 |
+
filter_size = filter_h_i * filter_w_i
|
| 203 |
+
padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
|
| 204 |
+
|
| 205 |
+
# Window the original mask and uv
|
| 206 |
+
uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
|
| 207 |
+
indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
|
| 208 |
+
padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
|
| 209 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
| 210 |
+
padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
|
| 211 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
| 212 |
+
padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
|
| 213 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
| 214 |
+
windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
| 215 |
+
windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
|
| 216 |
+
windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
| 217 |
+
|
| 218 |
+
# Gather the target pixels's local window
|
| 219 |
+
target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
|
| 220 |
+
target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
|
| 221 |
+
target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
|
| 222 |
+
|
| 223 |
+
target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
| 224 |
+
target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
| 225 |
+
target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
| 226 |
+
target_window_indices = target_window_indices.expand_as(target_window_mask)
|
| 227 |
+
|
| 228 |
+
# Compute nearest neighbor in the local window for each pixel
|
| 229 |
+
dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
|
| 230 |
+
nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
|
| 231 |
+
nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
|
| 232 |
+
target_mask = torch.any(target_window_mask, dim=-1)
|
| 233 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
| 234 |
+
batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
| 235 |
+
|
| 236 |
+
index = (*batch_indices, nearest_i, nearest_j)
|
| 237 |
+
|
| 238 |
+
if inputs is None:
|
| 239 |
+
outputs = None
|
| 240 |
+
elif isinstance(inputs, torch.Tensor):
|
| 241 |
+
outputs = inputs[index]
|
| 242 |
+
elif isinstance(inputs, Sequence):
|
| 243 |
+
outputs = tuple(x[index] for x in inputs)
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError(f'Invalid input type: {type(inputs)}')
|
| 246 |
+
|
| 247 |
+
if return_index:
|
| 248 |
+
return outputs, target_mask, index
|
| 249 |
+
else:
|
| 250 |
+
return outputs, target_mask
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
|
| 254 |
+
*batch_shape, height, width = depth.shape
|
| 255 |
+
depth = depth.reshape(-1, 1, height, width)
|
| 256 |
+
mask = mask.reshape(-1, 1, height, width)
|
| 257 |
+
if pooler =='max':
|
| 258 |
+
pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
|
| 259 |
+
output_mask = pooled_depth > depth * (1 + rtol)
|
| 260 |
+
elif pooler =='min':
|
| 261 |
+
pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
|
| 262 |
+
output_mask = pooled_depth < depth * (1 - rtol)
|
| 263 |
+
else:
|
| 264 |
+
raise ValueError(f'Unsupported pooler: {pooler}')
|
| 265 |
+
output_mask = output_mask.reshape(*batch_shape, height, width)
|
| 266 |
+
return output_mask
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
|
| 270 |
+
device, dtype = depth.device, depth.dtype
|
| 271 |
+
|
| 272 |
+
disp = torch.where(mask, 1 / depth, 0)
|
| 273 |
+
disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
|
| 274 |
+
mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
|
| 275 |
+
disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
|
| 276 |
+
mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
|
| 277 |
+
|
| 278 |
+
x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype)
|
| 279 |
+
A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3]
|
| 280 |
+
A = mask_window[..., None] * A
|
| 281 |
+
I = torch.eye(3, device=device, dtype=dtype)
|
| 282 |
+
|
| 283 |
+
affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2]
|
| 284 |
+
diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0)
|
| 285 |
+
|
| 286 |
+
edge_mask = mask & (diff > tol).any(dim=-1)
|
| 287 |
+
|
| 288 |
+
disp_mean = weighted_mean(disp_window, mask_window, dim=-1)
|
| 289 |
+
fg_edge_mask = edge_mask & (disp > disp_mean)
|
| 290 |
+
# fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size)
|
| 291 |
+
bg_edge_mask = edge_mask & ~fg_edge_mask
|
| 292 |
+
return fg_edge_mask, bg_edge_mask
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
|
| 296 |
+
device, dtype = depth.device, depth.dtype
|
| 297 |
+
|
| 298 |
+
disp = torch.where(mask, 1 / depth, 0)
|
| 299 |
+
disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
|
| 300 |
+
mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
|
| 301 |
+
disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 302 |
+
mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
|
| 303 |
+
|
| 304 |
+
disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1))
|
| 305 |
+
fg_edge_mask = mask & (disp / disp_mean > 1 + tol)
|
| 306 |
+
bg_edge_mask = mask & (disp_mean / disp > 1 + tol)
|
| 307 |
+
|
| 308 |
+
fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
|
| 309 |
+
bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
|
| 310 |
+
|
| 311 |
+
return fg_edge_mask, bg_edge_mask
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
|
| 315 |
+
kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
|
| 316 |
+
for _ in range(iterations):
|
| 317 |
+
input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
|
| 318 |
+
mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
|
| 319 |
+
if filter =='min':
|
| 320 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
|
| 321 |
+
elif filter =='max':
|
| 322 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
|
| 323 |
+
elif filter == 'mean':
|
| 324 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
|
| 325 |
+
elif filter =='median':
|
| 326 |
+
input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
|
| 327 |
+
mask = mask_window.any(dim=(-2, -1))
|
| 328 |
+
return input, mask
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor:
|
| 332 |
+
device, dtype = depth.device, depth.dtype
|
| 333 |
+
height, width = depth.shape[-2:]
|
| 334 |
+
radius = kernel_size // 2
|
| 335 |
+
|
| 336 |
+
duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device)
|
| 337 |
+
|
| 338 |
+
log_depth = depth.clamp_min_(eps).log()
|
| 339 |
+
log_depth_diff = utils3d.torch.sliding_window_2d(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None]
|
| 340 |
+
|
| 341 |
+
weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square())
|
| 342 |
+
tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps)
|
| 343 |
+
|
| 344 |
+
uv = utils3d.torch.image_uv(height=height, width=width, device=device, dtype=dtype)
|
| 345 |
+
K_inv = torch.inverse(intrinsics)
|
| 346 |
+
|
| 347 |
+
grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \
|
| 348 |
+
/ (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2)
|
| 349 |
+
laplacian = (weight * ((utils3d.torch.sliding_window_2d(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1))
|
| 350 |
+
|
| 351 |
+
laplacian = laplacian.clamp(-0.1, 0.1)
|
| 352 |
+
log_depth_refine = log_depth.clone()
|
| 353 |
+
|
| 354 |
+
for _ in range(iterations):
|
| 355 |
+
log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.torch.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp)
|
| 356 |
+
|
| 357 |
+
depth_refine = log_depth_refine.exp()
|
| 358 |
+
|
| 359 |
+
return depth_refine
|
moge/utils/io.py
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 8 |
+
from typing import IO
|
| 9 |
+
import zipfile
|
| 10 |
+
import json
|
| 11 |
+
import io
|
| 12 |
+
from typing import *
|
| 13 |
+
from pathlib import Path
|
| 14 |
+
import re
|
| 15 |
+
from PIL import Image, PngImagePlugin
|
| 16 |
+
|
| 17 |
+
import numpy as np
|
| 18 |
+
import cv2
|
| 19 |
+
|
| 20 |
+
from .tools import timeit
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def save_glb(
|
| 24 |
+
save_path: Union[str, os.PathLike],
|
| 25 |
+
vertices: np.ndarray,
|
| 26 |
+
faces: np.ndarray,
|
| 27 |
+
vertex_uvs: np.ndarray,
|
| 28 |
+
texture: np.ndarray,
|
| 29 |
+
vertex_normals: Optional[np.ndarray] = None,
|
| 30 |
+
):
|
| 31 |
+
import trimesh
|
| 32 |
+
import trimesh.visual
|
| 33 |
+
from PIL import Image
|
| 34 |
+
|
| 35 |
+
trimesh.Trimesh(
|
| 36 |
+
vertices=vertices,
|
| 37 |
+
vertex_normals=vertex_normals,
|
| 38 |
+
faces=faces,
|
| 39 |
+
visual = trimesh.visual.texture.TextureVisuals(
|
| 40 |
+
uv=vertex_uvs,
|
| 41 |
+
material=trimesh.visual.material.PBRMaterial(
|
| 42 |
+
baseColorTexture=Image.fromarray(texture),
|
| 43 |
+
metallicFactor=0.5,
|
| 44 |
+
roughnessFactor=1.0
|
| 45 |
+
)
|
| 46 |
+
),
|
| 47 |
+
process=False
|
| 48 |
+
).export(save_path)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def save_ply(
|
| 52 |
+
save_path: Union[str, os.PathLike],
|
| 53 |
+
vertices: np.ndarray,
|
| 54 |
+
faces: np.ndarray,
|
| 55 |
+
vertex_colors: np.ndarray,
|
| 56 |
+
vertex_normals: Optional[np.ndarray] = None,
|
| 57 |
+
):
|
| 58 |
+
import trimesh
|
| 59 |
+
import trimesh.visual
|
| 60 |
+
from PIL import Image
|
| 61 |
+
|
| 62 |
+
trimesh.Trimesh(
|
| 63 |
+
vertices=vertices,
|
| 64 |
+
faces=faces,
|
| 65 |
+
vertex_colors=vertex_colors,
|
| 66 |
+
vertex_normals=vertex_normals,
|
| 67 |
+
process=False
|
| 68 |
+
).export(save_path)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
|
| 72 |
+
"""
|
| 73 |
+
Read a image, return uint8 RGB array of shape (H, W, 3).
|
| 74 |
+
"""
|
| 75 |
+
if isinstance(path, (str, os.PathLike)):
|
| 76 |
+
data = Path(path).read_bytes()
|
| 77 |
+
else:
|
| 78 |
+
data = path.read()
|
| 79 |
+
image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
|
| 80 |
+
return image
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
|
| 84 |
+
"""
|
| 85 |
+
Write a image, input uint8 RGB array of shape (H, W, 3).
|
| 86 |
+
"""
|
| 87 |
+
data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
|
| 88 |
+
if isinstance(path, (str, os.PathLike)):
|
| 89 |
+
Path(path).write_bytes(data)
|
| 90 |
+
else:
|
| 91 |
+
path.write(data)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]:
|
| 95 |
+
"""
|
| 96 |
+
Read a depth image, return float32 depth array of shape (H, W).
|
| 97 |
+
"""
|
| 98 |
+
if isinstance(path, (str, os.PathLike)):
|
| 99 |
+
data = Path(path).read_bytes()
|
| 100 |
+
else:
|
| 101 |
+
data = path.read()
|
| 102 |
+
pil_image = Image.open(io.BytesIO(data))
|
| 103 |
+
near = float(pil_image.info.get('near'))
|
| 104 |
+
far = float(pil_image.info.get('far'))
|
| 105 |
+
unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None
|
| 106 |
+
depth = np.array(pil_image)
|
| 107 |
+
mask_nan, mask_inf = depth == 0, depth == 65535
|
| 108 |
+
depth = (depth.astype(np.float32) - 1) / 65533
|
| 109 |
+
depth = near ** (1 - depth) * far ** depth
|
| 110 |
+
depth[mask_nan] = np.nan
|
| 111 |
+
depth[mask_inf] = np.inf
|
| 112 |
+
return depth, unit
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
def write_depth(
|
| 116 |
+
path: Union[str, os.PathLike, IO],
|
| 117 |
+
depth: np.ndarray,
|
| 118 |
+
unit: float = None,
|
| 119 |
+
max_range: float = 1e5,
|
| 120 |
+
compression_level: int = 7,
|
| 121 |
+
):
|
| 122 |
+
"""
|
| 123 |
+
Encode and write a depth image as 16-bit PNG format.
|
| 124 |
+
### Parameters:
|
| 125 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 126 |
+
The file path or file object to write to.
|
| 127 |
+
- `depth: np.ndarray`
|
| 128 |
+
The depth array, float32 array of shape (H, W).
|
| 129 |
+
May contain `NaN` for invalid values and `Inf` for infinite values.
|
| 130 |
+
- `unit: float = None`
|
| 131 |
+
The unit of the depth values.
|
| 132 |
+
|
| 133 |
+
Depth values are encoded as follows:
|
| 134 |
+
- 0: unknown
|
| 135 |
+
- 1 ~ 65534: depth values in logarithmic
|
| 136 |
+
- 65535: infinity
|
| 137 |
+
|
| 138 |
+
metadata is stored in the PNG file as text fields:
|
| 139 |
+
- `near`: the minimum depth value
|
| 140 |
+
- `far`: the maximum depth value
|
| 141 |
+
- `unit`: the unit of the depth values (optional)
|
| 142 |
+
"""
|
| 143 |
+
mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
|
| 144 |
+
|
| 145 |
+
depth = depth.astype(np.float32)
|
| 146 |
+
mask_finite = depth
|
| 147 |
+
near = max(depth[mask_values].min(), 1e-5)
|
| 148 |
+
far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
|
| 149 |
+
depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
|
| 150 |
+
depth[mask_nan] = 0
|
| 151 |
+
depth[mask_inf] = 65535
|
| 152 |
+
|
| 153 |
+
pil_image = Image.fromarray(depth)
|
| 154 |
+
pnginfo = PngImagePlugin.PngInfo()
|
| 155 |
+
pnginfo.add_text('near', str(near))
|
| 156 |
+
pnginfo.add_text('far', str(far))
|
| 157 |
+
if unit is not None:
|
| 158 |
+
pnginfo.add_text('unit', str(unit))
|
| 159 |
+
pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
|
| 163 |
+
"""
|
| 164 |
+
Read a segmentation mask
|
| 165 |
+
### Parameters:
|
| 166 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 167 |
+
The file path or file object to read from.
|
| 168 |
+
### Returns:
|
| 169 |
+
- `Tuple[np.ndarray, Dict[str, int]]`
|
| 170 |
+
A tuple containing:
|
| 171 |
+
- `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
|
| 172 |
+
- `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
|
| 173 |
+
"""
|
| 174 |
+
if isinstance(path, (str, os.PathLike)):
|
| 175 |
+
data = Path(path).read_bytes()
|
| 176 |
+
else:
|
| 177 |
+
data = path.read()
|
| 178 |
+
pil_image = Image.open(io.BytesIO(data))
|
| 179 |
+
labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
|
| 180 |
+
mask = np.array(pil_image)
|
| 181 |
+
return mask, labels
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
|
| 185 |
+
"""
|
| 186 |
+
Write a segmentation mask and label mapping, as PNG format.
|
| 187 |
+
### Parameters:
|
| 188 |
+
- `path: Union[str, os.PathLike, IO]`
|
| 189 |
+
The file path or file object to write to.
|
| 190 |
+
- `mask: np.ndarray`
|
| 191 |
+
The segmentation mask, uint8 or uint16 array of shape (H, W).
|
| 192 |
+
- `labels: Dict[str, int] = None`
|
| 193 |
+
The label mapping, a dictionary of {label_name: label_id}.
|
| 194 |
+
- `compression_level: int = 7`
|
| 195 |
+
The compression level for PNG compression.
|
| 196 |
+
"""
|
| 197 |
+
assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
|
| 198 |
+
pil_image = Image.fromarray(mask)
|
| 199 |
+
pnginfo = PngImagePlugin.PngInfo()
|
| 200 |
+
if labels is not None:
|
| 201 |
+
labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
|
| 202 |
+
pnginfo.add_text('labels', labels_json)
|
| 203 |
+
pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
|
| 207 |
+
def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
|
| 208 |
+
"""
|
| 209 |
+
Read a normal image, return float32 normal array of shape (H, W, 3).
|
| 210 |
+
"""
|
| 211 |
+
if isinstance(path, (str, os.PathLike)):
|
| 212 |
+
data = Path(path).read_bytes()
|
| 213 |
+
else:
|
| 214 |
+
data = path.read()
|
| 215 |
+
normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
|
| 216 |
+
mask_nan = np.all(normal == 0, axis=-1)
|
| 217 |
+
normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
|
| 218 |
+
normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
|
| 219 |
+
normal[mask_nan] = np.nan
|
| 220 |
+
return normal
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
|
| 224 |
+
"""
|
| 225 |
+
Write a normal image, input float32 normal array of shape (H, W, 3).
|
| 226 |
+
"""
|
| 227 |
+
mask_nan = np.isnan(normal).any(axis=-1)
|
| 228 |
+
normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
|
| 229 |
+
normal[mask_nan] = 0
|
| 230 |
+
data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
|
| 231 |
+
if isinstance(path, (str, os.PathLike)):
|
| 232 |
+
Path(path).write_bytes(data)
|
| 233 |
+
else:
|
| 234 |
+
path.write(data)
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]:
|
| 238 |
+
return json.loads(Path(path).read_text())
|
| 239 |
+
|
| 240 |
+
def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]):
|
| 241 |
+
Path(path).write_text(json.dumps(meta))
|
moge/utils/panorama.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import *
|
| 10 |
+
import itertools
|
| 11 |
+
import json
|
| 12 |
+
import warnings
|
| 13 |
+
|
| 14 |
+
import cv2
|
| 15 |
+
import numpy as np
|
| 16 |
+
from numpy import ndarray
|
| 17 |
+
from tqdm import tqdm, trange
|
| 18 |
+
from scipy.sparse import csr_array, hstack, vstack
|
| 19 |
+
from scipy.ndimage import convolve
|
| 20 |
+
from scipy.sparse.linalg import lsmr
|
| 21 |
+
|
| 22 |
+
import utils3d
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def get_panorama_cameras():
|
| 26 |
+
vertices, _ = utils3d.numpy.icosahedron()
|
| 27 |
+
intrinsics = utils3d.numpy.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90))
|
| 28 |
+
extrinsics = utils3d.numpy.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32)
|
| 29 |
+
return extrinsics, [intrinsics] * len(vertices)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def spherical_uv_to_directions(uv: np.ndarray):
|
| 33 |
+
theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
|
| 34 |
+
directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
|
| 35 |
+
return directions
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def directions_to_spherical_uv(directions: np.ndarray):
|
| 39 |
+
directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
|
| 40 |
+
u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0
|
| 41 |
+
v = np.arccos(directions[..., 2]) / np.pi
|
| 42 |
+
return np.stack([u, v], axis=-1)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int):
|
| 46 |
+
height, width = image.shape[:2]
|
| 47 |
+
uv = utils3d.numpy.image_uv(width=resolution, height=resolution)
|
| 48 |
+
splitted_images = []
|
| 49 |
+
for i in range(len(extrinsics)):
|
| 50 |
+
spherical_uv = directions_to_spherical_uv(utils3d.numpy.unproject_cv(uv, extrinsics=extrinsics[i], intrinsics=intrinsics[i]))
|
| 51 |
+
pixels = utils3d.numpy.uv_to_pixel(spherical_uv, width=width, height=height).astype(np.float32)
|
| 52 |
+
|
| 53 |
+
splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR)
|
| 54 |
+
splitted_images.append(splitted_image)
|
| 55 |
+
return splitted_images
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]:
|
| 59 |
+
grid_index = np.arange(height * width).reshape(height, width)
|
| 60 |
+
grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge')
|
| 61 |
+
grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge')
|
| 62 |
+
|
| 63 |
+
data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1)
|
| 64 |
+
indices = np.stack([
|
| 65 |
+
grid_index[1:-1, 1:-1],
|
| 66 |
+
grid_index[:-2, 1:-1], # up
|
| 67 |
+
grid_index[2:, 1:-1], # down
|
| 68 |
+
grid_index[1:-1, :-2], # left
|
| 69 |
+
grid_index[1:-1, 2:] # right
|
| 70 |
+
], axis=-1).reshape(-1)
|
| 71 |
+
indptr = np.arange(0, height * width * 5 + 1, 5)
|
| 72 |
+
A = csr_array((data, indices, indptr), shape=(height * width, height * width))
|
| 73 |
+
|
| 74 |
+
return A
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]:
|
| 78 |
+
grid_index = np.arange(width * height).reshape(height, width)
|
| 79 |
+
if wrap_x:
|
| 80 |
+
grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap')
|
| 81 |
+
if wrap_y:
|
| 82 |
+
grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap')
|
| 83 |
+
|
| 84 |
+
data = np.concatenate([
|
| 85 |
+
np.concatenate([
|
| 86 |
+
np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j]
|
| 87 |
+
-np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1]
|
| 88 |
+
], axis=1).reshape(-1),
|
| 89 |
+
np.concatenate([
|
| 90 |
+
np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j]
|
| 91 |
+
-np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j]
|
| 92 |
+
], axis=1).reshape(-1),
|
| 93 |
+
])
|
| 94 |
+
indices = np.concatenate([
|
| 95 |
+
np.concatenate([
|
| 96 |
+
grid_index[:, :-1].reshape(-1, 1),
|
| 97 |
+
grid_index[:, 1:].reshape(-1, 1),
|
| 98 |
+
], axis=1).reshape(-1),
|
| 99 |
+
np.concatenate([
|
| 100 |
+
grid_index[:-1, :].reshape(-1, 1),
|
| 101 |
+
grid_index[1:, :].reshape(-1, 1),
|
| 102 |
+
], axis=1).reshape(-1),
|
| 103 |
+
])
|
| 104 |
+
indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2)
|
| 105 |
+
A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width))
|
| 106 |
+
|
| 107 |
+
return A
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]):
|
| 111 |
+
if max(width, height) > 256:
|
| 112 |
+
panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics)
|
| 113 |
+
panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR)
|
| 114 |
+
else:
|
| 115 |
+
panorama_depth_init = None
|
| 116 |
+
|
| 117 |
+
uv = utils3d.numpy.image_uv(width=width, height=height)
|
| 118 |
+
spherical_directions = spherical_uv_to_directions(uv)
|
| 119 |
+
|
| 120 |
+
# Warp each view to the panorama
|
| 121 |
+
panorama_log_distance_grad_maps, panorama_grad_masks = [], []
|
| 122 |
+
panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], []
|
| 123 |
+
panorama_pred_masks = []
|
| 124 |
+
for i in range(len(distance_maps)):
|
| 125 |
+
projected_uv, projected_depth = utils3d.numpy.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i])
|
| 126 |
+
projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1)
|
| 127 |
+
|
| 128 |
+
projected_pixels = utils3d.numpy.uv_to_pixel(np.clip(projected_uv, 0, 1), width=distance_maps[i].shape[1], height=distance_maps[i].shape[0]).astype(np.float32)
|
| 129 |
+
|
| 130 |
+
log_splitted_distance = np.log(distance_maps[i])
|
| 131 |
+
panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0)
|
| 132 |
+
panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0)
|
| 133 |
+
|
| 134 |
+
# calculate gradient map
|
| 135 |
+
padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap')
|
| 136 |
+
grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
|
| 137 |
+
|
| 138 |
+
padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap')
|
| 139 |
+
mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :]
|
| 140 |
+
|
| 141 |
+
panorama_log_distance_grad_maps.append((grad_x, grad_y))
|
| 142 |
+
panorama_grad_masks.append((mask_x, mask_y))
|
| 143 |
+
|
| 144 |
+
# calculate laplacian map
|
| 145 |
+
padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge')
|
| 146 |
+
padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
|
| 147 |
+
laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1]
|
| 148 |
+
|
| 149 |
+
padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge')
|
| 150 |
+
padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
|
| 151 |
+
mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5
|
| 152 |
+
|
| 153 |
+
panorama_log_distance_laplacian_maps.append(laplacian)
|
| 154 |
+
panorama_laplacian_masks.append(mask)
|
| 155 |
+
|
| 156 |
+
panorama_pred_masks.append(panorama_pred_mask)
|
| 157 |
+
|
| 158 |
+
panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0)
|
| 159 |
+
panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0)
|
| 160 |
+
panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0)
|
| 161 |
+
panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0)
|
| 162 |
+
|
| 163 |
+
panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3)
|
| 164 |
+
panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3)
|
| 165 |
+
|
| 166 |
+
panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0)
|
| 167 |
+
panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0)
|
| 168 |
+
panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3)
|
| 169 |
+
|
| 170 |
+
grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1)
|
| 171 |
+
grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1)
|
| 172 |
+
grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
|
| 173 |
+
laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1)
|
| 174 |
+
|
| 175 |
+
# Solve overdetermined system
|
| 176 |
+
A = vstack([
|
| 177 |
+
grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
|
| 178 |
+
poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask],
|
| 179 |
+
])
|
| 180 |
+
b = np.concatenate([
|
| 181 |
+
panorama_log_distance_grad_x.reshape(-1)[grad_x_mask],
|
| 182 |
+
panorama_log_distance_grad_y.reshape(-1)[grad_y_mask],
|
| 183 |
+
panorama_laplacian_map.reshape(-1)[laplacian_mask]
|
| 184 |
+
])
|
| 185 |
+
x, *_ = lsmr(
|
| 186 |
+
A, b,
|
| 187 |
+
atol=1e-5, btol=1e-5,
|
| 188 |
+
x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None,
|
| 189 |
+
show=False,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
panorama_depth = np.exp(x).reshape(height, width).astype(np.float32)
|
| 193 |
+
panorama_mask = np.any(panorama_pred_masks, axis=0)
|
| 194 |
+
|
| 195 |
+
return panorama_depth, panorama_mask
|
| 196 |
+
|
moge/utils/pipeline.py
ADDED
|
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import *
|
| 2 |
+
from abc import abstractmethod
|
| 3 |
+
from queue import Empty, Full
|
| 4 |
+
from threading import Thread
|
| 5 |
+
from queue import Queue
|
| 6 |
+
from multiprocessing import Process
|
| 7 |
+
from threading import Thread, Event
|
| 8 |
+
import multiprocessing
|
| 9 |
+
import threading
|
| 10 |
+
import inspect
|
| 11 |
+
import time
|
| 12 |
+
import uuid
|
| 13 |
+
from copy import deepcopy
|
| 14 |
+
import itertools
|
| 15 |
+
import functools
|
| 16 |
+
|
| 17 |
+
# Copied from the MoGe project:
|
| 18 |
+
# https://github.com/microsoft/MoGe
|
| 19 |
+
# Original license: MIT
|
| 20 |
+
# Copyright (c) the MoGe authors
|
| 21 |
+
|
| 22 |
+
__all__ = [
|
| 23 |
+
'Node',
|
| 24 |
+
'Link',
|
| 25 |
+
'ConcurrentNode',
|
| 26 |
+
'Worker',
|
| 27 |
+
'WorkerFunction',
|
| 28 |
+
'Provider',
|
| 29 |
+
'ProviderFunction',
|
| 30 |
+
'Sequential',
|
| 31 |
+
'Batch',
|
| 32 |
+
'Unbatch',
|
| 33 |
+
'Parallel',
|
| 34 |
+
'Graph',
|
| 35 |
+
'Buffer',
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
TERMINATE_CHECK_INTERVAL = 0.5
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
class _ItemWrapper:
|
| 42 |
+
def __init__(self, data: Any, id: Union[int, List[int]] = None):
|
| 43 |
+
self.data = data
|
| 44 |
+
self.id = id
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class Terminate(Exception):
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper:
|
| 52 |
+
while True:
|
| 53 |
+
try:
|
| 54 |
+
item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL))
|
| 55 |
+
if terminate_flag.is_set():
|
| 56 |
+
raise Terminate()
|
| 57 |
+
return item
|
| 58 |
+
except Empty:
|
| 59 |
+
if terminate_flag.is_set():
|
| 60 |
+
raise Terminate()
|
| 61 |
+
|
| 62 |
+
if timeout is not None:
|
| 63 |
+
timeout -= TERMINATE_CHECK_INTERVAL
|
| 64 |
+
if timeout <= 0:
|
| 65 |
+
raise Empty()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event):
|
| 69 |
+
while True:
|
| 70 |
+
try:
|
| 71 |
+
queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL)
|
| 72 |
+
if terminate_flag.is_set():
|
| 73 |
+
raise Terminate()
|
| 74 |
+
return
|
| 75 |
+
except Full:
|
| 76 |
+
if terminate_flag.is_set():
|
| 77 |
+
raise Terminate()
|
| 78 |
+
|
| 79 |
+
class Node:
|
| 80 |
+
def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
|
| 81 |
+
self.input: Queue = Queue(maxsize=in_buffer_size)
|
| 82 |
+
self.output: Queue = Queue(maxsize=out_buffer_size)
|
| 83 |
+
self.in_buffer_size = in_buffer_size
|
| 84 |
+
self.out_buffer_size = out_buffer_size
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def start(self):
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
@abstractmethod
|
| 91 |
+
def terminate(self):
|
| 92 |
+
pass
|
| 93 |
+
|
| 94 |
+
def stop(self):
|
| 95 |
+
self.terminate()
|
| 96 |
+
self.join()
|
| 97 |
+
|
| 98 |
+
@abstractmethod
|
| 99 |
+
def join(self):
|
| 100 |
+
pass
|
| 101 |
+
|
| 102 |
+
def put(self, data: Any, key: str = None, block: bool = True) -> None:
|
| 103 |
+
item = _ItemWrapper(data)
|
| 104 |
+
self.input.put(item, block=block)
|
| 105 |
+
|
| 106 |
+
def get(self, key: str = None, block: bool = True) -> Any:
|
| 107 |
+
item: _ItemWrapper = self.output.get(block=block)
|
| 108 |
+
return item.data
|
| 109 |
+
|
| 110 |
+
def __enter__(self):
|
| 111 |
+
self.start()
|
| 112 |
+
return self
|
| 113 |
+
|
| 114 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 115 |
+
self.terminate()
|
| 116 |
+
self.join()
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class ConcurrentNode(Node):
|
| 120 |
+
job: Union[Thread, Process]
|
| 121 |
+
|
| 122 |
+
def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
|
| 123 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
| 124 |
+
self.running_as = running_as
|
| 125 |
+
|
| 126 |
+
@abstractmethod
|
| 127 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
| 128 |
+
pass
|
| 129 |
+
|
| 130 |
+
def start(self):
|
| 131 |
+
if self.running_as == 'thread':
|
| 132 |
+
terminate_flag = threading.Event()
|
| 133 |
+
job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
|
| 134 |
+
elif self.running_as == 'process':
|
| 135 |
+
terminate_flag = multiprocessing.Event()
|
| 136 |
+
job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
|
| 137 |
+
job.start()
|
| 138 |
+
self.job = job
|
| 139 |
+
self.terminate_flag = terminate_flag
|
| 140 |
+
|
| 141 |
+
def terminate(self):
|
| 142 |
+
self.terminate_flag.set()
|
| 143 |
+
|
| 144 |
+
def join(self):
|
| 145 |
+
self.job.join()
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class Worker(ConcurrentNode):
|
| 149 |
+
def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None:
|
| 150 |
+
super().__init__(running_as, in_buffer_size, out_buffer_size)
|
| 151 |
+
|
| 152 |
+
def init(self) -> None:
|
| 153 |
+
"""
|
| 154 |
+
This method is called the the thread is started, to initialize any resources that is only held in the thread.
|
| 155 |
+
"""
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
@abstractmethod
|
| 159 |
+
def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]:
|
| 160 |
+
"""
|
| 161 |
+
This method defines the job that the node should do for each input item.
|
| 162 |
+
A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue.
|
| 163 |
+
The method is executed concurrently with other nodes.
|
| 164 |
+
"""
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
| 168 |
+
self.init()
|
| 169 |
+
try:
|
| 170 |
+
while True:
|
| 171 |
+
item = _get_queue_item(input, terminate_flag)
|
| 172 |
+
result = self.work(item.data)
|
| 173 |
+
_put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag)
|
| 174 |
+
|
| 175 |
+
except Terminate:
|
| 176 |
+
return
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class Provider(ConcurrentNode):
|
| 180 |
+
"""
|
| 181 |
+
A node that provides data to successive nodes. It takes no input and provides data to the output queue.
|
| 182 |
+
"""
|
| 183 |
+
def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None:
|
| 184 |
+
super().__init__(running_as, 0, out_buffer_size)
|
| 185 |
+
|
| 186 |
+
def init(self) -> None:
|
| 187 |
+
"""
|
| 188 |
+
This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process.
|
| 189 |
+
"""
|
| 190 |
+
pass
|
| 191 |
+
|
| 192 |
+
@abstractmethod
|
| 193 |
+
def provide(self) -> Generator[Any, None, None]:
|
| 194 |
+
pass
|
| 195 |
+
|
| 196 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
| 197 |
+
self.init()
|
| 198 |
+
try:
|
| 199 |
+
for data in self.provide():
|
| 200 |
+
_put_queue_item(output, _ItemWrapper(data), terminate_flag)
|
| 201 |
+
except Terminate:
|
| 202 |
+
return
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class WorkerFunction(Worker):
|
| 206 |
+
def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
|
| 207 |
+
super().__init__(running_as, in_buffer_size, out_buffer_size)
|
| 208 |
+
self.fn = fn
|
| 209 |
+
|
| 210 |
+
def work(self, *args, **kwargs):
|
| 211 |
+
return self.fn(*args, **kwargs)
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class ProviderFunction(Provider):
|
| 215 |
+
def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None:
|
| 216 |
+
super().__init__(running_as, out_buffer_size)
|
| 217 |
+
self.fn = fn
|
| 218 |
+
|
| 219 |
+
def provide(self):
|
| 220 |
+
for item in self.fn():
|
| 221 |
+
yield item
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
class Link:
|
| 225 |
+
def __init__(self, src: Queue, dst: Queue):
|
| 226 |
+
self.src = src
|
| 227 |
+
self.dst = dst
|
| 228 |
+
|
| 229 |
+
def _thread_fn(self):
|
| 230 |
+
try:
|
| 231 |
+
while True:
|
| 232 |
+
item = _get_queue_item(self.src, self.terminate_flag)
|
| 233 |
+
_put_queue_item(self.dst, item, self.terminate_flag)
|
| 234 |
+
except Terminate:
|
| 235 |
+
return
|
| 236 |
+
|
| 237 |
+
def start(self):
|
| 238 |
+
self.terminate_flag = threading.Event()
|
| 239 |
+
self.thread = Thread(target=self._thread_fn)
|
| 240 |
+
self.thread.start()
|
| 241 |
+
|
| 242 |
+
def terminate(self):
|
| 243 |
+
self.terminate_flag.set()
|
| 244 |
+
|
| 245 |
+
def join(self):
|
| 246 |
+
self.thread.join()
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class Graph(Node):
|
| 250 |
+
"""
|
| 251 |
+
Graph pipeline of nodes and links
|
| 252 |
+
"""
|
| 253 |
+
nodes: List[Node]
|
| 254 |
+
links: List[Link]
|
| 255 |
+
|
| 256 |
+
def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
|
| 257 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
| 258 |
+
self.nodes = []
|
| 259 |
+
self.links = []
|
| 260 |
+
|
| 261 |
+
def add(self, node: Node):
|
| 262 |
+
self.nodes.append(node)
|
| 263 |
+
|
| 264 |
+
def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]):
|
| 265 |
+
"""
|
| 266 |
+
Links the output of the source node to the input of the destination node.
|
| 267 |
+
If the source or destination node is None, the pipeline's input or output is used.
|
| 268 |
+
"""
|
| 269 |
+
src_queue = self.input if src is None else src.output
|
| 270 |
+
dst_queue = self.output if dst is None else dst.input
|
| 271 |
+
self.links.append(Link(src_queue, dst_queue))
|
| 272 |
+
|
| 273 |
+
def chain(self, nodes: Iterable[Node]):
|
| 274 |
+
"""
|
| 275 |
+
Link the output of each node to the input of the next node.
|
| 276 |
+
"""
|
| 277 |
+
nodes = list(nodes)
|
| 278 |
+
for i in range(len(nodes) - 1):
|
| 279 |
+
self.link(nodes[i], nodes[i + 1])
|
| 280 |
+
|
| 281 |
+
def start(self):
|
| 282 |
+
for node in self.nodes:
|
| 283 |
+
node.start()
|
| 284 |
+
for link in self.links:
|
| 285 |
+
link.start()
|
| 286 |
+
|
| 287 |
+
def terminate(self):
|
| 288 |
+
for node in self.nodes:
|
| 289 |
+
node.terminate()
|
| 290 |
+
for link in self.links:
|
| 291 |
+
link.terminate()
|
| 292 |
+
|
| 293 |
+
def join(self):
|
| 294 |
+
for node in self.nodes:
|
| 295 |
+
node.join()
|
| 296 |
+
for link in self.links:
|
| 297 |
+
link.join()
|
| 298 |
+
|
| 299 |
+
def __iter__(self):
|
| 300 |
+
providers = [node for node in self.nodes if isinstance(node, Provider)]
|
| 301 |
+
if len(providers) == 0:
|
| 302 |
+
raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.")
|
| 303 |
+
with self:
|
| 304 |
+
# while all(provider.job.is_alive() for provider in providers):
|
| 305 |
+
while True:
|
| 306 |
+
yield self.get()
|
| 307 |
+
|
| 308 |
+
def __call__(self, data: Any) -> Any:
|
| 309 |
+
"""
|
| 310 |
+
Submit data to the pipeline's input queue, and return the output data asynchronously.
|
| 311 |
+
NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work.
|
| 312 |
+
"""
|
| 313 |
+
# TODO
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class Sequential(Graph):
|
| 317 |
+
"""
|
| 318 |
+
Pipeline of nodes in sequential order, where each node takes the output of the previous node as input.
|
| 319 |
+
The order of input and output items is preserved (FIFO)
|
| 320 |
+
"""
|
| 321 |
+
def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
|
| 322 |
+
"""
|
| 323 |
+
Initialize the pipeline with a list of nodes to execute sequentially.
|
| 324 |
+
### Parameters:
|
| 325 |
+
- nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
|
| 326 |
+
- function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'.
|
| 327 |
+
- in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited).
|
| 328 |
+
- out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited).
|
| 329 |
+
"""
|
| 330 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
| 331 |
+
for node in nodes:
|
| 332 |
+
if isinstance(node, Node):
|
| 333 |
+
pass
|
| 334 |
+
elif isinstance(node, Callable):
|
| 335 |
+
if inspect.isgeneratorfunction(node):
|
| 336 |
+
node = ProviderFunction(node, function_running_as)
|
| 337 |
+
else:
|
| 338 |
+
node = WorkerFunction(node, function_running_as)
|
| 339 |
+
else:
|
| 340 |
+
raise ValueError(f"Invalid node type: {type(node)}")
|
| 341 |
+
self.add(node)
|
| 342 |
+
self.chain([None, *self.nodes, None])
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
class Parallel(Node):
|
| 346 |
+
"""
|
| 347 |
+
A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available.
|
| 348 |
+
NOTE: It is FIFO if and only if all the nested nodes are FIFO.
|
| 349 |
+
"""
|
| 350 |
+
nodes: List[Node]
|
| 351 |
+
|
| 352 |
+
def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'):
|
| 353 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
| 354 |
+
self.nodes = []
|
| 355 |
+
for node in nodes:
|
| 356 |
+
if isinstance(node, Node):
|
| 357 |
+
pass
|
| 358 |
+
elif isinstance(node, Callable):
|
| 359 |
+
if inspect.isgeneratorfunction(node):
|
| 360 |
+
node = ProviderFunction(node, function_running_as)
|
| 361 |
+
else:
|
| 362 |
+
node = WorkerFunction(node, function_running_as)
|
| 363 |
+
else:
|
| 364 |
+
raise ValueError(f"Invalid node type: {type(node)}")
|
| 365 |
+
self.nodes.append(node)
|
| 366 |
+
self.output_order = Queue()
|
| 367 |
+
self.lock = threading.Lock()
|
| 368 |
+
|
| 369 |
+
def _in_thread_fn(self, node: Node):
|
| 370 |
+
try:
|
| 371 |
+
while True:
|
| 372 |
+
with self.lock:
|
| 373 |
+
# A better idea: first make sure its node is vacant, then get it a new item.
|
| 374 |
+
# Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node.
|
| 375 |
+
# This could lead to suboptimal scheduling.
|
| 376 |
+
item = _get_queue_item(self.input, self.terminate_flag)
|
| 377 |
+
self.output_order.put(node.output)
|
| 378 |
+
_put_queue_item(node.input, item, self.terminate_flag)
|
| 379 |
+
except Terminate:
|
| 380 |
+
return
|
| 381 |
+
|
| 382 |
+
def _out_thread_fn(self):
|
| 383 |
+
try:
|
| 384 |
+
while True:
|
| 385 |
+
queue = _get_queue_item(self.output_order, self.terminate_flag)
|
| 386 |
+
item = _get_queue_item(queue, self.terminate_flag)
|
| 387 |
+
_put_queue_item(self.output, item, self.terminate_flag)
|
| 388 |
+
except Terminate:
|
| 389 |
+
return
|
| 390 |
+
|
| 391 |
+
def start(self):
|
| 392 |
+
self.terminate_flag = threading.Event()
|
| 393 |
+
self.in_threads = []
|
| 394 |
+
for node in self.nodes:
|
| 395 |
+
thread = Thread(target=self._in_thread_fn, args=(node,))
|
| 396 |
+
thread.start()
|
| 397 |
+
self.in_threads.append(thread)
|
| 398 |
+
thread = Thread(target=self._out_thread_fn)
|
| 399 |
+
thread.start()
|
| 400 |
+
self.out_thread = thread
|
| 401 |
+
for node in self.nodes:
|
| 402 |
+
node.start()
|
| 403 |
+
|
| 404 |
+
def terminate(self):
|
| 405 |
+
self.terminate_flag.set()
|
| 406 |
+
for node in self.nodes:
|
| 407 |
+
node.terminate()
|
| 408 |
+
|
| 409 |
+
def join(self):
|
| 410 |
+
for thread in self.in_threads:
|
| 411 |
+
thread.join()
|
| 412 |
+
self.out_thread.join()
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
class UnorderedParallel(Graph):
|
| 416 |
+
"""
|
| 417 |
+
Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available.
|
| 418 |
+
NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input.
|
| 419 |
+
"""
|
| 420 |
+
def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
|
| 421 |
+
"""
|
| 422 |
+
Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node.
|
| 423 |
+
### Parameters:
|
| 424 |
+
- nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
|
| 425 |
+
- function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'.
|
| 426 |
+
- in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited).
|
| 427 |
+
- out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited).
|
| 428 |
+
"""
|
| 429 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
| 430 |
+
for node in nodes:
|
| 431 |
+
if isinstance(node, Node):
|
| 432 |
+
pass
|
| 433 |
+
elif isinstance(node, Callable):
|
| 434 |
+
if inspect.isgeneratorfunction(node):
|
| 435 |
+
node = ProviderFunction(node, function_running_as)
|
| 436 |
+
else:
|
| 437 |
+
node = WorkerFunction(node, function_running_as)
|
| 438 |
+
else:
|
| 439 |
+
raise ValueError(f"Invalid node type: {type(node)}")
|
| 440 |
+
self.add(node)
|
| 441 |
+
for i in range(len(nodes)):
|
| 442 |
+
self.chain([None, self.nodes[i], None])
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
class Batch(ConcurrentNode):
|
| 446 |
+
"""
|
| 447 |
+
Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes.
|
| 448 |
+
The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node,
|
| 449 |
+
i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size.
|
| 450 |
+
"""
|
| 451 |
+
def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1):
|
| 452 |
+
assert batch_size > 0, "Batch size must be greater than 0."
|
| 453 |
+
super().__init__('thread', in_buffer_size, out_buffer_size)
|
| 454 |
+
self.batch_size = batch_size
|
| 455 |
+
self.patience = patience
|
| 456 |
+
|
| 457 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
| 458 |
+
try:
|
| 459 |
+
while True:
|
| 460 |
+
batch_id, batch_data = [], []
|
| 461 |
+
# Try to fill the batch
|
| 462 |
+
for i in range(self.batch_size):
|
| 463 |
+
if i == 0 or self.patience is None:
|
| 464 |
+
timeout = None
|
| 465 |
+
else:
|
| 466 |
+
timeout = self.patience - (time.time() - earliest_time)
|
| 467 |
+
if timeout < 0:
|
| 468 |
+
break
|
| 469 |
+
try:
|
| 470 |
+
item = _get_queue_item(input, terminate_flag, timeout)
|
| 471 |
+
except Empty:
|
| 472 |
+
break
|
| 473 |
+
|
| 474 |
+
if i == 0:
|
| 475 |
+
earliest_time = time.time()
|
| 476 |
+
batch_data.append(item.data)
|
| 477 |
+
batch_id.append(item.id)
|
| 478 |
+
|
| 479 |
+
batch = _ItemWrapper(batch_data, batch_id)
|
| 480 |
+
_put_queue_item(output, batch, terminate_flag)
|
| 481 |
+
except Terminate:
|
| 482 |
+
return
|
| 483 |
+
|
| 484 |
+
|
| 485 |
+
class Unbatch(ConcurrentNode):
|
| 486 |
+
"""
|
| 487 |
+
Ungroups every batch (a list of items) into individual items and passes them to successive nodes.
|
| 488 |
+
"""
|
| 489 |
+
def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
|
| 490 |
+
super().__init__('thread', in_buffer_size, out_buffer_size)
|
| 491 |
+
|
| 492 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
| 493 |
+
try:
|
| 494 |
+
while True:
|
| 495 |
+
batch = _get_queue_item(input, terminate_flag)
|
| 496 |
+
for id, data in zip(batch.id or itertools.repeat(None), batch.data):
|
| 497 |
+
item = _ItemWrapper(data, id)
|
| 498 |
+
_put_queue_item(output, item, terminate_flag)
|
| 499 |
+
except Terminate:
|
| 500 |
+
return
|
| 501 |
+
|
| 502 |
+
|
| 503 |
+
class Buffer(Node):
|
| 504 |
+
"A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time."
|
| 505 |
+
def __init__(self, size: int):
|
| 506 |
+
super().__init__(size, size)
|
| 507 |
+
self.size = size
|
| 508 |
+
self.input = self.output = Queue(maxsize=size)
|
moge/utils/tools.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
import time
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from numbers import Number
|
| 10 |
+
from functools import wraps
|
| 11 |
+
import warnings
|
| 12 |
+
import math
|
| 13 |
+
import json
|
| 14 |
+
import os
|
| 15 |
+
import importlib
|
| 16 |
+
import importlib.util
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def catch_exception(fn):
|
| 20 |
+
@wraps(fn)
|
| 21 |
+
def wrapper(*args, **kwargs):
|
| 22 |
+
try:
|
| 23 |
+
return fn(*args, **kwargs)
|
| 24 |
+
except Exception as e:
|
| 25 |
+
import traceback
|
| 26 |
+
print(f"Exception in {fn.__name__}", end='r')
|
| 27 |
+
# print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
|
| 28 |
+
traceback.print_exc(chain=False)
|
| 29 |
+
time.sleep(0.1)
|
| 30 |
+
return None
|
| 31 |
+
return wrapper
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CallbackOnException:
|
| 35 |
+
def __init__(self, callback: Callable, exception: type):
|
| 36 |
+
self.exception = exception
|
| 37 |
+
self.callback = callback
|
| 38 |
+
|
| 39 |
+
def __enter__(self):
|
| 40 |
+
return self
|
| 41 |
+
|
| 42 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 43 |
+
if isinstance(exc_val, self.exception):
|
| 44 |
+
self.callback()
|
| 45 |
+
return True
|
| 46 |
+
return False
|
| 47 |
+
|
| 48 |
+
def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
|
| 49 |
+
for k, v in d.items():
|
| 50 |
+
if isinstance(v, dict):
|
| 51 |
+
for sub_key in traverse_nested_dict_keys(v):
|
| 52 |
+
yield (k, ) + sub_key
|
| 53 |
+
else:
|
| 54 |
+
yield (k, )
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
|
| 58 |
+
for k in keys:
|
| 59 |
+
d = d.get(k, default)
|
| 60 |
+
if d is None:
|
| 61 |
+
break
|
| 62 |
+
return d
|
| 63 |
+
|
| 64 |
+
def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
|
| 65 |
+
for k in keys[:-1]:
|
| 66 |
+
d = d.setdefault(k, {})
|
| 67 |
+
d[keys[-1]] = value
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def key_average(list_of_dicts: list) -> Dict[str, Any]:
|
| 71 |
+
"""
|
| 72 |
+
Returns a dictionary with the average value of each key in the input list of dictionaries.
|
| 73 |
+
"""
|
| 74 |
+
_nested_dict_keys = set()
|
| 75 |
+
for d in list_of_dicts:
|
| 76 |
+
_nested_dict_keys.update(traverse_nested_dict_keys(d))
|
| 77 |
+
_nested_dict_keys = sorted(_nested_dict_keys)
|
| 78 |
+
result = {}
|
| 79 |
+
for k in _nested_dict_keys:
|
| 80 |
+
values = []
|
| 81 |
+
for d in list_of_dicts:
|
| 82 |
+
v = get_nested_dict(d, k)
|
| 83 |
+
if v is not None and not math.isnan(v):
|
| 84 |
+
values.append(v)
|
| 85 |
+
avg = sum(values) / len(values) if values else float('nan')
|
| 86 |
+
set_nested_dict(result, k, avg)
|
| 87 |
+
return result
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
|
| 91 |
+
"""
|
| 92 |
+
Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
|
| 93 |
+
"""
|
| 94 |
+
items = []
|
| 95 |
+
if parent_key is None:
|
| 96 |
+
parent_key = ()
|
| 97 |
+
for k, v in d.items():
|
| 98 |
+
new_key = parent_key + (k, )
|
| 99 |
+
if isinstance(v, MutableMapping):
|
| 100 |
+
items.extend(flatten_nested_dict(v, new_key).items())
|
| 101 |
+
else:
|
| 102 |
+
items.append((new_key, v))
|
| 103 |
+
return dict(items)
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
| 107 |
+
"""
|
| 108 |
+
Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
|
| 109 |
+
"""
|
| 110 |
+
result = {}
|
| 111 |
+
for k, v in d.items():
|
| 112 |
+
sub_dict = result
|
| 113 |
+
for k_ in k[:-1]:
|
| 114 |
+
if k_ not in sub_dict:
|
| 115 |
+
sub_dict[k_] = {}
|
| 116 |
+
sub_dict = sub_dict[k_]
|
| 117 |
+
sub_dict[k[-1]] = v
|
| 118 |
+
return result
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def read_jsonl(file):
|
| 122 |
+
import json
|
| 123 |
+
with open(file, 'r') as f:
|
| 124 |
+
data = f.readlines()
|
| 125 |
+
return [json.loads(line) for line in data]
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def write_jsonl(data: List[dict], file):
|
| 129 |
+
import json
|
| 130 |
+
with open(file, 'w') as f:
|
| 131 |
+
for item in data:
|
| 132 |
+
f.write(json.dumps(item) + '\n')
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
|
| 136 |
+
import pandas as pd
|
| 137 |
+
data = [flatten_nested_dict(d) for d in data]
|
| 138 |
+
df = pd.DataFrame(data)
|
| 139 |
+
df = df.sort_index(axis=1)
|
| 140 |
+
df.columns = pd.MultiIndex.from_tuples(df.columns)
|
| 141 |
+
return df
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
|
| 145 |
+
if isinstance(d, str):
|
| 146 |
+
for old, new in mapping.items():
|
| 147 |
+
d = d.replace(old, new)
|
| 148 |
+
elif isinstance(d, list):
|
| 149 |
+
for i, item in enumerate(d):
|
| 150 |
+
d[i] = recursive_replace(item, mapping)
|
| 151 |
+
elif isinstance(d, dict):
|
| 152 |
+
for k, v in d.items():
|
| 153 |
+
d[k] = recursive_replace(v, mapping)
|
| 154 |
+
return d
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
class timeit:
|
| 158 |
+
_history: Dict[str, List['timeit']] = {}
|
| 159 |
+
|
| 160 |
+
def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
|
| 161 |
+
self.name = name
|
| 162 |
+
self.verbose = verbose
|
| 163 |
+
self.start = None
|
| 164 |
+
self.end = None
|
| 165 |
+
self.average = average
|
| 166 |
+
if average and name not in timeit._history:
|
| 167 |
+
timeit._history[name] = []
|
| 168 |
+
|
| 169 |
+
def __call__(self, func: Callable):
|
| 170 |
+
import inspect
|
| 171 |
+
if inspect.iscoroutinefunction(func):
|
| 172 |
+
async def wrapper(*args, **kwargs):
|
| 173 |
+
with timeit(self.name or func.__qualname__):
|
| 174 |
+
ret = await func(*args, **kwargs)
|
| 175 |
+
return ret
|
| 176 |
+
return wrapper
|
| 177 |
+
else:
|
| 178 |
+
def wrapper(*args, **kwargs):
|
| 179 |
+
with timeit(self.name or func.__qualname__):
|
| 180 |
+
ret = func(*args, **kwargs)
|
| 181 |
+
return ret
|
| 182 |
+
return wrapper
|
| 183 |
+
|
| 184 |
+
def __enter__(self):
|
| 185 |
+
self.start = time.time()
|
| 186 |
+
return self
|
| 187 |
+
|
| 188 |
+
@property
|
| 189 |
+
def time(self) -> float:
|
| 190 |
+
assert self.start is not None, "Time not yet started."
|
| 191 |
+
assert self.end is not None, "Time not yet ended."
|
| 192 |
+
return self.end - self.start
|
| 193 |
+
|
| 194 |
+
@property
|
| 195 |
+
def average_time(self) -> float:
|
| 196 |
+
assert self.average, "Average time not available."
|
| 197 |
+
return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
|
| 198 |
+
|
| 199 |
+
@property
|
| 200 |
+
def history(self) -> List['timeit']:
|
| 201 |
+
return timeit._history.get(self.name, [])
|
| 202 |
+
|
| 203 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 204 |
+
self.end = time.time()
|
| 205 |
+
if self.average:
|
| 206 |
+
timeit._history[self.name].append(self)
|
| 207 |
+
if self.verbose:
|
| 208 |
+
if self.average:
|
| 209 |
+
avg = self.average_time
|
| 210 |
+
print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
|
| 211 |
+
else:
|
| 212 |
+
print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
|
| 216 |
+
first = strings[0]
|
| 217 |
+
|
| 218 |
+
for start in range(len(first)):
|
| 219 |
+
if any(s[start] != strings[0][start] for s in strings):
|
| 220 |
+
break
|
| 221 |
+
|
| 222 |
+
for end in range(1, min(len(s) for s in strings)):
|
| 223 |
+
if any(s[-end] != first[-end] for s in strings):
|
| 224 |
+
break
|
| 225 |
+
|
| 226 |
+
return [s[start:len(s) - end + 1] for s in strings]
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
|
| 230 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 231 |
+
from contextlib import nullcontext
|
| 232 |
+
from tqdm import tqdm
|
| 233 |
+
|
| 234 |
+
if pbar is not None:
|
| 235 |
+
pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
|
| 236 |
+
else:
|
| 237 |
+
pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
|
| 238 |
+
|
| 239 |
+
def decorator(fn: Callable):
|
| 240 |
+
with (
|
| 241 |
+
ThreadPoolExecutor(max_workers=num_workers) as executor,
|
| 242 |
+
pbar
|
| 243 |
+
):
|
| 244 |
+
pbar.refresh()
|
| 245 |
+
@catch_exception
|
| 246 |
+
@suppress_traceback
|
| 247 |
+
def _fn(input):
|
| 248 |
+
ret = fn(input)
|
| 249 |
+
pbar.update()
|
| 250 |
+
return ret
|
| 251 |
+
executor.map(_fn, inputs)
|
| 252 |
+
executor.shutdown(wait=True)
|
| 253 |
+
|
| 254 |
+
return decorator
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def suppress_traceback(fn):
|
| 258 |
+
@wraps(fn)
|
| 259 |
+
def wrapper(*args, **kwargs):
|
| 260 |
+
try:
|
| 261 |
+
return fn(*args, **kwargs)
|
| 262 |
+
except Exception as e:
|
| 263 |
+
e.__traceback__ = e.__traceback__.tb_next.tb_next
|
| 264 |
+
raise
|
| 265 |
+
return wrapper
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class no_warnings:
|
| 269 |
+
def __init__(self, action: str = 'ignore', **kwargs):
|
| 270 |
+
self.action = action
|
| 271 |
+
self.filter_kwargs = kwargs
|
| 272 |
+
|
| 273 |
+
def __call__(self, fn):
|
| 274 |
+
@wraps(fn)
|
| 275 |
+
def wrapper(*args, **kwargs):
|
| 276 |
+
with warnings.catch_warnings():
|
| 277 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 278 |
+
return fn(*args, **kwargs)
|
| 279 |
+
return wrapper
|
| 280 |
+
|
| 281 |
+
def __enter__(self):
|
| 282 |
+
self.warnings_manager = warnings.catch_warnings()
|
| 283 |
+
self.warnings_manager.__enter__()
|
| 284 |
+
warnings.simplefilter(self.action, **self.filter_kwargs)
|
| 285 |
+
|
| 286 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
| 287 |
+
self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
|
| 291 |
+
spec = importlib.util.spec_from_file_location(module_name, file_path)
|
| 292 |
+
module = importlib.util.module_from_spec(spec)
|
| 293 |
+
spec.loader.exec_module(module)
|
| 294 |
+
return module
|
moge/utils/vis.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import matplotlib
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
|
| 13 |
+
if mask is None:
|
| 14 |
+
depth = np.where(depth > 0, depth, np.nan)
|
| 15 |
+
else:
|
| 16 |
+
depth = np.where((depth > 0) & mask, depth, np.nan)
|
| 17 |
+
disp = 1 / depth
|
| 18 |
+
if normalize:
|
| 19 |
+
min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
|
| 20 |
+
disp = (disp - min_disp) / (max_disp - min_disp)
|
| 21 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
|
| 22 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 23 |
+
return colored
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
|
| 27 |
+
if mask is not None:
|
| 28 |
+
depth = np.where(mask, depth, np.nan)
|
| 29 |
+
|
| 30 |
+
min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
|
| 31 |
+
depth = (depth - min_depth) / (max_depth - min_depth)
|
| 32 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0)
|
| 33 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 34 |
+
return colored
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
|
| 38 |
+
if mask is not None:
|
| 39 |
+
disparity = np.where(mask, disparity, np.nan)
|
| 40 |
+
|
| 41 |
+
if normalize:
|
| 42 |
+
min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
|
| 43 |
+
disparity = (disparity - min_disp) / (max_disp - min_disp)
|
| 44 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0)
|
| 45 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 46 |
+
return colored
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray:
|
| 50 |
+
colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)[..., :3]
|
| 51 |
+
colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
|
| 52 |
+
return colored
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
|
| 56 |
+
if mask is not None:
|
| 57 |
+
normal = np.where(mask[..., None], normal, 0)
|
| 58 |
+
normal = normal * [0.5, -0.5, -0.5] + 0.5
|
| 59 |
+
normal = (normal.clip(0, 1) * 255).astype(np.uint8)
|
| 60 |
+
return normal
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None):
|
| 64 |
+
vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map))
|
| 65 |
+
cmap = matplotlib.colormaps[cmap]
|
| 66 |
+
colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3]
|
| 67 |
+
if mask is not None:
|
| 68 |
+
colorized_error_map = np.where(mask[..., None], colorized_error_map, 0)
|
| 69 |
+
colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8))
|
| 70 |
+
return colorized_error_map
|
moge/utils/webfile.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
import requests
|
| 7 |
+
from typing import *
|
| 8 |
+
|
| 9 |
+
__all__ = ["WebFile"]
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class WebFile:
|
| 13 |
+
def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None):
|
| 14 |
+
self.url = url
|
| 15 |
+
self.session = session or requests.Session()
|
| 16 |
+
self.session.headers.update(headers or {})
|
| 17 |
+
self._offset = 0
|
| 18 |
+
self.size = size if size is not None else self._fetch_size()
|
| 19 |
+
|
| 20 |
+
def _fetch_size(self):
|
| 21 |
+
with self.session.get(self.url, stream=True) as response:
|
| 22 |
+
response.raise_for_status()
|
| 23 |
+
content_length = response.headers.get("Content-Length")
|
| 24 |
+
if content_length is None:
|
| 25 |
+
raise ValueError("Missing Content-Length in header")
|
| 26 |
+
return int(content_length)
|
| 27 |
+
|
| 28 |
+
def _fetch_data(self, offset: int, n: int) -> bytes:
|
| 29 |
+
headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"}
|
| 30 |
+
response = self.session.get(self.url, headers=headers)
|
| 31 |
+
response.raise_for_status()
|
| 32 |
+
return response.content
|
| 33 |
+
|
| 34 |
+
def seekable(self) -> bool:
|
| 35 |
+
return True
|
| 36 |
+
|
| 37 |
+
def tell(self) -> int:
|
| 38 |
+
return self._offset
|
| 39 |
+
|
| 40 |
+
def available(self) -> int:
|
| 41 |
+
return self.size - self._offset
|
| 42 |
+
|
| 43 |
+
def seek(self, offset: int, whence: int = 0) -> None:
|
| 44 |
+
if whence == 0:
|
| 45 |
+
new_offset = offset
|
| 46 |
+
elif whence == 1:
|
| 47 |
+
new_offset = self._offset + offset
|
| 48 |
+
elif whence == 2:
|
| 49 |
+
new_offset = self.size + offset
|
| 50 |
+
else:
|
| 51 |
+
raise ValueError("Invalid value for whence")
|
| 52 |
+
|
| 53 |
+
self._offset = max(0, min(new_offset, self.size))
|
| 54 |
+
|
| 55 |
+
def read(self, n: Optional[int] = None) -> bytes:
|
| 56 |
+
if n is None or n < 0:
|
| 57 |
+
n = self.available()
|
| 58 |
+
else:
|
| 59 |
+
n = min(n, self.available())
|
| 60 |
+
|
| 61 |
+
if n == 0:
|
| 62 |
+
return b''
|
| 63 |
+
|
| 64 |
+
data = self._fetch_data(self._offset, n)
|
| 65 |
+
self._offset += len(data)
|
| 66 |
+
|
| 67 |
+
return data
|
| 68 |
+
|
| 69 |
+
def close(self) -> None:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
def __enter__(self):
|
| 73 |
+
return self
|
| 74 |
+
|
| 75 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
|
moge/utils/webzipfile.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copied from the MoGe project:
|
| 2 |
+
# https://github.com/microsoft/MoGe
|
| 3 |
+
# Original license: MIT
|
| 4 |
+
# Copyright (c) the MoGe authors
|
| 5 |
+
|
| 6 |
+
from typing import *
|
| 7 |
+
import io
|
| 8 |
+
import os
|
| 9 |
+
from zipfile import (
|
| 10 |
+
ZipInfo, BadZipFile, ZipFile, ZipExtFile,
|
| 11 |
+
sizeFileHeader, structFileHeader, stringFileHeader,
|
| 12 |
+
_FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS,
|
| 13 |
+
_MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED
|
| 14 |
+
)
|
| 15 |
+
import struct
|
| 16 |
+
from requests import Session
|
| 17 |
+
|
| 18 |
+
from .webfile import WebFile
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class _SharedWebFile(WebFile):
|
| 22 |
+
def __init__(self, webfile: WebFile, pos: int):
|
| 23 |
+
super().__init__(webfile.url, webfile.session, size=webfile.size)
|
| 24 |
+
self.seek(pos)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class WebZipFile(ZipFile):
|
| 28 |
+
"Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads."
|
| 29 |
+
def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None):
|
| 30 |
+
"""Open the ZIP file with mode read 'r', write 'w', exclusive create 'x',
|
| 31 |
+
or append 'a'."""
|
| 32 |
+
webf = WebFile(url, session=session, headers=headers)
|
| 33 |
+
super().__init__(webf, mode='r')
|
| 34 |
+
|
| 35 |
+
def open(self, name, mode="r", pwd=None, *, force_zip64=False):
|
| 36 |
+
"""Return file-like object for 'name'.
|
| 37 |
+
|
| 38 |
+
name is a string for the file name within the ZIP file, or a ZipInfo
|
| 39 |
+
object.
|
| 40 |
+
|
| 41 |
+
mode should be 'r' to read a file already in the ZIP file, or 'w' to
|
| 42 |
+
write to a file newly added to the archive.
|
| 43 |
+
|
| 44 |
+
pwd is the password to decrypt files (only used for reading).
|
| 45 |
+
|
| 46 |
+
When writing, if the file size is not known in advance but may exceed
|
| 47 |
+
2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large
|
| 48 |
+
files. If the size is known in advance, it is best to pass a ZipInfo
|
| 49 |
+
instance for name, with zinfo.file_size set.
|
| 50 |
+
"""
|
| 51 |
+
if mode not in {"r", "w"}:
|
| 52 |
+
raise ValueError('open() requires mode "r" or "w"')
|
| 53 |
+
if pwd and (mode == "w"):
|
| 54 |
+
raise ValueError("pwd is only supported for reading files")
|
| 55 |
+
if not self.fp:
|
| 56 |
+
raise ValueError(
|
| 57 |
+
"Attempt to use ZIP archive that was already closed")
|
| 58 |
+
|
| 59 |
+
assert mode == "r", "Only read mode is supported for now"
|
| 60 |
+
|
| 61 |
+
# Make sure we have an info object
|
| 62 |
+
if isinstance(name, ZipInfo):
|
| 63 |
+
# 'name' is already an info object
|
| 64 |
+
zinfo = name
|
| 65 |
+
elif mode == 'w':
|
| 66 |
+
zinfo = ZipInfo(name)
|
| 67 |
+
zinfo.compress_type = self.compression
|
| 68 |
+
zinfo._compresslevel = self.compresslevel
|
| 69 |
+
else:
|
| 70 |
+
# Get info object for name
|
| 71 |
+
zinfo = self.getinfo(name)
|
| 72 |
+
|
| 73 |
+
if mode == 'w':
|
| 74 |
+
return self._open_to_write(zinfo, force_zip64=force_zip64)
|
| 75 |
+
|
| 76 |
+
if self._writing:
|
| 77 |
+
raise ValueError("Can't read from the ZIP file while there "
|
| 78 |
+
"is an open writing handle on it. "
|
| 79 |
+
"Close the writing handle before trying to read.")
|
| 80 |
+
|
| 81 |
+
# Open for reading:
|
| 82 |
+
self._fileRefCnt += 1
|
| 83 |
+
zef_file = _SharedWebFile(self.fp, zinfo.header_offset)
|
| 84 |
+
|
| 85 |
+
try:
|
| 86 |
+
# Skip the file header:
|
| 87 |
+
fheader = zef_file.read(sizeFileHeader)
|
| 88 |
+
if len(fheader) != sizeFileHeader:
|
| 89 |
+
raise BadZipFile("Truncated file header")
|
| 90 |
+
fheader = struct.unpack(structFileHeader, fheader)
|
| 91 |
+
if fheader[_FH_SIGNATURE] != stringFileHeader:
|
| 92 |
+
raise BadZipFile("Bad magic number for file header")
|
| 93 |
+
|
| 94 |
+
fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
|
| 95 |
+
if fheader[_FH_EXTRA_FIELD_LENGTH]:
|
| 96 |
+
zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1)
|
| 97 |
+
|
| 98 |
+
if zinfo.flag_bits & _MASK_COMPRESSED_PATCH:
|
| 99 |
+
# Zip 2.7: compressed patched data
|
| 100 |
+
raise NotImplementedError("compressed patched data (flag bit 5)")
|
| 101 |
+
|
| 102 |
+
if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION:
|
| 103 |
+
# strong encryption
|
| 104 |
+
raise NotImplementedError("strong encryption (flag bit 6)")
|
| 105 |
+
|
| 106 |
+
if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME:
|
| 107 |
+
# UTF-8 filename
|
| 108 |
+
fname_str = fname.decode("utf-8")
|
| 109 |
+
else:
|
| 110 |
+
fname_str = fname.decode(self.metadata_encoding or "cp437")
|
| 111 |
+
|
| 112 |
+
if fname_str != zinfo.orig_filename:
|
| 113 |
+
raise BadZipFile(
|
| 114 |
+
'File name in directory %r and header %r differ.'
|
| 115 |
+
% (zinfo.orig_filename, fname))
|
| 116 |
+
|
| 117 |
+
# check for encrypted flag & handle password
|
| 118 |
+
is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED
|
| 119 |
+
if is_encrypted:
|
| 120 |
+
if not pwd:
|
| 121 |
+
pwd = self.pwd
|
| 122 |
+
if pwd and not isinstance(pwd, bytes):
|
| 123 |
+
raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__)
|
| 124 |
+
if not pwd:
|
| 125 |
+
raise RuntimeError("File %r is encrypted, password "
|
| 126 |
+
"required for extraction" % name)
|
| 127 |
+
else:
|
| 128 |
+
pwd = None
|
| 129 |
+
|
| 130 |
+
return ZipExtFile(zef_file, mode, zinfo, pwd, True)
|
| 131 |
+
except:
|
| 132 |
+
zef_file.close()
|
| 133 |
+
raise
|