JasonYinnnn commited on
Commit
2a36119
·
1 Parent(s): 7dabaaa

move threeDFixer to safe file

Browse files
app.py CHANGED
@@ -18,17 +18,7 @@ import random
18
  import imageio
19
  from einops import repeat
20
  from huggingface_hub import snapshot_download
21
- from threeDFixer.moge.model.v2 import MoGeModel
22
- from threeDFixer.pipelines import ThreeDFixerPipeline
23
- from threeDFixer.datasets.utils import (
24
- edge_mask_morph_gradient,
25
- process_scene_image,
26
- process_instance_image,
27
- transform_vertices,
28
- normalize_vertices,
29
- project2ply
30
- )
31
- from threeDFixer.utils import render_utils, postprocessing_utils
32
  from transformers import AutoModelForMaskGeneration, AutoProcessor
33
  from scripts.grounding_sam import plot_segmentation, segment
34
  import copy
@@ -192,6 +182,11 @@ def run_depth_estimation(
192
  ) -> Image.Image:
193
  rgb_image = image_prompts["image"].convert("RGB")
194
 
 
 
 
 
 
195
  rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
196
 
197
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
@@ -291,45 +286,6 @@ def set_random_seed(seed):
291
  if torch.cuda.is_available():
292
  torch.cuda.manual_seed_all(seed)
293
 
294
- def export_single_glb_from_outputs(
295
- outputs,
296
- fine_scale,
297
- fine_trans,
298
- coarse_scale,
299
- coarse_trans,
300
- trans,
301
- scale,
302
- rot,
303
- work_space,
304
- instance_name,
305
- run_id
306
- ):
307
-
308
- with torch.enable_grad():
309
- glb = postprocessing_utils.to_glb(
310
- outputs["gaussian"][0],
311
- outputs["mesh"][0],
312
- simplify=0.95,
313
- texture_size=1024,
314
- transform_fn=lambda x: transform_vertices(
315
- x,
316
- ops=["scale", "translation", "scale", "translation"],
317
- params=[fine_scale, fine_trans[None], coarse_scale, coarse_trans[None]],
318
- ),
319
- verbose=False
320
- )
321
-
322
- instance_glb_path = os.path.abspath(
323
- os.path.join(work_space, f"{run_id}_{instance_name}.glb")
324
- )
325
-
326
- glb.apply_translation(-trans) \
327
- .apply_scale(1.0 / (scale + 1e-6)) \
328
- .apply_transform(rot) \
329
- .export(instance_glb_path)
330
-
331
- return instance_glb_path, glb
332
-
333
 
334
  def export_scene_glb(trimeshes, work_space, scene_name):
335
  scene_path = os.path.abspath(os.path.join(work_space, scene_name))
@@ -356,6 +312,59 @@ def run_generation(
356
  cfg_interval_end: float = 1.0,
357
  t_rescale: float = 3.0,
358
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
359
  global dpt_pack
360
  global work_space
361
  global generated_object_map
 
18
  import imageio
19
  from einops import repeat
20
  from huggingface_hub import snapshot_download
21
+ from moge.model.v2 import MoGeModel
 
 
 
 
 
 
 
 
 
 
22
  from transformers import AutoModelForMaskGeneration, AutoProcessor
23
  from scripts.grounding_sam import plot_segmentation, segment
24
  import copy
 
182
  ) -> Image.Image:
183
  rgb_image = image_prompts["image"].convert("RGB")
184
 
185
+ from threeDFixer.datasets.utils import (
186
+ normalize_vertices,
187
+ project2ply
188
+ )
189
+
190
  rgb_image = rgb_image.resize((1024, 1024), Image.Resampling.LANCZOS)
191
 
192
  device = 'cuda' if torch.cuda.is_available() else 'cpu'
 
286
  if torch.cuda.is_available():
287
  torch.cuda.manual_seed_all(seed)
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
  def export_scene_glb(trimeshes, work_space, scene_name):
291
  scene_path = os.path.abspath(os.path.join(work_space, scene_name))
 
312
  cfg_interval_end: float = 1.0,
313
  t_rescale: float = 3.0,
314
  ):
315
+
316
+ from threeDFixer.pipelines import ThreeDFixerPipeline
317
+ from threeDFixer.datasets.utils import (
318
+ edge_mask_morph_gradient,
319
+ process_scene_image,
320
+ process_instance_image,
321
+ )
322
+ from threeDFixer.utils import render_utils
323
+
324
+ def export_single_glb_from_outputs(
325
+ outputs,
326
+ fine_scale,
327
+ fine_trans,
328
+ coarse_scale,
329
+ coarse_trans,
330
+ trans,
331
+ scale,
332
+ rot,
333
+ work_space,
334
+ instance_name,
335
+ run_id
336
+ ):
337
+
338
+ from threeDFixer.datasets.utils import (
339
+ transform_vertices,
340
+ )
341
+ from threeDFixer.utils import postprocessing_utils
342
+
343
+ with torch.enable_grad():
344
+ glb = postprocessing_utils.to_glb(
345
+ outputs["gaussian"][0],
346
+ outputs["mesh"][0],
347
+ simplify=0.95,
348
+ texture_size=1024,
349
+ transform_fn=lambda x: transform_vertices(
350
+ x,
351
+ ops=["scale", "translation", "scale", "translation"],
352
+ params=[fine_scale, fine_trans[None], coarse_scale, coarse_trans[None]],
353
+ ),
354
+ verbose=False
355
+ )
356
+
357
+ instance_glb_path = os.path.abspath(
358
+ os.path.join(work_space, f"{run_id}_{instance_name}.glb")
359
+ )
360
+
361
+ glb.apply_translation(-trans) \
362
+ .apply_scale(1.0 / (scale + 1e-6)) \
363
+ .apply_transform(rot) \
364
+ .export(instance_glb_path)
365
+
366
+ return instance_glb_path, glb
367
+
368
  global dpt_pack
369
  global work_space
370
  global generated_object_map
moge/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
moge/model/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ import importlib
7
+ from typing import *
8
+
9
+ if TYPE_CHECKING:
10
+ from .v1 import MoGeModel as MoGeModelV1
11
+ from .v2 import MoGeModel as MoGeModelV2
12
+
13
+
14
+ def import_model_class_by_version(version: str) -> Type[Union['MoGeModelV1', 'MoGeModelV2']]:
15
+ assert version in ['v1', 'v2'], f'Unsupported model version: {version}'
16
+
17
+ try:
18
+ module = importlib.import_module(f'.{version}', __package__)
19
+ except ModuleNotFoundError:
20
+ raise ValueError(f'Model version "{version}" not found.')
21
+
22
+ cls = getattr(module, 'MoGeModel')
23
+ return cls
moge/model/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
moge/model/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
moge/model/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
moge/model/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
moge/model/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
moge/model/dinov2/layers/attention.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+ import torch.nn.functional as F
17
+
18
+
19
+ logger = logging.getLogger("dinov2")
20
+
21
+
22
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
23
+ try:
24
+ if XFORMERS_ENABLED:
25
+ from xformers.ops import memory_efficient_attention, unbind
26
+
27
+ XFORMERS_AVAILABLE = True
28
+ # warnings.warn("xFormers is available (Attention)")
29
+ else:
30
+ # warnings.warn("xFormers is disabled (Attention)")
31
+ raise ImportError
32
+ except ImportError:
33
+ XFORMERS_AVAILABLE = False
34
+ # warnings.warn("xFormers is not available (Attention)")
35
+
36
+
37
+ class Attention(nn.Module):
38
+ def __init__(
39
+ self,
40
+ dim: int,
41
+ num_heads: int = 8,
42
+ qkv_bias: bool = False,
43
+ proj_bias: bool = True,
44
+ attn_drop: float = 0.0,
45
+ proj_drop: float = 0.0,
46
+ ) -> None:
47
+ super().__init__()
48
+ self.num_heads = num_heads
49
+ head_dim = dim // num_heads
50
+ self.scale = head_dim**-0.5
51
+
52
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
53
+ self.attn_drop = nn.Dropout(attn_drop)
54
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
55
+ self.proj_drop = nn.Dropout(proj_drop)
56
+
57
+ # # Deprecated implementation, extremely slow
58
+ # def forward(self, x: Tensor, attn_bias=None) -> Tensor:
59
+ # B, N, C = x.shape
60
+ # qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
61
+ # q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
62
+ # attn = q @ k.transpose(-2, -1)
63
+ # attn = attn.softmax(dim=-1)
64
+ # attn = self.attn_drop(attn)
65
+ # x = (attn @ v).transpose(1, 2).reshape(B, N, C)
66
+ # x = self.proj(x)
67
+ # x = self.proj_drop(x)
68
+ # return x
69
+
70
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
71
+ B, N, C = x.shape
72
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
73
+
74
+ q, k, v = qkv.unbind(0) # (B, H, N, C // H)
75
+
76
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
77
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
78
+
79
+ x = self.proj(x)
80
+ x = self.proj_drop(x)
81
+ return x
82
+
83
+ class MemEffAttention(Attention):
84
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
85
+ if not XFORMERS_AVAILABLE:
86
+ if attn_bias is not None:
87
+ raise AssertionError("xFormers is required for using nested tensors")
88
+ return super().forward(x)
89
+
90
+ B, N, C = x.shape
91
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
92
+
93
+ q, k, v = unbind(qkv, 2)
94
+
95
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
96
+ x = x.reshape([B, N, C])
97
+
98
+ x = self.proj(x)
99
+ x = self.proj_drop(x)
100
+ return x
moge/model/dinov2/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
moge/model/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
moge/model/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
moge/model/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
moge/model/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
moge/model/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
moge/model/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
moge/model/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
moge/model/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable, Optional, List
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ @property
173
+ def onnx_compatible_mode(self):
174
+ return getattr(self, "_onnx_compatible_mode", False)
175
+
176
+ @onnx_compatible_mode.setter
177
+ def onnx_compatible_mode(self, value: bool):
178
+ self._onnx_compatible_mode = value
179
+
180
+ def init_weights(self):
181
+ trunc_normal_(self.pos_embed, std=0.02)
182
+ nn.init.normal_(self.cls_token, std=1e-6)
183
+ if self.register_tokens is not None:
184
+ nn.init.normal_(self.register_tokens, std=1e-6)
185
+ named_apply(init_weights_vit_timm, self)
186
+
187
+ def interpolate_pos_encoding(self, x, h, w):
188
+ previous_dtype = x.dtype
189
+ npatch = x.shape[1] - 1
190
+ batch_size = x.shape[0]
191
+ N = self.pos_embed.shape[1] - 1
192
+ if not self.onnx_compatible_mode and npatch == N and w == h:
193
+ return self.pos_embed
194
+ pos_embed = self.pos_embed.float()
195
+ class_pos_embed = pos_embed[:, 0, :]
196
+ patch_pos_embed = pos_embed[:, 1:, :]
197
+ dim = x.shape[-1]
198
+ h0, w0 = h // self.patch_size, w // self.patch_size
199
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
200
+ assert N == M * M
201
+ kwargs = {}
202
+ if not self.onnx_compatible_mode and self.interpolate_offset > 0:
203
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
204
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
205
+ sx = float(w0 + self.interpolate_offset) / M
206
+ sy = float(h0 + self.interpolate_offset) / M
207
+ kwargs["scale_factor"] = (sy, sx)
208
+ else:
209
+ # Simply specify an output size instead of a scale factor
210
+ kwargs["size"] = (h0, w0)
211
+
212
+ patch_pos_embed = nn.functional.interpolate(
213
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
214
+ mode="bicubic",
215
+ antialias=self.interpolate_antialias,
216
+ **kwargs,
217
+ )
218
+
219
+ assert (h0, w0) == patch_pos_embed.shape[-2:]
220
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).flatten(1, 2)
221
+ return torch.cat((class_pos_embed[:, None, :].expand(patch_pos_embed.shape[0], -1, -1), patch_pos_embed), dim=1).to(previous_dtype)
222
+
223
+ def prepare_tokens_with_masks(self, x, masks=None):
224
+ B, nc, h, w = x.shape
225
+ x = self.patch_embed(x)
226
+
227
+ if masks is not None:
228
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
229
+
230
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
231
+ x = x + self.interpolate_pos_encoding(x, h, w)
232
+
233
+ if self.register_tokens is not None:
234
+ x = torch.cat(
235
+ (
236
+ x[:, :1],
237
+ self.register_tokens.expand(x.shape[0], -1, -1),
238
+ x[:, 1:],
239
+ ),
240
+ dim=1,
241
+ )
242
+
243
+ return x
244
+
245
+ def forward_features_list(self, x_list, masks_list):
246
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks, ar in zip(x_list, masks_list)]
247
+ for blk in self.blocks:
248
+ x = blk(x)
249
+
250
+ all_x = x
251
+ output = []
252
+ for x, masks in zip(all_x, masks_list):
253
+ x_norm = self.norm(x)
254
+ output.append(
255
+ {
256
+ "x_norm_clstoken": x_norm[:, 0],
257
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
258
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
259
+ "x_prenorm": x,
260
+ "masks": masks,
261
+ }
262
+ )
263
+ return output
264
+
265
+ def forward_features(self, x, masks=None):
266
+ if isinstance(x, list):
267
+ return self.forward_features_list(x, masks)
268
+
269
+ x = self.prepare_tokens_with_masks(x, masks)
270
+
271
+ for blk in self.blocks:
272
+ x = blk(x)
273
+
274
+ x_norm = self.norm(x)
275
+ return {
276
+ "x_norm_clstoken": x_norm[:, 0],
277
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
278
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
279
+ "x_prenorm": x,
280
+ "masks": masks,
281
+ }
282
+
283
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
284
+ x = self.prepare_tokens_with_masks(x)
285
+ # If n is an int, take the n last blocks. If it's a list, take them
286
+ output, total_block_len = [], len(self.blocks)
287
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
288
+ for i, blk in enumerate(self.blocks):
289
+ x = blk(x)
290
+ if i in blocks_to_take:
291
+ output.append(x)
292
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
293
+ return output
294
+
295
+ def _get_intermediate_layers_chunked(self, x, n=1):
296
+ x = self.prepare_tokens_with_masks(x)
297
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
298
+ # If n is an int, take the n last blocks. If it's a list, take them
299
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
300
+ for block_chunk in self.blocks:
301
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
302
+ x = blk(x)
303
+ if i in blocks_to_take:
304
+ output.append(x)
305
+ i += 1
306
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
307
+ return output
308
+
309
+ def get_intermediate_layers(
310
+ self,
311
+ x: torch.Tensor,
312
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
313
+ reshape: bool = False,
314
+ return_class_token: bool = False,
315
+ norm=True,
316
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
317
+ if self.chunked_blocks:
318
+ outputs = self._get_intermediate_layers_chunked(x, n)
319
+ else:
320
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
321
+ if norm:
322
+ outputs = [self.norm(out) for out in outputs]
323
+ class_tokens = [out[:, 0] for out in outputs]
324
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
325
+ if reshape:
326
+ B, _, w, h = x.shape
327
+ outputs = [
328
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
329
+ for out in outputs
330
+ ]
331
+ if return_class_token:
332
+ return tuple(zip(outputs, class_tokens))
333
+ return tuple(outputs)
334
+
335
+ def forward(self, *args, is_training=False, **kwargs):
336
+ ret = self.forward_features(*args, **kwargs)
337
+ if is_training:
338
+ return ret
339
+ else:
340
+ return self.head(ret["x_norm_clstoken"])
341
+
342
+
343
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
344
+ """ViT weight initialization, original timm impl (for reproducibility)"""
345
+ if isinstance(module, nn.Linear):
346
+ trunc_normal_(module.weight, std=0.02)
347
+ if module.bias is not None:
348
+ nn.init.zeros_(module.bias)
349
+
350
+
351
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
352
+ model = DinoVisionTransformer(
353
+ patch_size=patch_size,
354
+ embed_dim=384,
355
+ depth=12,
356
+ num_heads=6,
357
+ mlp_ratio=4,
358
+ block_fn=partial(Block, attn_class=MemEffAttention),
359
+ num_register_tokens=num_register_tokens,
360
+ **kwargs,
361
+ )
362
+ return model
363
+
364
+
365
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
366
+ model = DinoVisionTransformer(
367
+ patch_size=patch_size,
368
+ embed_dim=768,
369
+ depth=12,
370
+ num_heads=12,
371
+ mlp_ratio=4,
372
+ block_fn=partial(Block, attn_class=MemEffAttention),
373
+ num_register_tokens=num_register_tokens,
374
+ **kwargs,
375
+ )
376
+ return model
377
+
378
+
379
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
380
+ model = DinoVisionTransformer(
381
+ patch_size=patch_size,
382
+ embed_dim=1024,
383
+ depth=24,
384
+ num_heads=16,
385
+ mlp_ratio=4,
386
+ block_fn=partial(Block, attn_class=MemEffAttention),
387
+ num_register_tokens=num_register_tokens,
388
+ **kwargs,
389
+ )
390
+ return model
391
+
392
+
393
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
394
+ """
395
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
396
+ """
397
+ model = DinoVisionTransformer(
398
+ patch_size=patch_size,
399
+ embed_dim=1536,
400
+ depth=40,
401
+ num_heads=24,
402
+ mlp_ratio=4,
403
+ block_fn=partial(Block, attn_class=MemEffAttention),
404
+ num_register_tokens=num_register_tokens,
405
+ **kwargs,
406
+ )
407
+ return model
moge/model/dinov2/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
moge/model/dinov2/utils/cluster.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class ClusterType(Enum):
13
+ AWS = "aws"
14
+ FAIR = "fair"
15
+ RSC = "rsc"
16
+
17
+
18
+ def _guess_cluster_type() -> ClusterType:
19
+ uname = os.uname()
20
+ if uname.sysname == "Linux":
21
+ if uname.release.endswith("-aws"):
22
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
+ return ClusterType.AWS
24
+ elif uname.nodename.startswith("rsc"):
25
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
+ return ClusterType.RSC
27
+
28
+ return ClusterType.FAIR
29
+
30
+
31
+ def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
32
+ if cluster_type is None:
33
+ return _guess_cluster_type()
34
+
35
+ return cluster_type
36
+
37
+
38
+ def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
39
+ cluster_type = get_cluster_type(cluster_type)
40
+ if cluster_type is None:
41
+ return None
42
+
43
+ CHECKPOINT_DIRNAMES = {
44
+ ClusterType.AWS: "checkpoints",
45
+ ClusterType.FAIR: "checkpoint",
46
+ ClusterType.RSC: "checkpoint/dino",
47
+ }
48
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
49
+
50
+
51
+ def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
52
+ checkpoint_path = get_checkpoint_path(cluster_type)
53
+ if checkpoint_path is None:
54
+ return None
55
+
56
+ username = os.environ.get("USER")
57
+ assert username is not None
58
+ return checkpoint_path / username
59
+
60
+
61
+ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
62
+ cluster_type = get_cluster_type(cluster_type)
63
+ if cluster_type is None:
64
+ return None
65
+
66
+ SLURM_PARTITIONS = {
67
+ ClusterType.AWS: "learnlab",
68
+ ClusterType.FAIR: "learnlab",
69
+ ClusterType.RSC: "learn",
70
+ }
71
+ return SLURM_PARTITIONS[cluster_type]
72
+
73
+
74
+ def get_slurm_executor_parameters(
75
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
76
+ ) -> Dict[str, Any]:
77
+ # create default parameters
78
+ params = {
79
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
80
+ "gpus_per_node": num_gpus_per_node,
81
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
82
+ "cpus_per_task": 10,
83
+ "nodes": nodes,
84
+ "slurm_partition": get_slurm_partition(cluster_type),
85
+ }
86
+ # apply cluster-specific adjustments
87
+ cluster_type = get_cluster_type(cluster_type)
88
+ if cluster_type == ClusterType.AWS:
89
+ params["cpus_per_task"] = 12
90
+ del params["mem_gb"]
91
+ elif cluster_type == ClusterType.RSC:
92
+ params["cpus_per_task"] = 12
93
+ # set additional parameters / apply overrides
94
+ params.update(kwargs)
95
+ return params
moge/model/dinov2/utils/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import logging
8
+ import os
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ import dinov2.distributed as distributed
13
+ from dinov2.logging import setup_logging
14
+ from dinov2.utils import utils
15
+ from dinov2.configs import dinov2_default_config
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ def apply_scaling_rules_to_cfg(cfg): # to fix
22
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
23
+ base_lr = cfg.optim.base_lr
24
+ cfg.optim.lr = base_lr
25
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
26
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
+ else:
28
+ raise NotImplementedError
29
+ return cfg
30
+
31
+
32
+ def write_config(cfg, output_dir, name="config.yaml"):
33
+ logger.info(OmegaConf.to_yaml(cfg))
34
+ saved_cfg_path = os.path.join(output_dir, name)
35
+ with open(saved_cfg_path, "w") as f:
36
+ OmegaConf.save(config=cfg, f=f)
37
+ return saved_cfg_path
38
+
39
+
40
+ def get_cfg_from_args(args):
41
+ args.output_dir = os.path.abspath(args.output_dir)
42
+ args.opts += [f"train.output_dir={args.output_dir}"]
43
+ default_cfg = OmegaConf.create(dinov2_default_config)
44
+ cfg = OmegaConf.load(args.config_file)
45
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
+ return cfg
47
+
48
+
49
+ def default_setup(args):
50
+ distributed.enable(overwrite=True)
51
+ seed = getattr(args, "seed", 0)
52
+ rank = distributed.get_global_rank()
53
+
54
+ global logger
55
+ setup_logging(output=args.output_dir, level=logging.INFO)
56
+ logger = logging.getLogger("dinov2")
57
+
58
+ utils.fix_random_seeds(seed + rank)
59
+ logger.info("git:\n {}\n".format(utils.get_sha()))
60
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
61
+
62
+
63
+ def setup(args):
64
+ """
65
+ Create configs and perform basic setups.
66
+ """
67
+ cfg = get_cfg_from_args(args)
68
+ os.makedirs(args.output_dir, exist_ok=True)
69
+ default_setup(args)
70
+ apply_scaling_rules_to_cfg(cfg)
71
+ write_config(cfg, args.output_dir)
72
+ return cfg
moge/model/dinov2/utils/dtype.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Dict, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ TypeSpec = Union[str, np.dtype, torch.dtype]
14
+
15
+
16
+ _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
17
+ np.dtype("bool"): torch.bool,
18
+ np.dtype("uint8"): torch.uint8,
19
+ np.dtype("int8"): torch.int8,
20
+ np.dtype("int16"): torch.int16,
21
+ np.dtype("int32"): torch.int32,
22
+ np.dtype("int64"): torch.int64,
23
+ np.dtype("float16"): torch.float16,
24
+ np.dtype("float32"): torch.float32,
25
+ np.dtype("float64"): torch.float64,
26
+ np.dtype("complex64"): torch.complex64,
27
+ np.dtype("complex128"): torch.complex128,
28
+ }
29
+
30
+
31
+ def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
32
+ if isinstance(dtype, torch.dtype):
33
+ return dtype
34
+ if isinstance(dtype, str):
35
+ dtype = np.dtype(dtype)
36
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
moge/model/dinov2/utils/param_groups.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ import logging
8
+
9
+
10
+ logger = logging.getLogger("dinov2")
11
+
12
+
13
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
14
+ """
15
+ Calculate lr decay rate for different ViT blocks.
16
+ Args:
17
+ name (string): parameter name.
18
+ lr_decay_rate (float): base lr decay rate.
19
+ num_layers (int): number of ViT blocks.
20
+ Returns:
21
+ lr decay rate for the given parameter.
22
+ """
23
+ layer_id = num_layers + 1
24
+ if name.startswith("backbone") or force_is_backbone:
25
+ if (
26
+ ".pos_embed" in name
27
+ or ".patch_embed" in name
28
+ or ".mask_token" in name
29
+ or ".cls_token" in name
30
+ or ".register_tokens" in name
31
+ ):
32
+ layer_id = 0
33
+ elif force_is_backbone and (
34
+ "pos_embed" in name
35
+ or "patch_embed" in name
36
+ or "mask_token" in name
37
+ or "cls_token" in name
38
+ or "register_tokens" in name
39
+ ):
40
+ layer_id = 0
41
+ elif ".blocks." in name and ".residual." not in name:
42
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
43
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
44
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
45
+ elif "blocks." in name and "residual." not in name:
46
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
47
+
48
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
49
+
50
+
51
+ def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
52
+ chunked_blocks = False
53
+ if hasattr(model, "n_blocks"):
54
+ logger.info("chunked fsdp")
55
+ n_blocks = model.n_blocks
56
+ chunked_blocks = model.chunked_blocks
57
+ elif hasattr(model, "blocks"):
58
+ logger.info("first code branch")
59
+ n_blocks = len(model.blocks)
60
+ elif hasattr(model, "backbone"):
61
+ logger.info("second code branch")
62
+ n_blocks = len(model.backbone.blocks)
63
+ else:
64
+ logger.info("else code branch")
65
+ n_blocks = 0
66
+ all_param_groups = []
67
+
68
+ for name, param in model.named_parameters():
69
+ name = name.replace("_fsdp_wrapped_module.", "")
70
+ if not param.requires_grad:
71
+ continue
72
+ decay_rate = get_vit_lr_decay_rate(
73
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
74
+ )
75
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
76
+
77
+ if "last_layer" in name:
78
+ d.update({"is_last_layer": True})
79
+
80
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
81
+ d.update({"wd_multiplier": 0.0})
82
+
83
+ if "patch_embed" in name:
84
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
85
+
86
+ all_param_groups.append(d)
87
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
88
+
89
+ return all_param_groups
90
+
91
+
92
+ def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
93
+ fused_params_groups = defaultdict(lambda: {"params": []})
94
+ for d in all_params_groups:
95
+ identifier = ""
96
+ for k in keys:
97
+ identifier += k + str(d[k]) + "_"
98
+
99
+ for k in keys:
100
+ fused_params_groups[identifier][k] = d[k]
101
+ fused_params_groups[identifier]["params"].append(d["params"])
102
+
103
+ return fused_params_groups.values()
moge/model/dinov2/utils/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import subprocess
10
+ from urllib.parse import urlparse
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
21
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
22
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
23
+ else:
24
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
25
+ if checkpoint_key is not None and checkpoint_key in state_dict:
26
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
27
+ state_dict = state_dict[checkpoint_key]
28
+ # remove `module.` prefix
29
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30
+ # remove `backbone.` prefix induced by multicrop wrapper
31
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
+ msg = model.load_state_dict(state_dict, strict=False)
33
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
34
+
35
+
36
+ def fix_random_seeds(seed=31):
37
+ """
38
+ Fix random seeds.
39
+ """
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+
46
+ def get_sha():
47
+ cwd = os.path.dirname(os.path.abspath(__file__))
48
+
49
+ def _run(command):
50
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
51
+
52
+ sha = "N/A"
53
+ diff = "clean"
54
+ branch = "N/A"
55
+ try:
56
+ sha = _run(["git", "rev-parse", "HEAD"])
57
+ subprocess.check_output(["git", "diff"], cwd=cwd)
58
+ diff = _run(["git", "diff-index", "HEAD"])
59
+ diff = "has uncommitted changes" if diff else "clean"
60
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
61
+ except Exception:
62
+ pass
63
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
64
+ return message
65
+
66
+
67
+ class CosineScheduler(object):
68
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
69
+ super().__init__()
70
+ self.final_value = final_value
71
+ self.total_iters = total_iters
72
+
73
+ freeze_schedule = np.zeros((freeze_iters))
74
+
75
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
76
+
77
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
78
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
79
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
80
+
81
+ assert len(self.schedule) == self.total_iters
82
+
83
+ def __getitem__(self, it):
84
+ if it >= self.total_iters:
85
+ return self.final_value
86
+ else:
87
+ return self.schedule[it]
88
+
89
+
90
+ def has_batchnorms(model):
91
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
92
+ for name, module in model.named_modules():
93
+ if isinstance(module, bn_types):
94
+ return True
95
+ return False
moge/model/modules.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+ from numbers import Number
8
+ import importlib
9
+ import itertools
10
+ import functools
11
+ import sys
12
+
13
+ import torch
14
+ from torch import Tensor
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+
18
+ from .dinov2.models.vision_transformer import DinoVisionTransformer
19
+ from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
20
+ from ..utils.geometry_torch import normalized_view_plane_uv
21
+
22
+
23
+ class ResidualConvBlock(nn.Module):
24
+ def __init__(
25
+ self,
26
+ in_channels: int,
27
+ out_channels: int = None,
28
+ hidden_channels: int = None,
29
+ kernel_size: int = 3,
30
+ padding_mode: str = 'replicate',
31
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
32
+ in_norm: Literal['group_norm', 'layer_norm', 'instance_norm', 'none'] = 'layer_norm',
33
+ hidden_norm: Literal['group_norm', 'layer_norm', 'instance_norm'] = 'group_norm',
34
+ ):
35
+ super(ResidualConvBlock, self).__init__()
36
+ if out_channels is None:
37
+ out_channels = in_channels
38
+ if hidden_channels is None:
39
+ hidden_channels = in_channels
40
+
41
+ if activation =='relu':
42
+ activation_cls = nn.ReLU
43
+ elif activation == 'leaky_relu':
44
+ activation_cls = functools.partial(nn.LeakyReLU, negative_slope=0.2)
45
+ elif activation =='silu':
46
+ activation_cls = nn.SiLU
47
+ elif activation == 'elu':
48
+ activation_cls = nn.ELU
49
+ else:
50
+ raise ValueError(f'Unsupported activation function: {activation}')
51
+
52
+ self.layers = nn.Sequential(
53
+ nn.GroupNorm(in_channels // 32, in_channels) if in_norm == 'group_norm' else \
54
+ nn.GroupNorm(1, in_channels) if in_norm == 'layer_norm' else \
55
+ nn.InstanceNorm2d(in_channels) if in_norm == 'instance_norm' else \
56
+ nn.Identity(),
57
+ activation_cls(),
58
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode),
59
+ nn.GroupNorm(hidden_channels // 32, hidden_channels) if hidden_norm == 'group_norm' else \
60
+ nn.GroupNorm(1, hidden_channels) if hidden_norm == 'layer_norm' else \
61
+ nn.InstanceNorm2d(hidden_channels) if hidden_norm == 'instance_norm' else\
62
+ nn.Identity(),
63
+ activation_cls(),
64
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=kernel_size, padding=kernel_size // 2, padding_mode=padding_mode)
65
+ )
66
+
67
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
68
+
69
+ def forward(self, x):
70
+ skip = self.skip_connection(x)
71
+ x = self.layers(x)
72
+ x = x + skip
73
+ return x
74
+
75
+
76
+ class DINOv2Encoder(nn.Module):
77
+ "Wrapped DINOv2 encoder supporting gradient checkpointing. Input is RGB image in range [0, 1]."
78
+ backbone: DinoVisionTransformer
79
+ image_mean: torch.Tensor
80
+ image_std: torch.Tensor
81
+ dim_features: int
82
+
83
+ def __init__(self, backbone: str, intermediate_layers: Union[int, List[int]], dim_out: int, **deprecated_kwargs):
84
+ super(DINOv2Encoder, self).__init__()
85
+
86
+ self.intermediate_layers = intermediate_layers
87
+
88
+ # Load the backbone
89
+ self.hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), backbone)
90
+ self.backbone_name = backbone
91
+ self.backbone = self.hub_loader(pretrained=False)
92
+
93
+ self.dim_features = self.backbone.blocks[0].attn.qkv.in_features
94
+ self.num_features = intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers)
95
+
96
+ self.output_projections = nn.ModuleList([
97
+ nn.Conv2d(in_channels=self.dim_features, out_channels=dim_out, kernel_size=1, stride=1, padding=0,)
98
+ for _ in range(self.num_features)
99
+ ])
100
+
101
+ self.register_buffer("image_mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
102
+ self.register_buffer("image_std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))
103
+
104
+ @property
105
+ def onnx_compatible_mode(self):
106
+ return getattr(self, "_onnx_compatible_mode", False)
107
+
108
+ @onnx_compatible_mode.setter
109
+ def onnx_compatible_mode(self, value: bool):
110
+ self._onnx_compatible_mode = value
111
+ self.backbone.onnx_compatible_mode = value
112
+
113
+ def init_weights(self):
114
+ pretrained_backbone_state_dict = self.hub_loader(pretrained=True).state_dict()
115
+ self.backbone.load_state_dict(pretrained_backbone_state_dict)
116
+
117
+ def enable_gradient_checkpointing(self):
118
+ for i in range(len(self.backbone.blocks)):
119
+ wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
120
+
121
+ def enable_pytorch_native_sdpa(self):
122
+ for i in range(len(self.backbone.blocks)):
123
+ wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
124
+
125
+ def forward(self, image: torch.Tensor, token_rows: Union[int, torch.LongTensor], token_cols: Union[int, torch.LongTensor], return_class_token: bool = False) -> Tuple[torch.Tensor, torch.Tensor]:
126
+ image_14 = F.interpolate(image, (token_rows * 14, token_cols * 14), mode="bilinear", align_corners=False, antialias=not self.onnx_compatible_mode)
127
+ image_14 = (image_14 - self.image_mean) / self.image_std
128
+
129
+ # Get intermediate layers from the backbone
130
+ features = self.backbone.get_intermediate_layers(image_14, n=self.intermediate_layers, return_class_token=True)
131
+
132
+ # Project features to the desired dimensionality
133
+ x = torch.stack([
134
+ proj(feat.permute(0, 2, 1).unflatten(2, (token_rows, token_cols)).contiguous())
135
+ for proj, (feat, clstoken) in zip(self.output_projections, features)
136
+ ], dim=1).sum(dim=1)
137
+
138
+ if return_class_token:
139
+ return x, features[-1][1]
140
+ else:
141
+ return x
142
+
143
+
144
+ class Resampler(nn.Sequential):
145
+ def __init__(self,
146
+ in_channels: int,
147
+ out_channels: int,
148
+ type_: Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'],
149
+ scale_factor: int = 2,
150
+ ):
151
+ if type_ == 'pixel_shuffle':
152
+ nn.Sequential.__init__(self,
153
+ nn.Conv2d(in_channels, out_channels * (scale_factor ** 2), kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
154
+ nn.PixelShuffle(scale_factor),
155
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
156
+ )
157
+ for i in range(1, scale_factor ** 2):
158
+ self[0].weight.data[i::scale_factor ** 2] = self[0].weight.data[0::scale_factor ** 2]
159
+ self[0].bias.data[i::scale_factor ** 2] = self[0].bias.data[0::scale_factor ** 2]
160
+ elif type_ in ['nearest', 'bilinear']:
161
+ nn.Sequential.__init__(self,
162
+ nn.Upsample(scale_factor=scale_factor, mode=type_, align_corners=False if type_ == 'bilinear' else None),
163
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
164
+ )
165
+ elif type_ == 'conv_transpose':
166
+ nn.Sequential.__init__(self,
167
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=scale_factor, stride=scale_factor),
168
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
169
+ )
170
+ self[0].weight.data[:] = self[0].weight.data[:, :, :1, :1]
171
+ elif type_ == 'pixel_unshuffle':
172
+ nn.Sequential.__init__(self,
173
+ nn.PixelUnshuffle(scale_factor),
174
+ nn.Conv2d(in_channels * (scale_factor ** 2), out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
175
+ )
176
+ elif type_ == 'avg_pool':
177
+ nn.Sequential.__init__(self,
178
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
179
+ nn.AvgPool2d(kernel_size=scale_factor, stride=scale_factor),
180
+ )
181
+ elif type_ == 'max_pool':
182
+ nn.Sequential.__init__(self,
183
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
184
+ nn.MaxPool2d(kernel_size=scale_factor, stride=scale_factor),
185
+ )
186
+ else:
187
+ raise ValueError(f'Unsupported resampler type: {type_}')
188
+
189
+ class MLP(nn.Sequential):
190
+ def __init__(self, dims: Sequence[int]):
191
+ nn.Sequential.__init__(self,
192
+ *itertools.chain(*[
193
+ (nn.Linear(dim_in, dim_out), nn.ReLU(inplace=True))
194
+ for dim_in, dim_out in zip(dims[:-2], dims[1:-1])
195
+ ]),
196
+ nn.Linear(dims[-2], dims[-1]),
197
+ )
198
+
199
+
200
+ class ConvStack(nn.Module):
201
+ def __init__(self,
202
+ dim_in: List[Optional[int]],
203
+ dim_res_blocks: List[int],
204
+ dim_out: List[Optional[int]],
205
+ resamplers: Union[Literal['pixel_shuffle', 'nearest', 'bilinear', 'conv_transpose', 'pixel_unshuffle', 'avg_pool', 'max_pool'], List],
206
+ dim_times_res_block_hidden: int = 1,
207
+ num_res_blocks: int = 1,
208
+ res_block_in_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'layer_norm',
209
+ res_block_hidden_norm: Literal['layer_norm', 'group_norm' , 'instance_norm', 'none'] = 'group_norm',
210
+ activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu',
211
+ ):
212
+ super().__init__()
213
+ self.input_blocks = nn.ModuleList([
214
+ nn.Conv2d(dim_in_, dim_res_block_, kernel_size=1, stride=1, padding=0) if dim_in_ is not None else nn.Identity()
215
+ for dim_in_, dim_res_block_ in zip(dim_in if isinstance(dim_in, Sequence) else itertools.repeat(dim_in), dim_res_blocks)
216
+ ])
217
+ self.resamplers = nn.ModuleList([
218
+ Resampler(dim_prev, dim_succ, scale_factor=2, type_=resampler)
219
+ for i, (dim_prev, dim_succ, resampler) in enumerate(zip(
220
+ dim_res_blocks[:-1],
221
+ dim_res_blocks[1:],
222
+ resamplers if isinstance(resamplers, Sequence) else itertools.repeat(resamplers)
223
+ ))
224
+ ])
225
+ self.res_blocks = nn.ModuleList([
226
+ nn.Sequential(
227
+ *(
228
+ ResidualConvBlock(
229
+ dim_res_block_, dim_res_block_, dim_times_res_block_hidden * dim_res_block_,
230
+ activation=activation, in_norm=res_block_in_norm, hidden_norm=res_block_hidden_norm
231
+ ) for _ in range(num_res_blocks[i] if isinstance(num_res_blocks, list) else num_res_blocks)
232
+ )
233
+ ) for i, dim_res_block_ in enumerate(dim_res_blocks)
234
+ ])
235
+ self.output_blocks = nn.ModuleList([
236
+ nn.Conv2d(dim_res_block_, dim_out_, kernel_size=1, stride=1, padding=0) if dim_out_ is not None else nn.Identity()
237
+ for dim_out_, dim_res_block_ in zip(dim_out if isinstance(dim_out, Sequence) else itertools.repeat(dim_out), dim_res_blocks)
238
+ ])
239
+
240
+ def enable_gradient_checkpointing(self):
241
+ for i in range(len(self.resamplers)):
242
+ self.resamplers[i] = wrap_module_with_gradient_checkpointing(self.resamplers[i])
243
+ for i in range(len(self.res_blocks)):
244
+ for j in range(len(self.res_blocks[i])):
245
+ self.res_blocks[i][j] = wrap_module_with_gradient_checkpointing(self.res_blocks[i][j])
246
+
247
+ def forward(self, in_features: List[torch.Tensor]):
248
+ out_features = []
249
+ for i in range(len(self.res_blocks)):
250
+ feature = self.input_blocks[i](in_features[i])
251
+ if i == 0:
252
+ x = feature
253
+ elif feature is not None:
254
+ x = x + feature
255
+ x = self.res_blocks[i](x)
256
+ out_features.append(self.output_blocks[i](x))
257
+ if i < len(self.res_blocks) - 1:
258
+ x = self.resamplers[i](x)
259
+ return out_features
moge/model/transforms.py ADDED
@@ -0,0 +1,1344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+ from numbers import Number
8
+
9
+ import torch
10
+ import torch.nn.functional as F
11
+
12
+ import inspect
13
+ from functools import wraps
14
+
15
+ import warnings
16
+
17
+ def suppress_traceback(fn):
18
+ @wraps(fn)
19
+ def wrapper(*args, **kwargs):
20
+ try:
21
+ return fn(*args, **kwargs)
22
+ except Exception as e:
23
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
24
+ raise
25
+ return wrapper
26
+
27
+
28
+ class no_warnings:
29
+ def __init__(self, action: str = 'ignore', **kwargs):
30
+ self.action = action
31
+ self.filter_kwargs = kwargs
32
+
33
+ def __call__(self, fn):
34
+ @wraps(fn)
35
+ def wrapper(*args, **kwargs):
36
+ with warnings.catch_warnings():
37
+ warnings.simplefilter(self.action, **self.filter_kwargs)
38
+ return fn(*args, **kwargs)
39
+ return wrapper
40
+
41
+ def __enter__(self):
42
+ self.warnings_manager = warnings.catch_warnings()
43
+ self.warnings_manager.__enter__()
44
+ warnings.simplefilter(self.action, **self.filter_kwargs)
45
+
46
+ def __exit__(self, exc_type, exc_val, exc_tb):
47
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
48
+
49
+
50
+ def get_device(args, kwargs):
51
+ device = None
52
+ for arg in (list(args) + list(kwargs.values())):
53
+ if isinstance(arg, torch.Tensor):
54
+ if device is None:
55
+ device = arg.device
56
+ elif device != arg.device:
57
+ raise ValueError("All tensors must be on the same device.")
58
+ return device
59
+
60
+
61
+ def get_args_order(func, args, kwargs):
62
+ """
63
+ Get the order of the arguments of a function.
64
+ """
65
+ names = inspect.getfullargspec(func).args
66
+ names_idx = {name: i for i, name in enumerate(names)}
67
+ args_order = []
68
+ kwargs_order = {}
69
+ for name, arg in kwargs.items():
70
+ if name in names:
71
+ kwargs_order[name] = names_idx[name]
72
+ names.remove(name)
73
+ for i, arg in enumerate(args):
74
+ if i < len(names):
75
+ args_order.append(names_idx[names[i]])
76
+ return args_order, kwargs_order
77
+
78
+
79
+ def broadcast_args(args, kwargs, args_dim, kwargs_dim):
80
+ spatial = []
81
+ for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
82
+ if isinstance(arg, torch.Tensor) and arg_dim is not None:
83
+ arg_spatial = arg.shape[:arg.ndim-arg_dim]
84
+ if len(arg_spatial) > len(spatial):
85
+ spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
86
+ for j in range(len(arg_spatial)):
87
+ if spatial[-j] < arg_spatial[-j]:
88
+ if spatial[-j] == 1:
89
+ spatial[-j] = arg_spatial[-j]
90
+ else:
91
+ raise ValueError("Cannot broadcast arguments.")
92
+ for i, arg in enumerate(args):
93
+ if isinstance(arg, torch.Tensor) and args_dim[i] is not None:
94
+ args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
95
+ for key, arg in kwargs.items():
96
+ if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
97
+ kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
98
+ return args, kwargs, spatial
99
+
100
+ @suppress_traceback
101
+ def batched(*dims):
102
+ """
103
+ Decorator that allows a function to be called with batched arguments.
104
+ """
105
+ def decorator(func):
106
+ @wraps(func)
107
+ def wrapper(*args, device=torch.device('cpu'), **kwargs):
108
+ args = list(args)
109
+ # get arguments dimensions
110
+ args_order, kwargs_order = get_args_order(func, args, kwargs)
111
+ args_dim = [dims[i] for i in args_order]
112
+ kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
113
+ # convert to torch tensor
114
+ device = get_device(args, kwargs) or device
115
+ for i, arg in enumerate(args):
116
+ if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
117
+ args[i] = torch.tensor(arg, device=device)
118
+ for key, arg in kwargs.items():
119
+ if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
120
+ kwargs[key] = torch.tensor(arg, device=device)
121
+ # broadcast arguments
122
+ args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
123
+ for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
124
+ if isinstance(arg, torch.Tensor) and arg_dim is not None:
125
+ args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
126
+ for key, arg in kwargs.items():
127
+ if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None:
128
+ kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
129
+ # call function
130
+ results = func(*args, **kwargs)
131
+ type_results = type(results)
132
+ results = list(results) if isinstance(results, (tuple, list)) else [results]
133
+ # restore spatial dimensions
134
+ for i, result in enumerate(results):
135
+ results[i] = result.reshape([*spatial, *result.shape[1:]])
136
+ if type_results == tuple:
137
+ results = tuple(results)
138
+ elif type_results == list:
139
+ results = list(results)
140
+ else:
141
+ results = results[0]
142
+ return results
143
+ return wrapper
144
+ return decorator
145
+
146
+ __all__ = [
147
+ 'perspective',
148
+ 'perspective_from_fov',
149
+ 'perspective_from_fov_xy',
150
+ 'intrinsics_from_focal_center',
151
+ 'intrinsics_from_fov',
152
+ 'intrinsics_from_fov_xy',
153
+ 'focal_to_fov',
154
+ 'fov_to_focal',
155
+ 'intrinsics_to_fov',
156
+ 'view_look_at',
157
+ 'extrinsics_look_at',
158
+ 'perspective_to_intrinsics',
159
+ 'intrinsics_to_perspective',
160
+ 'extrinsics_to_view',
161
+ 'view_to_extrinsics',
162
+ 'normalize_intrinsics',
163
+ 'crop_intrinsics',
164
+ 'pixel_to_uv',
165
+ 'pixel_to_ndc',
166
+ 'uv_to_pixel',
167
+ 'project_depth',
168
+ 'depth_buffer_to_linear',
169
+ 'project_gl',
170
+ 'project_cv',
171
+ 'unproject_gl',
172
+ 'unproject_cv',
173
+ 'skew_symmetric',
174
+ 'rotation_matrix_from_vectors',
175
+ 'euler_axis_angle_rotation',
176
+ 'euler_angles_to_matrix',
177
+ 'matrix_to_euler_angles',
178
+ 'matrix_to_quaternion',
179
+ 'quaternion_to_matrix',
180
+ 'matrix_to_axis_angle',
181
+ 'axis_angle_to_matrix',
182
+ 'axis_angle_to_quaternion',
183
+ 'quaternion_to_axis_angle',
184
+ 'slerp',
185
+ 'interpolate_extrinsics',
186
+ 'interpolate_view',
187
+ 'extrinsics_to_essential',
188
+ 'to4x4',
189
+ 'rotation_matrix_2d',
190
+ 'rotate_2d',
191
+ 'translate_2d',
192
+ 'scale_2d',
193
+ 'apply_2d',
194
+ ]
195
+
196
+
197
+ @batched(0,0,0,0)
198
+ def perspective(
199
+ fov_y: Union[float, torch.Tensor],
200
+ aspect: Union[float, torch.Tensor],
201
+ near: Union[float, torch.Tensor],
202
+ far: Union[float, torch.Tensor]
203
+ ) -> torch.Tensor:
204
+ """
205
+ Get OpenGL perspective matrix
206
+
207
+ Args:
208
+ fov_y (float | torch.Tensor): field of view in y axis
209
+ aspect (float | torch.Tensor): aspect ratio
210
+ near (float | torch.Tensor): near plane to clip
211
+ far (float | torch.Tensor): far plane to clip
212
+
213
+ Returns:
214
+ (torch.Tensor): [..., 4, 4] perspective matrix
215
+ """
216
+ N = fov_y.shape[0]
217
+ ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device)
218
+ ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect)
219
+ ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2))
220
+ ret[:, 2, 2] = (near + far) / (near - far)
221
+ ret[:, 2, 3] = 2. * near * far / (near - far)
222
+ ret[:, 3, 2] = -1.
223
+ return ret
224
+
225
+
226
+ def perspective_from_fov(
227
+ fov: Union[float, torch.Tensor],
228
+ width: Union[int, torch.Tensor],
229
+ height: Union[int, torch.Tensor],
230
+ near: Union[float, torch.Tensor],
231
+ far: Union[float, torch.Tensor]
232
+ ) -> torch.Tensor:
233
+ """
234
+ Get OpenGL perspective matrix from field of view in largest dimension
235
+
236
+ Args:
237
+ fov (float | torch.Tensor): field of view in largest dimension
238
+ width (int | torch.Tensor): image width
239
+ height (int | torch.Tensor): image height
240
+ near (float | torch.Tensor): near plane to clip
241
+ far (float | torch.Tensor): far plane to clip
242
+
243
+ Returns:
244
+ (torch.Tensor): [..., 4, 4] perspective matrix
245
+ """
246
+ fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height))
247
+ aspect = width / height
248
+ return perspective(fov_y, aspect, near, far)
249
+
250
+
251
+ def perspective_from_fov_xy(
252
+ fov_x: Union[float, torch.Tensor],
253
+ fov_y: Union[float, torch.Tensor],
254
+ near: Union[float, torch.Tensor],
255
+ far: Union[float, torch.Tensor]
256
+ ) -> torch.Tensor:
257
+ """
258
+ Get OpenGL perspective matrix from field of view in x and y axis
259
+
260
+ Args:
261
+ fov_x (float | torch.Tensor): field of view in x axis
262
+ fov_y (float | torch.Tensor): field of view in y axis
263
+ near (float | torch.Tensor): near plane to clip
264
+ far (float | torch.Tensor): far plane to clip
265
+
266
+ Returns:
267
+ (torch.Tensor): [..., 4, 4] perspective matrix
268
+ """
269
+ aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2)
270
+ return perspective(fov_y, aspect, near, far)
271
+
272
+
273
+ @batched(0,0,0,0)
274
+ def intrinsics_from_focal_center(
275
+ fx: Union[float, torch.Tensor],
276
+ fy: Union[float, torch.Tensor],
277
+ cx: Union[float, torch.Tensor],
278
+ cy: Union[float, torch.Tensor]
279
+ ) -> torch.Tensor:
280
+ """
281
+ Get OpenCV intrinsics matrix
282
+
283
+ Args:
284
+ focal_x (float | torch.Tensor): focal length in x axis
285
+ focal_y (float | torch.Tensor): focal length in y axis
286
+ cx (float | torch.Tensor): principal point in x axis
287
+ cy (float | torch.Tensor): principal point in y axis
288
+
289
+ Returns:
290
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
291
+ """
292
+ N = fx.shape[0]
293
+ ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device)
294
+ zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device)
295
+ ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3))
296
+ return ret
297
+
298
+
299
+ @batched(0, 0, 0, 0, 0, 0)
300
+ def intrinsics_from_fov(
301
+ fov_max: Union[float, torch.Tensor] = None,
302
+ fov_min: Union[float, torch.Tensor] = None,
303
+ fov_x: Union[float, torch.Tensor] = None,
304
+ fov_y: Union[float, torch.Tensor] = None,
305
+ width: Union[int, torch.Tensor] = None,
306
+ height: Union[int, torch.Tensor] = None,
307
+ ) -> torch.Tensor:
308
+ """
309
+ Get normalized OpenCV intrinsics matrix from given field of view.
310
+ You can provide either fov_max, fov_min, fov_x or fov_y
311
+
312
+ Args:
313
+ width (int | torch.Tensor): image width
314
+ height (int | torch.Tensor): image height
315
+ fov_max (float | torch.Tensor): field of view in largest dimension
316
+ fov_min (float | torch.Tensor): field of view in smallest dimension
317
+ fov_x (float | torch.Tensor): field of view in x axis
318
+ fov_y (float | torch.Tensor): field of view in y axis
319
+
320
+ Returns:
321
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
322
+ """
323
+ if fov_max is not None:
324
+ fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2))
325
+ fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2))
326
+ elif fov_min is not None:
327
+ fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2))
328
+ fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2))
329
+ elif fov_x is not None and fov_y is not None:
330
+ fx = 1 / (2 * torch.tan(fov_x / 2))
331
+ fy = 1 / (2 * torch.tan(fov_y / 2))
332
+ elif fov_x is not None:
333
+ fx = 1 / (2 * torch.tan(fov_x / 2))
334
+ fy = fx * width / height
335
+ elif fov_y is not None:
336
+ fy = 1 / (2 * torch.tan(fov_y / 2))
337
+ fx = fy * height / width
338
+ cx = 0.5
339
+ cy = 0.5
340
+ ret = intrinsics_from_focal_center(fx, fy, cx, cy)
341
+ return ret
342
+
343
+
344
+
345
+ def intrinsics_from_fov_xy(
346
+ fov_x: Union[float, torch.Tensor],
347
+ fov_y: Union[float, torch.Tensor]
348
+ ) -> torch.Tensor:
349
+ """
350
+ Get OpenCV intrinsics matrix from field of view in x and y axis
351
+
352
+ Args:
353
+ fov_x (float | torch.Tensor): field of view in x axis
354
+ fov_y (float | torch.Tensor): field of view in y axis
355
+
356
+ Returns:
357
+ (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
358
+ """
359
+ focal_x = 0.5 / torch.tan(fov_x / 2)
360
+ focal_y = 0.5 / torch.tan(fov_y / 2)
361
+ cx = cy = 0.5
362
+ return intrinsics_from_focal_center(focal_x, focal_y, cx, cy)
363
+
364
+
365
+ def focal_to_fov(focal: torch.Tensor):
366
+ return 2 * torch.atan(0.5 / focal)
367
+
368
+
369
+ def fov_to_focal(fov: torch.Tensor):
370
+ return 0.5 / torch.tan(fov / 2)
371
+
372
+
373
+ def intrinsics_to_fov(intrinsics: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
374
+ "NOTE: approximate FOV by assuming centered principal point"
375
+ fov_x = focal_to_fov(intrinsics[..., 0, 0])
376
+ fov_y = focal_to_fov(intrinsics[..., 1, 1])
377
+ return fov_x, fov_y
378
+
379
+
380
+ @batched(1,1,1)
381
+ def view_look_at(
382
+ eye: torch.Tensor,
383
+ look_at: torch.Tensor,
384
+ up: torch.Tensor
385
+ ) -> torch.Tensor:
386
+ """
387
+ Get OpenGL view matrix looking at something
388
+
389
+ Args:
390
+ eye (torch.Tensor): [..., 3] the eye position
391
+ look_at (torch.Tensor): [..., 3] the position to look at
392
+ up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction
393
+
394
+ Returns:
395
+ (torch.Tensor): [..., 4, 4], view matrix
396
+ """
397
+ N = eye.shape[0]
398
+ z = eye - look_at
399
+ x = torch.cross(up, z, dim=-1)
400
+ y = torch.cross(z, x, dim=-1)
401
+ # x = torch.cross(y, z, dim=-1)
402
+ x = x / x.norm(dim=-1, keepdim=True)
403
+ y = y / y.norm(dim=-1, keepdim=True)
404
+ z = z / z.norm(dim=-1, keepdim=True)
405
+ R = torch.stack([x, y, z], dim=-2)
406
+ t = -torch.matmul(R, eye[..., None])
407
+ ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
408
+ ret[:, :3, :3] = R
409
+ ret[:, :3, 3] = t[:, :, 0]
410
+ ret[:, 3, 3] = 1.
411
+ return ret
412
+
413
+
414
+ @batched(1, 1, 1)
415
+ def extrinsics_look_at(
416
+ eye: torch.Tensor,
417
+ look_at: torch.Tensor,
418
+ up: torch.Tensor
419
+ ) -> torch.Tensor:
420
+ """
421
+ Get OpenCV extrinsics matrix looking at something
422
+
423
+ Args:
424
+ eye (torch.Tensor): [..., 3] the eye position
425
+ look_at (torch.Tensor): [..., 3] the position to look at
426
+ up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction
427
+
428
+ Returns:
429
+ (torch.Tensor): [..., 4, 4], extrinsics matrix
430
+ """
431
+ N = eye.shape[0]
432
+ z = look_at - eye
433
+ x = torch.cross(-up, z, dim=-1)
434
+ y = torch.cross(z, x, dim=-1)
435
+ # x = torch.cross(y, z, dim=-1)
436
+ x = x / x.norm(dim=-1, keepdim=True)
437
+ y = y / y.norm(dim=-1, keepdim=True)
438
+ z = z / z.norm(dim=-1, keepdim=True)
439
+ R = torch.stack([x, y, z], dim=-2)
440
+ t = -torch.matmul(R, eye[..., None])
441
+ ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device)
442
+ ret[:, :3, :3] = R
443
+ ret[:, :3, 3] = t[:, :, 0]
444
+ ret[:, 3, 3] = 1.
445
+ return ret
446
+
447
+
448
+ @batched(2)
449
+ def perspective_to_intrinsics(
450
+ perspective: torch.Tensor
451
+ ) -> torch.Tensor:
452
+ """
453
+ OpenGL perspective matrix to OpenCV intrinsics
454
+
455
+ Args:
456
+ perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
457
+
458
+ Returns:
459
+ (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics
460
+ """
461
+ assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix"
462
+ ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \
463
+ @ perspective[:, [0, 1, 3], :3] \
464
+ @ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device))
465
+ return ret / ret[:, 2, 2, None, None]
466
+
467
+
468
+ @batched(2,0,0)
469
+ def intrinsics_to_perspective(
470
+ intrinsics: torch.Tensor,
471
+ near: Union[float, torch.Tensor],
472
+ far: Union[float, torch.Tensor],
473
+ ) -> torch.Tensor:
474
+ """
475
+ OpenCV intrinsics to OpenGL perspective matrix
476
+
477
+ Args:
478
+ intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix
479
+ near (float | torch.Tensor): [...] near plane to clip
480
+ far (float | torch.Tensor): [...] far plane to clip
481
+ Returns:
482
+ (torch.Tensor): [..., 4, 4] OpenGL perspective matrix
483
+ """
484
+ N = intrinsics.shape[0]
485
+ fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1]
486
+ cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2]
487
+ ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device)
488
+ ret[:, 0, 0] = 2 * fx
489
+ ret[:, 1, 1] = 2 * fy
490
+ ret[:, 0, 2] = -2 * cx + 1
491
+ ret[:, 1, 2] = 2 * cy - 1
492
+ ret[:, 2, 2] = (near + far) / (near - far)
493
+ ret[:, 2, 3] = 2. * near * far / (near - far)
494
+ ret[:, 3, 2] = -1.
495
+ return ret
496
+
497
+
498
+ @batched(2)
499
+ def extrinsics_to_view(
500
+ extrinsics: torch.Tensor
501
+ ) -> torch.Tensor:
502
+ """
503
+ OpenCV camera extrinsics to OpenGL view matrix
504
+
505
+ Args:
506
+ extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
507
+
508
+ Returns:
509
+ (torch.Tensor): [..., 4, 4] OpenGL view matrix
510
+ """
511
+ return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None]
512
+
513
+
514
+ @batched(2)
515
+ def view_to_extrinsics(
516
+ view: torch.Tensor
517
+ ) -> torch.Tensor:
518
+ """
519
+ OpenGL view matrix to OpenCV camera extrinsics
520
+
521
+ Args:
522
+ view (torch.Tensor): [..., 4, 4] OpenGL view matrix
523
+
524
+ Returns:
525
+ (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix
526
+ """
527
+ return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None]
528
+
529
+
530
+ @batched(2,0,0)
531
+ def normalize_intrinsics(
532
+ intrinsics: torch.Tensor,
533
+ width: Union[int, torch.Tensor],
534
+ height: Union[int, torch.Tensor]
535
+ ) -> torch.Tensor:
536
+ """
537
+ Normalize camera intrinsics(s) to uv space
538
+
539
+ Args:
540
+ intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize
541
+ width (int | torch.Tensor): [...] image width(s)
542
+ height (int | torch.Tensor): [...] image height(s)
543
+
544
+ Returns:
545
+ (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s)
546
+ """
547
+ zeros = torch.zeros_like(width)
548
+ ones = torch.ones_like(width)
549
+ transform = torch.stack([
550
+ 1 / width, zeros, 0.5 / width,
551
+ zeros, 1 / height, 0.5 / height,
552
+ zeros, zeros, ones
553
+ ]).reshape(*zeros.shape, 3, 3).to(intrinsics)
554
+ return transform @ intrinsics
555
+
556
+
557
+
558
+ @batched(2,0,0,0,0,0,0)
559
+ def crop_intrinsics(
560
+ intrinsics: torch.Tensor,
561
+ width: Union[int, torch.Tensor],
562
+ height: Union[int, torch.Tensor],
563
+ left: Union[int, torch.Tensor],
564
+ top: Union[int, torch.Tensor],
565
+ crop_width: Union[int, torch.Tensor],
566
+ crop_height: Union[int, torch.Tensor]
567
+ ) -> torch.Tensor:
568
+ """
569
+ Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width]
570
+
571
+ Args:
572
+ intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop
573
+ width (int | torch.Tensor): [...] image width(s)
574
+ height (int | torch.Tensor): [...] image height(s)
575
+ left (int | torch.Tensor): [...] left crop boundary
576
+ top (int | torch.Tensor): [...] top crop boundary
577
+ crop_width (int | torch.Tensor): [...] crop width
578
+ crop_height (int | torch.Tensor): [...] crop height
579
+
580
+ Returns:
581
+ (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s)
582
+ """
583
+ zeros = torch.zeros_like(width)
584
+ ones = torch.ones_like(width)
585
+ transform = torch.stack([
586
+ width / crop_width, zeros, -left / crop_width,
587
+ zeros, height / crop_height, -top / crop_height,
588
+ zeros, zeros, ones
589
+ ]).reshape(*zeros.shape, 3, 3).to(intrinsics)
590
+ return transform @ intrinsics
591
+
592
+
593
+ @batched(1,0,0)
594
+ def pixel_to_uv(
595
+ pixel: torch.Tensor,
596
+ width: Union[int, torch.Tensor],
597
+ height: Union[int, torch.Tensor]
598
+ ) -> torch.Tensor:
599
+ """
600
+ Args:
601
+ pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
602
+ width (int | torch.Tensor): [...] image width(s)
603
+ height (int | torch.Tensor): [...] image height(s)
604
+
605
+ Returns:
606
+ (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
607
+ """
608
+ if not torch.is_floating_point(pixel):
609
+ pixel = pixel.float()
610
+ uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel)
611
+ return uv
612
+
613
+
614
+ @batched(1,0,0)
615
+ def uv_to_pixel(
616
+ uv: torch.Tensor,
617
+ width: Union[int, torch.Tensor],
618
+ height: Union[int, torch.Tensor]
619
+ ) -> torch.Tensor:
620
+ """
621
+ Args:
622
+ uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
623
+ width (int | torch.Tensor): [...] image width(s)
624
+ height (int | torch.Tensor): [...] image height(s)
625
+
626
+ Returns:
627
+ (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1)
628
+ """
629
+ pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5
630
+ return pixel
631
+
632
+
633
+ @batched(1,0,0)
634
+ def pixel_to_ndc(
635
+ pixel: torch.Tensor,
636
+ width: Union[int, torch.Tensor],
637
+ height: Union[int, torch.Tensor]
638
+ ) -> torch.Tensor:
639
+ """
640
+ Args:
641
+ pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1)
642
+ width (int | torch.Tensor): [...] image width(s)
643
+ height (int | torch.Tensor): [...] image height(s)
644
+
645
+ Returns:
646
+ (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1)
647
+ """
648
+ if not torch.is_floating_point(pixel):
649
+ pixel = pixel.float()
650
+ ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \
651
+ + torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device)
652
+ return ndc
653
+
654
+
655
+ @batched(0,0,0)
656
+ def project_depth(
657
+ depth: torch.Tensor,
658
+ near: Union[float, torch.Tensor],
659
+ far: Union[float, torch.Tensor]
660
+ ) -> torch.Tensor:
661
+ """
662
+ Project linear depth to depth value in screen space
663
+
664
+ Args:
665
+ depth (torch.Tensor): [...] depth value
666
+ near (float | torch.Tensor): [...] near plane to clip
667
+ far (float | torch.Tensor): [...] far plane to clip
668
+
669
+ Returns:
670
+ (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1]
671
+ """
672
+ return (far - near * far / depth) / (far - near)
673
+
674
+
675
+ @batched(0,0,0)
676
+ def depth_buffer_to_linear(
677
+ depth: torch.Tensor,
678
+ near: Union[float, torch.Tensor],
679
+ far: Union[float, torch.Tensor]
680
+ ) -> torch.Tensor:
681
+ """
682
+ Linearize depth value to linear depth
683
+
684
+ Args:
685
+ depth (torch.Tensor): [...] screen depth value, ranging in [0, 1]
686
+ near (float | torch.Tensor): [...] near plane to clip
687
+ far (float | torch.Tensor): [...] far plane to clip
688
+
689
+ Returns:
690
+ (torch.Tensor): [...] linear depth
691
+ """
692
+ return near * far / (far - (far - near) * depth)
693
+
694
+
695
+ @batched(2, 2, 2, 2)
696
+ def project_gl(
697
+ points: torch.Tensor,
698
+ model: torch.Tensor = None,
699
+ view: torch.Tensor = None,
700
+ perspective: torch.Tensor = None
701
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
702
+ """
703
+ Project 3D points to 2D following the OpenGL convention (except for row major matrice)
704
+
705
+ Args:
706
+ points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last
707
+ dimension is 4, the points are assumed to be in homogeneous coordinates
708
+ model (torch.Tensor): [..., 4, 4] model matrix
709
+ view (torch.Tensor): [..., 4, 4] view matrix
710
+ perspective (torch.Tensor): [..., 4, 4] perspective matrix
711
+
712
+ Returns:
713
+ scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1].
714
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
715
+ linear_depth (torch.Tensor): [..., N] linear depth
716
+ """
717
+ assert perspective is not None, "perspective matrix is required"
718
+
719
+ if points.shape[-1] == 3:
720
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
721
+ mvp = perspective if perspective is not None else torch.eye(4).to(points)
722
+ if view is not None:
723
+ mvp = mvp @ view
724
+ if model is not None:
725
+ mvp = mvp @ model
726
+ clip_coord = points @ mvp.transpose(-1, -2)
727
+ ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:]
728
+ scr_coord = ndc_coord * 0.5 + 0.5
729
+ linear_depth = clip_coord[..., 3]
730
+ return scr_coord, linear_depth
731
+
732
+
733
+ @batched(2, 2, 2)
734
+ def project_cv(
735
+ points: torch.Tensor,
736
+ extrinsics: torch.Tensor = None,
737
+ intrinsics: torch.Tensor = None
738
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
739
+ """
740
+ Project 3D points to 2D following the OpenCV convention
741
+
742
+ Args:
743
+ points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last
744
+ dimension is 4, the points are assumed to be in homogeneous coordinates
745
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
746
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
747
+
748
+ Returns:
749
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
750
+ The origin (0., 0.) is corresponding to the left & top
751
+ linear_depth (torch.Tensor): [..., N] linear depth
752
+ """
753
+ assert intrinsics is not None, "intrinsics matrix is required"
754
+ if points.shape[-1] == 3:
755
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
756
+ if extrinsics is not None:
757
+ points = points @ extrinsics.transpose(-1, -2)
758
+ points = points[..., :3] @ intrinsics.transpose(-2, -1)
759
+ uv_coord = points[..., :2] / points[..., 2:]
760
+ linear_depth = points[..., 2]
761
+ return uv_coord, linear_depth
762
+
763
+
764
+ @batched(2, 2, 2, 2)
765
+ def unproject_gl(
766
+ screen_coord: torch.Tensor,
767
+ model: torch.Tensor = None,
768
+ view: torch.Tensor = None,
769
+ perspective: torch.Tensor = None
770
+ ) -> torch.Tensor:
771
+ """
772
+ Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice)
773
+
774
+ Args:
775
+ screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1].
776
+ The origin (0., 0., 0.) is corresponding to the left & bottom & nearest
777
+ model (torch.Tensor): [..., 4, 4] model matrix
778
+ view (torch.Tensor): [..., 4, 4] view matrix
779
+ perspective (torch.Tensor): [..., 4, 4] perspective matrix
780
+
781
+ Returns:
782
+ points (torch.Tensor): [..., N, 3] 3d points
783
+ """
784
+ assert perspective is not None, "perspective matrix is required"
785
+ ndc_xy = screen_coord * 2 - 1
786
+ clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1)
787
+ transform = perspective
788
+ if view is not None:
789
+ transform = transform @ view
790
+ if model is not None:
791
+ transform = transform @ model
792
+ transform = torch.inverse(transform)
793
+ points = clip_coord @ transform.transpose(-1, -2)
794
+ points = points[..., :3] / points[..., 3:]
795
+ return points
796
+
797
+
798
+ @batched(2, 1, 2, 2)
799
+ def unproject_cv(
800
+ uv_coord: torch.Tensor,
801
+ depth: torch.Tensor = None,
802
+ extrinsics: torch.Tensor = None,
803
+ intrinsics: torch.Tensor = None
804
+ ) -> torch.Tensor:
805
+ """
806
+ Unproject uv coordinates to 3D view space following the OpenCV convention
807
+
808
+ Args:
809
+ uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1].
810
+ The origin (0., 0.) is corresponding to the left & top
811
+ depth (torch.Tensor): [..., N] depth value
812
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
813
+ intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix
814
+
815
+ Returns:
816
+ points (torch.Tensor): [..., N, 3] 3d points
817
+ """
818
+ assert intrinsics is not None, "intrinsics matrix is required"
819
+ points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1)
820
+ points = points @ torch.inverse(intrinsics).transpose(-2, -1)
821
+ if depth is not None:
822
+ points = points * depth[..., None]
823
+ if extrinsics is not None:
824
+ points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
825
+ points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3]
826
+ return points
827
+
828
+
829
+ def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
830
+ """
831
+ Return the rotation matrices for one of the rotations about an axis
832
+ of which Euler angles describe, for each value of the angle given.
833
+
834
+ Args:
835
+ axis: Axis label "X" or "Y or "Z".
836
+ angle: any shape tensor of Euler angles in radians
837
+
838
+ Returns:
839
+ Rotation matrices as tensor of shape (..., 3, 3).
840
+ """
841
+
842
+ cos = torch.cos(angle)
843
+ sin = torch.sin(angle)
844
+ one = torch.ones_like(angle)
845
+ zero = torch.zeros_like(angle)
846
+
847
+ if axis == "X":
848
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
849
+ elif axis == "Y":
850
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
851
+ elif axis == "Z":
852
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
853
+ else:
854
+ raise ValueError("letter must be either X, Y or Z.")
855
+
856
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
857
+
858
+
859
+ def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor:
860
+ """
861
+ Convert rotations given as Euler angles in radians to rotation matrices.
862
+
863
+ Args:
864
+ euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ
865
+ convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply.
866
+
867
+ Returns:
868
+ Rotation matrices as tensor of shape (..., 3, 3).
869
+ """
870
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
871
+ raise ValueError("Invalid input euler angles.")
872
+ if len(convention) != 3:
873
+ raise ValueError("Convention must have 3 letters.")
874
+ if convention[1] in (convention[0], convention[2]):
875
+ raise ValueError(f"Invalid convention {convention}.")
876
+ for letter in convention:
877
+ if letter not in ("X", "Y", "Z"):
878
+ raise ValueError(f"Invalid letter {letter} in convention string.")
879
+ matrices = [
880
+ euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)])
881
+ for c in convention
882
+ ]
883
+ # return functools.reduce(torch.matmul, matrices)
884
+ return matrices[2] @ matrices[1] @ matrices[0]
885
+
886
+
887
+ def skew_symmetric(v: torch.Tensor):
888
+ "Skew symmetric matrix from a 3D vector"
889
+ assert v.shape[-1] == 3, "v must be 3D"
890
+ x, y, z = v.unbind(dim=-1)
891
+ zeros = torch.zeros_like(x)
892
+ return torch.stack([
893
+ zeros, -z, y,
894
+ z, zeros, -x,
895
+ -y, x, zeros,
896
+ ], dim=-1).reshape(*v.shape[:-1], 3, 3)
897
+
898
+
899
+ def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor):
900
+ "Rotation matrix that rotates v1 to v2"
901
+ I = torch.eye(3).to(v1)
902
+ v1 = F.normalize(v1, dim=-1)
903
+ v2 = F.normalize(v2, dim=-1)
904
+ v = torch.cross(v1, v2, dim=-1)
905
+ c = torch.sum(v1 * v2, dim=-1)
906
+ K = skew_symmetric(v)
907
+ R = I + K + (1 / (1 + c))[None, None] * (K @ K)
908
+ return R
909
+
910
+
911
+ def _angle_from_tan(
912
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
913
+ ) -> torch.Tensor:
914
+ """
915
+ Extract the first or third Euler angle from the two members of
916
+ the matrix which are positive constant times its sine and cosine.
917
+
918
+ Args:
919
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
920
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
921
+ convention.
922
+ data: Rotation matrices as tensor of shape (..., 3, 3).
923
+ horizontal: Whether we are looking for the angle for the third axis,
924
+ which means the relevant entries are in the same row of the
925
+ rotation matrix. If not, they are in the same column.
926
+ tait_bryan: Whether the first and third axes in the convention differ.
927
+
928
+ Returns:
929
+ Euler Angles in radians for each matrix in data as a tensor
930
+ of shape (...).
931
+ """
932
+
933
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
934
+ if horizontal:
935
+ i2, i1 = i1, i2
936
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
937
+ if horizontal == even:
938
+ return torch.atan2(data[..., i1], data[..., i2])
939
+ if tait_bryan:
940
+ return torch.atan2(-data[..., i2], data[..., i1])
941
+ return torch.atan2(data[..., i2], -data[..., i1])
942
+
943
+
944
+ def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor:
945
+ """
946
+ Convert rotations given as rotation matrices to Euler angles in radians.
947
+ NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d)
948
+
949
+ Args:
950
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
951
+ convention: Convention string of three uppercase letters.
952
+
953
+ Returns:
954
+ Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d)
955
+ """
956
+ if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'):
957
+ raise ValueError(f"Invalid convention {convention}.")
958
+ if not matrix.shape[-2:] == (3, 3):
959
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
960
+
961
+ i0 = 'XYZ'.index(convention[0])
962
+ i2 = 'XYZ'.index(convention[2])
963
+ tait_bryan = i0 != i2
964
+ if tait_bryan:
965
+ central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0))
966
+ else:
967
+ central_angle = torch.acos(matrix[..., i2, i2])
968
+
969
+ # Angles in composition order
970
+ o = [
971
+ _angle_from_tan(
972
+ convention[0], convention[1], matrix[..., i2, :], True, tait_bryan
973
+ ),
974
+ central_angle,
975
+ _angle_from_tan(
976
+ convention[2], convention[1], matrix[..., i0], False, tait_bryan
977
+ ),
978
+ ]
979
+ return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1)
980
+
981
+
982
+ def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
983
+ """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation
984
+
985
+ Args:
986
+ axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
987
+
988
+ Returns:
989
+ torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters
990
+ """
991
+ batch_shape = axis_angle.shape[:-1]
992
+ device, dtype = axis_angle.device, axis_angle.dtype
993
+
994
+ angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True)
995
+ axis = axis_angle / angle
996
+
997
+ cos = torch.cos(angle)[..., None, :]
998
+ sin = torch.sin(angle)[..., None, :]
999
+
1000
+ rx, ry, rz = torch.split(axis, 3, dim=-1)
1001
+ zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device)
1002
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3))
1003
+
1004
+ ident = torch.eye(3, dtype=dtype, device=device)
1005
+ rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K)
1006
+ return rot_mat
1007
+
1008
+
1009
+ def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
1010
+ """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector)
1011
+
1012
+ Args:
1013
+ rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
1014
+
1015
+ Returns:
1016
+ torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices
1017
+ """
1018
+ quat = matrix_to_quaternion(rot_mat)
1019
+ axis_angle = quaternion_to_axis_angle(quat, eps=eps)
1020
+ return axis_angle
1021
+
1022
+
1023
+ def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
1024
+ """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector)
1025
+
1026
+ Args:
1027
+ quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
1028
+
1029
+ Returns:
1030
+ torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions
1031
+ """
1032
+ assert quaternion.shape[-1] == 4
1033
+ norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True)
1034
+ axis = quaternion[..., 1:] / norm.clamp(min=eps)
1035
+ angle = 2 * torch.atan2(norm, quaternion[..., 0:1])
1036
+ return angle * axis
1037
+
1038
+
1039
+ def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
1040
+ """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z)
1041
+
1042
+ Args:
1043
+ axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors
1044
+
1045
+ Returns:
1046
+ torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters
1047
+ """
1048
+ axis = F.normalize(axis_angle, dim=-1, eps=eps)
1049
+ angle = torch.norm(axis_angle, dim=-1, keepdim=True)
1050
+ quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1)
1051
+ return quat
1052
+
1053
+
1054
+ def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
1055
+ """Convert 3x3 rotation matrix to quaternion (w, x, y, z)
1056
+
1057
+ Args:
1058
+ rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert
1059
+
1060
+ Returns:
1061
+ torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices
1062
+ """
1063
+ # Extract the diagonal and off-diagonal elements of the rotation matrix
1064
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1)
1065
+
1066
+ diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1)
1067
+ M = torch.tensor([
1068
+ [1, 1, 1],
1069
+ [1, -1, -1],
1070
+ [-1, 1, -1],
1071
+ [-1, -1, 1]
1072
+ ], dtype=rot_mat.dtype, device=rot_mat.device)
1073
+ wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5)
1074
+ _, max_idx = wxyz.max(dim=-1)
1075
+ xw = torch.sign(m21 - m12)
1076
+ yw = torch.sign(m02 - m20)
1077
+ zw = torch.sign(m10 - m01)
1078
+ yz = torch.sign(m21 + m12)
1079
+ xz = torch.sign(m02 + m20)
1080
+ xy = torch.sign(m01 + m10)
1081
+ ones = torch.ones_like(xw)
1082
+ sign = torch.where(
1083
+ max_idx[..., None] == 0,
1084
+ torch.stack([ones, xw, yw, zw], dim=-1),
1085
+ torch.where(
1086
+ max_idx[..., None] == 1,
1087
+ torch.stack([xw, ones, xy, xz], dim=-1),
1088
+ torch.where(
1089
+ max_idx[..., None] == 2,
1090
+ torch.stack([yw, xy, ones, yz], dim=-1),
1091
+ torch.stack([zw, xz, yz, ones], dim=-1)
1092
+ )
1093
+ )
1094
+ )
1095
+ quat = sign * wxyz
1096
+ quat = F.normalize(quat, dim=-1, eps=eps)
1097
+ return quat
1098
+
1099
+
1100
+ def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor:
1101
+ """Converts a batch of quaternions (w, x, y, z) to rotation matrices
1102
+
1103
+ Args:
1104
+ quaternion (torch.Tensor): shape (..., 4), the quaternions to convert
1105
+
1106
+ Returns:
1107
+ torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions
1108
+ """
1109
+ assert quaternion.shape[-1] == 4
1110
+ quaternion = F.normalize(quaternion, dim=-1, eps=eps)
1111
+ w, x, y, z = quaternion.unbind(dim=-1)
1112
+ zeros = torch.zeros_like(w)
1113
+ I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device)
1114
+ xyz = quaternion[..., 1:]
1115
+ A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None]
1116
+ B = torch.stack([
1117
+ zeros, -z, y,
1118
+ z, zeros, -x,
1119
+ -y, x, zeros
1120
+ ], dim=-1).unflatten(-1, (3, 3))
1121
+ rot_mat = I + 2 * (A + w[..., None, None] * B)
1122
+ return rot_mat
1123
+
1124
+
1125
+ def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
1126
+ """Spherical linear interpolation between two rotation matrices
1127
+
1128
+ Args:
1129
+ rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix
1130
+ rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix
1131
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
1132
+
1133
+ Returns:
1134
+ torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix
1135
+ """
1136
+ assert rot_mat_1.shape[-2:] == (3, 3)
1137
+ rot_vec_1 = matrix_to_axis_angle(rot_mat_1)
1138
+ rot_vec_2 = matrix_to_axis_angle(rot_mat_2)
1139
+ if isinstance(t, Number):
1140
+ t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device)
1141
+ rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2
1142
+ rot_mat = axis_angle_to_matrix(rot_vec)
1143
+ return rot_mat
1144
+
1145
+
1146
+ def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor:
1147
+ """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
1148
+
1149
+ Args:
1150
+ ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
1151
+ ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
1152
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
1153
+
1154
+ Returns:
1155
+ torch.Tensor: shape (..., 4, 4), the interpolated camera pose
1156
+ """
1157
+ return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t))
1158
+
1159
+
1160
+ def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]):
1161
+ """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation.
1162
+
1163
+ Args:
1164
+ ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose
1165
+ ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose
1166
+ t (torch.Tensor): scalar or shape (...,), the interpolation factor
1167
+
1168
+ Returns:
1169
+ torch.Tensor: shape (..., 4, 4), the interpolated camera pose
1170
+ """
1171
+ return interpolate_extrinsics(view1, view2, t)
1172
+
1173
+
1174
+ def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]):
1175
+ assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4)
1176
+ if isinstance(t, Number):
1177
+ t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device)
1178
+ pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3]
1179
+ rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t)
1180
+ transform = torch.cat([rot, pos[..., None]], dim=-1)
1181
+ transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2)
1182
+ return transform
1183
+
1184
+
1185
+ def extrinsics_to_essential(extrinsics: torch.Tensor):
1186
+ """
1187
+ extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0`
1188
+
1189
+ Args:
1190
+ extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix
1191
+
1192
+ Returns:
1193
+ (torch.Tensor): [..., 3, 3] essential matrix
1194
+ """
1195
+ assert extrinsics.shape[-2:] == (4, 4)
1196
+ R = extrinsics[..., :3, :3]
1197
+ t = extrinsics[..., :3, 3]
1198
+ zeros = torch.zeros_like(t)
1199
+ t_x = torch.stack([
1200
+ zeros, -t[..., 2], t[..., 1],
1201
+ t[..., 2], zeros, -t[..., 0],
1202
+ -t[..., 1], t[..., 0], zeros
1203
+ ]).reshape(*t.shape[:-1], 3, 3)
1204
+ return R @ t_x
1205
+
1206
+
1207
+ def to4x4(R: torch.Tensor, t: torch.Tensor):
1208
+ """
1209
+ Compose rotation matrix and translation vector to 4x4 transformation matrix
1210
+
1211
+ Args:
1212
+ R (torch.Tensor): [..., 3, 3] rotation matrix
1213
+ t (torch.Tensor): [..., 3] translation vector
1214
+
1215
+ Returns:
1216
+ (torch.Tensor): [..., 4, 4] transformation matrix
1217
+ """
1218
+ assert R.shape[-2:] == (3, 3)
1219
+ assert t.shape[-1] == 3
1220
+ assert R.shape[:-2] == t.shape[:-1]
1221
+ return torch.cat([
1222
+ torch.cat([R, t[..., None]], dim=-1),
1223
+ torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4)
1224
+ ], dim=-2)
1225
+
1226
+
1227
+ def rotation_matrix_2d(theta: Union[float, torch.Tensor]):
1228
+ """
1229
+ 2x2 matrix for 2D rotation
1230
+
1231
+ Args:
1232
+ theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
1233
+
1234
+ Returns:
1235
+ (torch.Tensor): (..., 2, 2) rotation matrix
1236
+ """
1237
+ if isinstance(theta, float):
1238
+ theta = torch.tensor(theta)
1239
+ return torch.stack([
1240
+ torch.cos(theta), -torch.sin(theta),
1241
+ torch.sin(theta), torch.cos(theta),
1242
+ ], dim=-1).unflatten(-1, (2, 2))
1243
+
1244
+
1245
+ def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None):
1246
+ """
1247
+ 3x3 matrix for 2D rotation around a center
1248
+ ```
1249
+ [[Rxx, Rxy, tx],
1250
+ [Ryx, Ryy, ty],
1251
+ [0, 0, 1]]
1252
+ ```
1253
+ Args:
1254
+ theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,)
1255
+ center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0)
1256
+
1257
+ Returns:
1258
+ (torch.Tensor): (..., 3, 3) transformation matrix
1259
+ """
1260
+ if isinstance(theta, float):
1261
+ theta = torch.tensor(theta)
1262
+ if center is not None:
1263
+ theta = theta.to(center)
1264
+ if center is None:
1265
+ center = torch.zeros(2).to(theta).expand(*theta.shape, -1)
1266
+ R = rotation_matrix_2d(theta)
1267
+ return torch.cat([
1268
+ torch.cat([
1269
+ R,
1270
+ center[..., :, None] - R @ center[..., :, None],
1271
+ ], dim=-1),
1272
+ torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1),
1273
+ ], dim=-2)
1274
+
1275
+
1276
+ def translate_2d(translation: torch.Tensor):
1277
+ """
1278
+ Translation matrix for 2D translation
1279
+ ```
1280
+ [[1, 0, tx],
1281
+ [0, 1, ty],
1282
+ [0, 0, 1]]
1283
+ ```
1284
+ Args:
1285
+ translation (torch.Tensor): translation vector, arbitrary shape (..., 2)
1286
+
1287
+ Returns:
1288
+ (torch.Tensor): (..., 3, 3) transformation matrix
1289
+ """
1290
+ return torch.cat([
1291
+ torch.cat([
1292
+ torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
1293
+ translation[..., None],
1294
+ ], dim=-1),
1295
+ torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1),
1296
+ ], dim=-2)
1297
+
1298
+
1299
+ def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None):
1300
+ """
1301
+ Scale matrix for 2D scaling
1302
+ ```
1303
+ [[s, 0, tx],
1304
+ [0, s, ty],
1305
+ [0, 0, 1]]
1306
+ ```
1307
+ Args:
1308
+ scale (float | torch.Tensor): scale factor, arbitrary shape (...,)
1309
+ center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0)
1310
+
1311
+ Returns:
1312
+ (torch.Tensor): (..., 3, 3) transformation matrix
1313
+ """
1314
+ if isinstance(scale, float):
1315
+ scale = torch.tensor(scale)
1316
+ if center is not None:
1317
+ scale = scale.to(center)
1318
+ if center is None:
1319
+ center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1)
1320
+ return torch.cat([
1321
+ torch.cat([
1322
+ scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1),
1323
+ center[..., :, None] - center[..., :, None] * scale[..., None, None],
1324
+ ], dim=-1),
1325
+ torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1),
1326
+ ], dim=-2)
1327
+
1328
+
1329
+ def apply_2d(transform: torch.Tensor, points: torch.Tensor):
1330
+ """
1331
+ Apply (3x3 or 2x3) 2D affine transformation to points
1332
+ ```
1333
+ p = R @ p + t
1334
+ ```
1335
+ Args:
1336
+ transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix
1337
+ points (torch.Tensor): (..., N, 2) points to transform
1338
+
1339
+ Returns:
1340
+ (torch.Tensor): (..., N, 2) transformed points
1341
+ """
1342
+ assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3"
1343
+ assert points.shape[-1] == 2, "points must be 2D"
1344
+ return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2]
moge/model/utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+
12
+ def wrap_module_with_gradient_checkpointing(module: nn.Module):
13
+ from torch.utils.checkpoint import checkpoint
14
+ class _CheckpointingWrapper(module.__class__):
15
+ _restore_cls = module.__class__
16
+ def forward(self, *args, **kwargs):
17
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
18
+
19
+ module.__class__ = _CheckpointingWrapper
20
+ return module
21
+
22
+
23
+ def unwrap_module_with_gradient_checkpointing(module: nn.Module):
24
+ module.__class__ = module.__class__._restore_cls
25
+
26
+
27
+ def wrap_dinov2_attention_with_sdpa(module: nn.Module):
28
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
29
+ class _AttentionWrapper(module.__class__):
30
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
31
+ B, N, C = x.shape
32
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
33
+
34
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
35
+
36
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
37
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
38
+
39
+ x = self.proj(x)
40
+ x = self.proj_drop(x)
41
+ return x
42
+ module.__class__ = _AttentionWrapper
43
+ return module
44
+
45
+
46
+ def sync_ddp_hook(state, bucket: torch.distributed.GradBucket) -> torch.futures.Future[torch.Tensor]:
47
+ group_to_use = torch.distributed.group.WORLD
48
+ world_size = group_to_use.size()
49
+ grad = bucket.buffer()
50
+ grad.div_(world_size)
51
+ torch.distributed.all_reduce(grad, group=group_to_use)
52
+ fut = torch.futures.Future()
53
+ fut.set_result(grad)
54
+ return fut
moge/model/v2.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file is modified from MoGe:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+ # Modifications Copyright (c) 2026 Ze-Xin Yin, Robot labs of Horizon Robotics, and D-Robotics.
6
+
7
+
8
+ from typing import *
9
+ from numbers import Number
10
+ from functools import partial
11
+ from pathlib import Path
12
+ import warnings
13
+
14
+ import torch
15
+ import torch.nn as nn
16
+ import torch.nn.functional as F
17
+ import torch.utils
18
+ import torch.utils.checkpoint
19
+ import torch.amp
20
+ import torch.version
21
+ import utils3d
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, angle_diff_vec3
25
+ from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
26
+ from .modules import DINOv2Encoder, MLP, ConvStack
27
+ from . import transforms
28
+
29
+ from einops import rearrange
30
+
31
+ def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor:
32
+ """
33
+ Get image space UV grid, ranging in [0, 1].
34
+
35
+ >>> image_uv(10, 10):
36
+ [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]],
37
+ [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]],
38
+ ... ... ...
39
+ [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]]
40
+
41
+ Args:
42
+ width (int): image width
43
+ height (int): image height
44
+
45
+ Returns:
46
+ torch.Tensor: shape (height, width, 2)
47
+ """
48
+ if left is None: left = 0
49
+ if top is None: top = 0
50
+ if right is None: right = width
51
+ if bottom is None: bottom = height
52
+ u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype)
53
+ v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype)
54
+ u, v = torch.meshgrid(u, v, indexing='xy')
55
+ uv = torch.stack([u, v], dim=-1)
56
+ return uv
57
+
58
+ def depth_to_points(depth: torch.Tensor, intrinsics: torch.Tensor, extrinsics: torch.Tensor = None):
59
+ height, width = depth.shape[-2:]
60
+ uv = image_uv(width=width, height=height, dtype=depth.dtype, device=depth.device)
61
+ pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics[..., None, :, :], extrinsics=extrinsics[..., None, :, :] if extrinsics is not None else None)
62
+ return pts
63
+
64
+ class MoGeModel(nn.Module):
65
+ encoder: DINOv2Encoder
66
+ neck: ConvStack
67
+ points_head: ConvStack
68
+ mask_head: ConvStack
69
+ scale_head: MLP
70
+
71
+ def __init__(self,
72
+ encoder: Dict[str, Any],
73
+ neck: Dict[str, Any],
74
+ points_head: Dict[str, Any] = None,
75
+ mask_head: Dict[str, Any] = None,
76
+ normal_head: Dict[str, Any] = None,
77
+ scale_head: Dict[str, Any] = None,
78
+ remap_output: Literal['linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
79
+ num_tokens_range: List[int] = [1200, 3600],
80
+ **deprecated_kwargs
81
+ ):
82
+ super(MoGeModel, self).__init__()
83
+ if deprecated_kwargs:
84
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
85
+
86
+ self.remap_output = remap_output
87
+ self.num_tokens_range = num_tokens_range
88
+
89
+ self.encoder = DINOv2Encoder(**encoder)
90
+ self.neck = ConvStack(**neck)
91
+ if points_head is not None:
92
+ self.points_head = ConvStack(**points_head)
93
+ if mask_head is not None:
94
+ self.mask_head = ConvStack(**mask_head)
95
+ if normal_head is not None:
96
+ self.normal_head = ConvStack(**normal_head)
97
+ if scale_head is not None:
98
+ self.scale_head = MLP(**scale_head)
99
+
100
+ @property
101
+ def device(self) -> torch.device:
102
+ return next(self.parameters()).device
103
+
104
+ @property
105
+ def dtype(self) -> torch.dtype:
106
+ return next(self.parameters()).dtype
107
+
108
+ @classmethod
109
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
110
+ """
111
+ Load a model from a checkpoint file.
112
+
113
+ ### Parameters:
114
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
115
+ - `compiled`
116
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
117
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
118
+
119
+ ### Returns:
120
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
121
+ """
122
+ if Path(pretrained_model_name_or_path).exists():
123
+ checkpoint_path = pretrained_model_name_or_path
124
+ else:
125
+ checkpoint_path = hf_hub_download(
126
+ repo_id=pretrained_model_name_or_path,
127
+ repo_type="model",
128
+ filename="model.pt",
129
+ **hf_kwargs
130
+ )
131
+ checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=True)
132
+
133
+ model_config = checkpoint['model_config']
134
+ if model_kwargs is not None:
135
+ model_config.update(model_kwargs)
136
+ model = cls(**model_config)
137
+ model.load_state_dict(checkpoint['model'], strict=False)
138
+
139
+ return model
140
+
141
+ def init_weights(self):
142
+ self.encoder.init_weights()
143
+
144
+ def enable_gradient_checkpointing(self):
145
+ self.encoder.enable_gradient_checkpointing()
146
+ self.neck.enable_gradient_checkpointing()
147
+ for head in ['points_head', 'normal_head', 'mask_head']:
148
+ if hasattr(self, head):
149
+ getattr(self, head).enable_gradient_checkpointing()
150
+
151
+ def enable_pytorch_native_sdpa(self):
152
+ self.encoder.enable_pytorch_native_sdpa()
153
+
154
+ def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
155
+ if self.remap_output == 'linear':
156
+ pass
157
+ elif self.remap_output =='sinh':
158
+ points = torch.sinh(points)
159
+ elif self.remap_output == 'exp':
160
+ xy, z = points.split([2, 1], dim=-1)
161
+ z = torch.exp(z)
162
+ points = torch.cat([xy * z, z], dim=-1)
163
+ elif self.remap_output =='sinh_exp':
164
+ xy, z = points.split([2, 1], dim=-1)
165
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
166
+ else:
167
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
168
+ return points
169
+
170
+ @torch.inference_mode()
171
+ def infer_feature_tokens(self, image: torch.Tensor, num_tokens: int, tokens_layer: int = -1) -> torch.Tensor:
172
+ batch_size, _, img_h, img_w = image.shape
173
+ device, dtype = image.device, image.dtype
174
+
175
+ aspect_ratio = img_w / img_h
176
+ base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
177
+ num_tokens = base_h * base_w
178
+
179
+ # Backbones encoding
180
+ features = self.encoder(image, base_h, base_w, return_class_token=False)
181
+ features = [features, None, None, None, None]
182
+
183
+ # Concat UVs for aspect ratio input
184
+ for level in range(5):
185
+ uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
186
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
187
+ if features[level] is None:
188
+ features[level] = uv
189
+ else:
190
+ features[level] = torch.concat([features[level], uv], dim=1)
191
+
192
+ # Shared neck
193
+ features = self.neck(features)[tokens_layer]
194
+ return features
195
+
196
+ def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
197
+ batch_size, _, img_h, img_w = image.shape
198
+ device, dtype = image.device, image.dtype
199
+
200
+ aspect_ratio = img_w / img_h
201
+ base_h, base_w = int((num_tokens / aspect_ratio) ** 0.5), int((num_tokens * aspect_ratio) ** 0.5)
202
+ num_tokens = base_h * base_w
203
+
204
+ # Backbones encoding
205
+ features, cls_token = self.encoder(image, base_h, base_w, return_class_token=True)
206
+ features = [features, None, None, None, None]
207
+
208
+ # Concat UVs for aspect ratio input
209
+ for level in range(5):
210
+ uv = normalized_view_plane_uv(width=base_w * 2 ** level, height=base_h * 2 ** level, aspect_ratio=aspect_ratio, dtype=dtype, device=device)
211
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(batch_size, -1, -1, -1)
212
+ if features[level] is None:
213
+ features[level] = uv
214
+ else:
215
+ features[level] = torch.concat([features[level], uv], dim=1)
216
+
217
+ # Shared neck
218
+ features = self.neck(features)
219
+
220
+ # Heads decoding
221
+ points, normal, mask = (getattr(self, head)(features)[-1] if hasattr(self, head) else None for head in ['points_head', 'normal_head', 'mask_head'])
222
+ metric_scale = self.scale_head(cls_token) if hasattr(self, 'scale_head') else None
223
+
224
+ # Resize
225
+ points, normal, mask = (F.interpolate(v, (img_h, img_w), mode='bilinear', align_corners=False, antialias=False) if v is not None else None for v in [points, normal, mask])
226
+
227
+ # Remap output
228
+ if points is not None:
229
+ points = points.permute(0, 2, 3, 1)
230
+ points = self._remap_points(points) # slightly improves the performance in case of very large output values
231
+ if normal is not None:
232
+ normal = normal.permute(0, 2, 3, 1)
233
+ normal = F.normalize(normal, dim=-1)
234
+ if mask is not None:
235
+ mask = mask.squeeze(1).sigmoid()
236
+ if metric_scale is not None:
237
+ metric_scale = metric_scale.squeeze(1).exp()
238
+
239
+ return_dict = {
240
+ 'points': points,
241
+ 'normal': normal,
242
+ 'mask': mask,
243
+ 'metric_scale': metric_scale
244
+ }
245
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
246
+
247
+ return return_dict
248
+
249
+ @torch.inference_mode()
250
+ def infer(
251
+ self,
252
+ image: torch.Tensor,
253
+ num_tokens: int = None,
254
+ resolution_level: int = 9,
255
+ force_projection: bool = True,
256
+ apply_mask: Literal[False, True, 'blend'] = True,
257
+ fov_x: Optional[Union[Number, torch.Tensor]] = None,
258
+ use_fp16: bool = True,
259
+ ) -> Dict[str, torch.Tensor]:
260
+ """
261
+ User-friendly inference function
262
+
263
+ ### Parameters
264
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
265
+ - `num_tokens`: the number of base ViT tokens to use for inference, `'least'` or `'most'` or an integer. Suggested range: 1200 ~ 2500.
266
+ More tokens will result in significantly higher accuracy and finer details, but slower inference time. Default: `'most'`.
267
+ - `force_projection`: if True, the output point map will be computed using the actual depth map. Default: True
268
+ - `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
269
+ - `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
270
+ - `use_fp16`: if True, use mixed precision to speed up inference. Default: True
271
+
272
+ ### Returns
273
+
274
+ A dictionary containing the following keys:
275
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
276
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
277
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
278
+ """
279
+ if image.dim() == 3:
280
+ omit_batch_dim = True
281
+ image = image.unsqueeze(0)
282
+ else:
283
+ omit_batch_dim = False
284
+ image = image.to(dtype=self.dtype, device=self.device)
285
+
286
+ original_height, original_width = image.shape[-2:]
287
+ area = original_height * original_width
288
+ aspect_ratio = original_width / original_height
289
+
290
+ # Determine the number of base tokens to use
291
+ if num_tokens is None:
292
+ min_tokens, max_tokens = self.num_tokens_range
293
+ num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
294
+
295
+ # Forward pass
296
+ with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
297
+ output = self.forward(image, num_tokens=num_tokens)
298
+ points, normal, mask, metric_scale = (output.get(k, None) for k in ['points', 'normal', 'mask', 'metric_scale'])
299
+
300
+ # Always process the output in fp32 precision
301
+ points, normal, mask, metric_scale, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, normal, mask, metric_scale, fov_x])
302
+ with torch.autocast(device_type=self.device.type, dtype=torch.float32):
303
+ if mask is not None:
304
+ mask_binary = mask > 0.5
305
+ else:
306
+ mask_binary = None
307
+
308
+ if points is not None:
309
+ # Convert affine point map to camera-space. Recover depth and intrinsics from point map.
310
+ # NOTE: Focal here is the focal length relative to half the image diagonal
311
+ if fov_x is None:
312
+ # Recover focal and shift from predicted point map
313
+ focal, shift = recover_focal_shift(points, mask_binary)
314
+ else:
315
+ # Focal is known, recover shift only
316
+ focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
317
+ if focal.ndim == 0:
318
+ focal = focal[None].expand(points.shape[0])
319
+ _, shift = recover_focal_shift(points, mask_binary, focal=focal)
320
+ fx, fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio, focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
321
+ intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
322
+ points[..., 2] += shift[..., None, None]
323
+ if mask_binary is not None:
324
+ mask_binary &= points[..., 2] > 0 # in case depth is contains negative values (which should never happen in practice)
325
+ depth = points[..., 2].clone()
326
+ else:
327
+ depth, intrinsics = None, None
328
+
329
+ # If projection constraint is forced, recompute the point map using the actual depth map & intrinsics
330
+ if force_projection and depth is not None:
331
+ points = depth_to_points(depth, intrinsics=intrinsics)
332
+
333
+ # Apply metric scale
334
+ if metric_scale is not None:
335
+ if points is not None:
336
+ points *= metric_scale[:, None, None, None]
337
+ if depth is not None:
338
+ depth *= metric_scale[:, None, None]
339
+
340
+ # Apply mask
341
+ if apply_mask and mask_binary is not None:
342
+ points = torch.where(mask_binary[..., None], points, torch.inf) if points is not None else None
343
+ depth = torch.where(mask_binary, depth, torch.inf) if depth is not None else None
344
+ normal = torch.where(mask_binary[..., None], normal, torch.zeros_like(normal)) if normal is not None else None
345
+
346
+ return_dict = {
347
+ 'points': points,
348
+ 'intrinsics': intrinsics,
349
+ 'depth': depth,
350
+ 'mask': mask_binary,
351
+ 'normal': normal,
352
+ 'metric_scale': metric_scale
353
+ }
354
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
355
+
356
+ if omit_batch_dim:
357
+ return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
358
+
359
+ return return_dict
moge/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
moge/utils/download.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from pathlib import Path
7
+ from typing import *
8
+ import requests
9
+
10
+ from tqdm import tqdm
11
+
12
+
13
+ __all__ = ["download_file", "download_bytes"]
14
+
15
+
16
+ def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
17
+ # Ensure headers is a dict if not provided
18
+ headers = headers or {}
19
+
20
+ # Initialize local variables
21
+ file_path = Path(filepath)
22
+ downloaded_bytes = 0
23
+
24
+ # Check if we should resume the download
25
+ if resume and file_path.exists():
26
+ downloaded_bytes = file_path.stat().st_size
27
+ headers['Range'] = f"bytes={downloaded_bytes}-"
28
+
29
+ # Make a GET request to fetch the file
30
+ with requests.get(url, stream=True, headers=headers) as response:
31
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
32
+
33
+ # Calculate the total size to download
34
+ total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
35
+
36
+ # Display a progress bar while downloading
37
+ with (
38
+ tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
39
+ open(file_path, 'ab') as file,
40
+ ):
41
+ # Set the initial position of the progress bar
42
+ pbar.update(downloaded_bytes)
43
+
44
+ # Write the content to the file in chunks
45
+ for chunk in response.iter_content(chunk_size=4096):
46
+ file.write(chunk)
47
+ pbar.update(len(chunk))
48
+
49
+
50
+ def download_bytes(url: str, headers: dict = None) -> bytes:
51
+ # Ensure headers is a dict if not provided
52
+ headers = headers or {}
53
+
54
+ # Make a GET request to fetch the file
55
+ with requests.get(url, stream=True, headers=headers) as response:
56
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
57
+
58
+ # Read the content of the response
59
+ return response.content
60
+
moge/utils/geometry_numpy.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+ from functools import partial
8
+ import math
9
+
10
+ import cv2
11
+ import numpy as np
12
+ from scipy.signal import fftconvolve
13
+ import numpy as np
14
+ import utils3d
15
+
16
+ from .tools import timeit
17
+
18
+
19
+ def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
20
+ if w is None:
21
+ return np.mean(x, axis=axis)
22
+ else:
23
+ w = w.astype(x.dtype)
24
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
25
+
26
+
27
+ def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
28
+ if w is None:
29
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
30
+ else:
31
+ w = w.astype(x.dtype)
32
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
33
+
34
+
35
+ def normalized_view_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
36
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
37
+ if aspect_ratio is None:
38
+ aspect_ratio = width / height
39
+
40
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
41
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
42
+
43
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
44
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
45
+ u, v = np.meshgrid(u, v, indexing='xy')
46
+ uv = np.stack([u, v], axis=-1)
47
+ return uv
48
+
49
+
50
+ def focal_to_fov_numpy(focal: np.ndarray):
51
+ return 2 * np.arctan(0.5 / focal)
52
+
53
+
54
+ def fov_to_focal_numpy(fov: np.ndarray):
55
+ return 0.5 / np.tan(fov / 2)
56
+
57
+
58
+ def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
59
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
60
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
61
+ return fov_x, fov_y
62
+
63
+
64
+ def point_map_to_depth_legacy_numpy(points: np.ndarray):
65
+ height, width = points.shape[-3:-1]
66
+ diagonal = (height ** 2 + width ** 2) ** 0.5
67
+ uv = normalized_view_plane_uv_numpy(width, height, dtype=points.dtype) # (H, W, 2)
68
+ _, uv = np.broadcast_arrays(points[..., :2], uv)
69
+
70
+ # Solve least squares problem
71
+ b = (uv * points[..., 2:]).reshape(*points.shape[:-3], -1) # (..., H * W * 2)
72
+ A = np.stack([points[..., :2], -uv], axis=-1).reshape(*points.shape[:-3], -1, 2) # (..., H * W * 2, 2)
73
+
74
+ M = A.swapaxes(-2, -1) @ A
75
+ solution = (np.linalg.inv(M + 1e-6 * np.eye(2)) @ (A.swapaxes(-2, -1) @ b[..., None])).squeeze(-1)
76
+ focal, shift = solution
77
+
78
+ depth = points[..., 2] + shift[..., None, None]
79
+ fov_x = np.arctan(width / diagonal / focal) * 2
80
+ fov_y = np.arctan(height / diagonal / focal) * 2
81
+ return depth, fov_x, fov_y, shift
82
+
83
+
84
+ def solve_optimal_focal_shift(uv: np.ndarray, xyz: np.ndarray):
85
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
86
+ from scipy.optimize import least_squares
87
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
88
+
89
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
90
+ xy_proj = xy / (z + shift)[: , None]
91
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
92
+ err = (f * xy_proj - uv).ravel()
93
+ return err
94
+
95
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
96
+ optim_shift = solution['x'].squeeze().astype(np.float32)
97
+
98
+ xy_proj = xy / (z + optim_shift)[: , None]
99
+ optim_focal = (xy_proj * uv).sum() / np.square(xy_proj).sum()
100
+
101
+ return optim_shift, optim_focal
102
+
103
+
104
+ def solve_optimal_shift(uv: np.ndarray, xyz: np.ndarray, focal: float):
105
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift"
106
+ from scipy.optimize import least_squares
107
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
108
+
109
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
110
+ xy_proj = xy / (z + shift)[: , None]
111
+ err = (focal * xy_proj - uv).ravel()
112
+ return err
113
+
114
+ solution = least_squares(partial(fn, uv, xy, z), x0=0, ftol=1e-3, method='lm')
115
+ optim_shift = solution['x'].squeeze().astype(np.float32)
116
+
117
+ return optim_shift
118
+
119
+
120
+ def recover_focal_shift_numpy(points: np.ndarray, mask: np.ndarray = None, focal: float = None, downsample_size: Tuple[int, int] = (64, 64)):
121
+ import cv2
122
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
123
+
124
+ height, width = points.shape[-3], points.shape[-2]
125
+ diagonal = (height ** 2 + width ** 2) ** 0.5
126
+
127
+ uv = normalized_view_plane_uv_numpy(width=width, height=height)
128
+
129
+ if mask is None:
130
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
131
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
132
+ else:
133
+ (points_lr, uv_lr), mask_lr = mask_aware_nearest_resize_numpy((points, uv), mask, downsample_size)
134
+
135
+ if points_lr.size < 2:
136
+ return 1., 0.
137
+
138
+ if focal is None:
139
+ focal, shift = solve_optimal_focal_shift(uv_lr, points_lr)
140
+ else:
141
+ shift = solve_optimal_shift(uv_lr, points_lr, focal)
142
+
143
+ return focal, shift
144
+
145
+
146
+ def mask_aware_nearest_resize_numpy(
147
+ inputs: Union[np.ndarray, Tuple[np.ndarray, ...], None],
148
+ mask: np.ndarray,
149
+ size: Tuple[int, int],
150
+ return_index: bool = False
151
+ ) -> Tuple[Union[np.ndarray, Tuple[np.ndarray, ...], None], np.ndarray, Tuple[np.ndarray, ...]]:
152
+ """
153
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
154
+
155
+ ### Parameters
156
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
157
+ - `mask`: input 2D mask of shape (..., H, W)
158
+ - `size`: target size (width, height)
159
+
160
+ ### Returns
161
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
162
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
163
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension.
164
+ """
165
+ height, width = mask.shape[-2:]
166
+ target_width, target_height = size
167
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
168
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
169
+ filter_size = filter_h_i * filter_w_i
170
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
171
+
172
+ # Window the original mask and uv
173
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
174
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
175
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
176
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
177
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
178
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
179
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
180
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
181
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
182
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
183
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
184
+
185
+ # Gather the target pixels's local window
186
+ target_centers = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
187
+ target_lefttop = target_centers - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
188
+ target_window = np.round(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
189
+
190
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
191
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
192
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(*([-1] * (mask.ndim - 2)), target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
193
+
194
+ # Compute nearest neighbor in the local window for each pixel
195
+ dist = np.square(target_window_centers - target_centers[..., None])
196
+ dist = dist[..., 0, :] + dist[..., 1, :]
197
+ dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
198
+ nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
199
+ nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
200
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
201
+ target_mask = np.any(target_window_mask, axis=-1)
202
+ batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
203
+
204
+ index = (*batch_indices, nearest_i, nearest_j)
205
+
206
+ if inputs is None:
207
+ outputs = None
208
+ elif isinstance(inputs, np.ndarray):
209
+ outputs = inputs[index]
210
+ elif isinstance(inputs, Sequence):
211
+ outputs = tuple(x[index] for x in inputs)
212
+ else:
213
+ raise ValueError(f'Invalid input type: {type(inputs)}')
214
+
215
+ if return_index:
216
+ return outputs, target_mask, index
217
+ else:
218
+ return outputs, target_mask
219
+
220
+
221
+ def mask_aware_area_resize_numpy(image: np.ndarray, mask: np.ndarray, target_width: int, target_height: int) -> Tuple[Tuple[np.ndarray, ...], np.ndarray]:
222
+ """
223
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
224
+
225
+ ### Parameters
226
+ - `image`: Input 2D image of shape (..., H, W, C)
227
+ - `mask`: Input 2D mask of shape (..., H, W)
228
+ - `target_width`: target width of the resized map
229
+ - `target_height`: target height of the resized map
230
+
231
+ ### Returns
232
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width).
233
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
234
+ """
235
+ height, width = mask.shape[-2:]
236
+
237
+ if image.shape[-2:] == (height, width):
238
+ omit_channel_dim = True
239
+ else:
240
+ omit_channel_dim = False
241
+ if omit_channel_dim:
242
+ image = image[..., None]
243
+
244
+ image = np.where(mask[..., None], image, 0)
245
+
246
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
247
+ filter_h_i, filter_w_i = math.ceil(filter_h_f) + 1, math.ceil(filter_w_f) + 1
248
+ filter_size = filter_h_i * filter_w_i
249
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
250
+
251
+ # Window the original mask and uv (non-copy)
252
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
253
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
254
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
255
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
256
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
257
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
258
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
259
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
260
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
261
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
262
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
263
+
264
+ # Gather the target pixels's local window
265
+ target_center = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
266
+ target_lefttop = target_center - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
267
+ target_bottomright = target_center + np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
268
+ target_window = np.floor(target_lefttop).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
269
+
270
+ target_window_centers = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
271
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
272
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
273
+
274
+ # Compute pixel area in the local windows
275
+ target_window_lefttop = np.maximum(target_window_centers - 0.5, target_lefttop[..., None])
276
+ target_window_bottomright = np.minimum(target_window_centers + 0.5, target_bottomright[..., None])
277
+ target_window_area = (target_window_bottomright - target_window_lefttop).clip(0, None)
278
+ target_window_area = np.where(target_window_mask, target_window_area[..., 0, :] * target_window_area[..., 1, :], 0)
279
+
280
+ # Weighted sum by area
281
+ target_window_image = image.reshape(*image.shape[:-3], height * width, -1)[..., target_window_indices, :].swapaxes(-2, -1)
282
+ target_mask = np.sum(target_window_area, axis=-1) >= 0.25
283
+ target_image = weighted_mean_numpy(target_window_image, target_window_area[..., None, :], axis=-1)
284
+
285
+ if omit_channel_dim:
286
+ target_image = target_image[..., 0]
287
+
288
+ return target_image, target_mask
289
+
290
+
291
+ def norm3d(x: np.ndarray) -> np.ndarray:
292
+ "Faster `np.linalg.norm(x, axis=-1)` for 3D vectors"
293
+ return np.sqrt(np.square(x[..., 0]) + np.square(x[..., 1]) + np.square(x[..., 2]))
294
+
295
+
296
+ def depth_occlusion_edge_numpy(depth: np.ndarray, mask: np.ndarray, thickness: int = 1, tol: float = 0.1):
297
+ disp = np.where(mask, 1 / depth, 0)
298
+ disp_pad = np.pad(disp, (thickness, thickness), constant_values=0)
299
+ mask_pad = np.pad(mask, (thickness, thickness), constant_values=False)
300
+ kernel_size = 2 * thickness + 1
301
+ disp_window = utils3d.numpy.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
302
+ mask_window = utils3d.numpy.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, axis=(-2, -1)) # [..., H, W, kernel_size ** 2]
303
+
304
+ disp_mean = weighted_mean_numpy(disp_window, mask_window, axis=(-2, -1))
305
+ fg_edge_mask = mask & (disp > (1 + tol) * disp_mean)
306
+ bg_edge_mask = mask & (disp_mean > (1 + tol) * disp)
307
+
308
+ edge_mask = (cv2.dilate(fg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0) \
309
+ & (cv2.dilate(bg_edge_mask.astype(np.uint8), np.ones((3, 3), dtype=np.uint8), iterations=thickness) > 0)
310
+
311
+ return edge_mask
312
+
313
+
314
+ def disk_kernel(radius: int) -> np.ndarray:
315
+ """
316
+ Generate disk kernel with given radius.
317
+
318
+ Args:
319
+ radius (int): Radius of the disk (in pixels).
320
+
321
+ Returns:
322
+ np.ndarray: (2*radius+1, 2*radius+1) normalized convolution kernel.
323
+ """
324
+ # Create coordinate grid centered at (0,0)
325
+ L = np.arange(-radius, radius + 1)
326
+ X, Y = np.meshgrid(L, L)
327
+ # Generate disk: region inside circle with radius R is 1
328
+ kernel = ((X**2 + Y**2) <= radius**2).astype(np.float32)
329
+ # Normalize the kernel
330
+ kernel /= np.sum(kernel)
331
+ return kernel
332
+
333
+
334
+ def disk_blur(image: np.ndarray, radius: int) -> np.ndarray:
335
+ """
336
+ Apply disk blur to an image using FFT convolution.
337
+
338
+ Args:
339
+ image (np.ndarray): Input image, can be grayscale or color.
340
+ radius (int): Blur radius (in pixels).
341
+
342
+ Returns:
343
+ np.ndarray: Blurred image.
344
+ """
345
+ if radius == 0:
346
+ return image
347
+ kernel = disk_kernel(radius)
348
+ if image.ndim == 2:
349
+ blurred = fftconvolve(image, kernel, mode='same')
350
+ elif image.ndim == 3:
351
+ channels = []
352
+ for i in range(image.shape[2]):
353
+ blurred_channel = fftconvolve(image[..., i], kernel, mode='same')
354
+ channels.append(blurred_channel)
355
+ blurred = np.stack(channels, axis=-1)
356
+ else:
357
+ raise ValueError("Image must be 2D or 3D.")
358
+ return blurred
359
+
360
+
361
+ def depth_of_field(
362
+ img: np.ndarray,
363
+ disp: np.ndarray,
364
+ focus_disp : float,
365
+ max_blur_radius : int = 10,
366
+ ) -> np.ndarray:
367
+ """
368
+ Apply depth of field effect to an image.
369
+
370
+ Args:
371
+ img (numpy.ndarray): (H, W, 3) input image.
372
+ depth (numpy.ndarray): (H, W) depth map of the scene.
373
+ focus_depth (float): Focus depth of the lens.
374
+ strength (float): Strength of the depth of field effect.
375
+ max_blur_radius (int): Maximum blur radius (in pixels).
376
+
377
+ Returns:
378
+ numpy.ndarray: (H, W, 3) output image with depth of field effect applied.
379
+ """
380
+ # Precalculate dialated depth map for each blur radius
381
+ max_disp = np.max(disp)
382
+ disp = disp / max_disp
383
+ focus_disp = focus_disp / max_disp
384
+ dilated_disp = []
385
+ for radius in range(max_blur_radius + 1):
386
+ dilated_disp.append(cv2.dilate(disp, cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*radius+1, 2*radius+1)), iterations=1))
387
+
388
+ # Determine the blur radius for each pixel based on the depth map
389
+ blur_radii = np.clip(abs(disp - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
390
+ for radius in range(max_blur_radius + 1):
391
+ dialted_blur_radii = np.clip(abs(dilated_disp[radius] - focus_disp) * max_blur_radius, 0, max_blur_radius).astype(np.int32)
392
+ mask = (dialted_blur_radii >= radius) & (dialted_blur_radii >= blur_radii) & (dilated_disp[radius] > disp)
393
+ blur_radii[mask] = dialted_blur_radii[mask]
394
+ blur_radii = np.clip(blur_radii, 0, max_blur_radius)
395
+ blur_radii = cv2.blur(blur_radii, (5, 5))
396
+
397
+ # Precalculate the blured image for each blur radius
398
+ unique_radii = np.unique(blur_radii)
399
+ precomputed = {}
400
+ for radius in range(max_blur_radius + 1):
401
+ if radius not in unique_radii:
402
+ continue
403
+ precomputed[radius] = disk_blur(img, radius)
404
+
405
+ # Composit the blured image for each pixel
406
+ output = np.zeros_like(img)
407
+ for r in unique_radii:
408
+ mask = blur_radii == r
409
+ output[mask] = precomputed[r][mask]
410
+
411
+ return output
moge/utils/geometry_torch.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+ import math
8
+ from collections import namedtuple
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torch.nn as nn
13
+ import torch.nn.functional as F
14
+ import torch.types
15
+ import utils3d
16
+
17
+ from .tools import timeit
18
+ from .geometry_numpy import solve_optimal_focal_shift, solve_optimal_shift
19
+
20
+
21
+ def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
22
+ if w is None:
23
+ return x.mean(dim=dim, keepdim=keepdim)
24
+ else:
25
+ w = w.to(x.dtype)
26
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
27
+
28
+
29
+ def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
30
+ if w is None:
31
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
32
+ else:
33
+ w = w.to(x.dtype)
34
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
35
+
36
+
37
+ def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
38
+ if w is None:
39
+ return x.add(eps).log().mean(dim=dim).exp()
40
+ else:
41
+ w = w.to(x.dtype)
42
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
43
+
44
+
45
+ def normalized_view_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
46
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
47
+ if aspect_ratio is None:
48
+ aspect_ratio = width / height
49
+
50
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
51
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
52
+
53
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
54
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
55
+ u, v = torch.meshgrid(u, v, indexing='xy')
56
+ uv = torch.stack([u, v], dim=-1)
57
+ return uv
58
+
59
+
60
+ def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
61
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
62
+ kernel = kernel / kernel.sum()
63
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
64
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
65
+ input = F.conv2d(input, kernel, groups=input.shape[1])
66
+ return input
67
+
68
+
69
+ def focal_to_fov(focal: torch.Tensor):
70
+ return 2 * torch.atan(0.5 / focal)
71
+
72
+
73
+ def fov_to_focal(fov: torch.Tensor):
74
+ return 0.5 / torch.tan(fov / 2)
75
+
76
+
77
+ def angle_diff_vec3(v1: torch.Tensor, v2: torch.Tensor, eps: float = 1e-12):
78
+ return torch.atan2(torch.cross(v1, v2, dim=-1).norm(dim=-1) + eps, (v1 * v2).sum(dim=-1))
79
+
80
+ def intrinsics_to_fov(intrinsics: torch.Tensor):
81
+ """
82
+ Returns field of view in radians from normalized intrinsics matrix.
83
+ ### Parameters:
84
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
85
+
86
+ ### Returns:
87
+ - fov_x: torch.Tensor of shape (...)
88
+ - fov_y: torch.Tensor of shape (...)
89
+ """
90
+ focal_x = intrinsics[..., 0, 0]
91
+ focal_y = intrinsics[..., 1, 1]
92
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
93
+
94
+
95
+ def point_map_to_depth_legacy(points: torch.Tensor):
96
+ height, width = points.shape[-3:-1]
97
+ diagonal = (height ** 2 + width ** 2) ** 0.5
98
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
99
+
100
+ # Solve least squares problem
101
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
102
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
103
+
104
+ M = A.transpose(-2, -1) @ A
105
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
106
+ focal, shift = solution.unbind(-1)
107
+
108
+ depth = points[..., 2] + shift[..., None, None]
109
+ fov_x = torch.atan(width / diagonal / focal) * 2
110
+ fov_y = torch.atan(height / diagonal / focal) * 2
111
+ return depth, fov_x, fov_y, shift
112
+
113
+
114
+ def view_plane_uv_to_focal(uv: torch.Tensor):
115
+ normed_uv = normalized_view_plane_uv(width=uv.shape[-2], height=uv.shape[-3], device=uv.device, dtype=uv.dtype)
116
+ focal = (uv * normed_uv).sum() / uv.square().sum().add(1e-12)
117
+ return focal
118
+
119
+
120
+ def recover_focal_shift(points: torch.Tensor, mask: torch.Tensor = None, focal: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
121
+ """
122
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
123
+
124
+ Note that it assumes:
125
+ - the optical center is at the center of the map
126
+ - the map is undistorted
127
+ - the map is isometric in the x and y directions
128
+
129
+ ### Parameters:
130
+ - `points: torch.Tensor` of shape (..., H, W, 3)
131
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
132
+
133
+ ### Returns:
134
+ - `focal`: torch.Tensor of shape (...) the estimated focal length, relative to the half diagonal of the map
135
+ - `shift`: torch.Tensor of shape (...) Z-axis shift to translate the point map to camera space
136
+ """
137
+ shape = points.shape
138
+ height, width = points.shape[-3], points.shape[-2]
139
+ diagonal = (height ** 2 + width ** 2) ** 0.5
140
+
141
+ points = points.reshape(-1, *shape[-3:])
142
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
143
+ focal = focal.reshape(-1) if focal is not None else None
144
+ uv = normalized_view_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
145
+
146
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
147
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
148
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
149
+
150
+ uv_lr_np = uv_lr.cpu().numpy()
151
+ points_lr_np = points_lr.detach().cpu().numpy()
152
+ focal_np = focal.cpu().numpy() if focal is not None else None
153
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
154
+ optim_shift, optim_focal = [], []
155
+ for i in range(points.shape[0]):
156
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
157
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
158
+ if uv_lr_i_np.shape[0] < 2:
159
+ optim_focal.append(1)
160
+ optim_shift.append(0)
161
+ continue
162
+ if focal is None:
163
+ optim_shift_i, optim_focal_i = solve_optimal_focal_shift(uv_lr_i_np, points_lr_i_np)
164
+ optim_focal.append(float(optim_focal_i))
165
+ else:
166
+ optim_shift_i = solve_optimal_shift(uv_lr_i_np, points_lr_i_np, focal_np[i])
167
+ optim_shift.append(float(optim_shift_i))
168
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype).reshape(shape[:-3])
169
+
170
+ if focal is None:
171
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype).reshape(shape[:-3])
172
+ else:
173
+ optim_focal = focal.reshape(shape[:-3])
174
+
175
+ return optim_focal, optim_shift
176
+
177
+
178
+ def mask_aware_nearest_resize(
179
+ inputs: Union[torch.Tensor, Sequence[torch.Tensor], None],
180
+ mask: torch.BoolTensor,
181
+ size: Tuple[int, int],
182
+ return_index: bool = False
183
+ ) -> Tuple[Union[torch.Tensor, Sequence[torch.Tensor], None], torch.BoolTensor, Tuple[torch.LongTensor, ...]]:
184
+ """
185
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
186
+
187
+ ### Parameters
188
+ - `inputs`: a single or a list of input 2D map(s) of shape (..., H, W, ...).
189
+ - `mask`: input 2D mask of shape (..., H, W)
190
+ - `size`: target size (target_width, target_height)
191
+
192
+ ### Returns
193
+ - `*resized_maps`: resized map(s) of shape (..., target_height, target_width, ...).
194
+ - `resized_mask`: mask of the resized map of shape (..., target_height, target_width)
195
+ - `nearest_idx`: if return_index is True, nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension, .
196
+ """
197
+ height, width = mask.shape[-2:]
198
+ target_width, target_height = size
199
+ device = mask.device
200
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
201
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
202
+ filter_size = filter_h_i * filter_w_i
203
+ padding_h, padding_w = filter_h_i // 2 + 1, filter_w_i // 2 + 1
204
+
205
+ # Window the original mask and uv
206
+ uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
207
+ indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
208
+ padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
209
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
210
+ padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
211
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
212
+ padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
213
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
214
+ windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
215
+ windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
216
+ windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
217
+
218
+ # Gather the target pixels's local window
219
+ target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
220
+ target_lefttop = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
221
+ target_window = torch.round(target_lefttop).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
222
+
223
+ target_window_uv = windowed_uv[target_window[..., 1], target_window[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
224
+ target_window_mask = windowed_mask[..., target_window[..., 1], target_window[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
225
+ target_window_indices = windowed_indices[target_window[..., 1], target_window[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
226
+ target_window_indices = target_window_indices.expand_as(target_window_mask)
227
+
228
+ # Compute nearest neighbor in the local window for each pixel
229
+ dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
230
+ nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
231
+ nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
232
+ target_mask = torch.any(target_window_mask, dim=-1)
233
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
234
+ batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
235
+
236
+ index = (*batch_indices, nearest_i, nearest_j)
237
+
238
+ if inputs is None:
239
+ outputs = None
240
+ elif isinstance(inputs, torch.Tensor):
241
+ outputs = inputs[index]
242
+ elif isinstance(inputs, Sequence):
243
+ outputs = tuple(x[index] for x in inputs)
244
+ else:
245
+ raise ValueError(f'Invalid input type: {type(inputs)}')
246
+
247
+ if return_index:
248
+ return outputs, target_mask, index
249
+ else:
250
+ return outputs, target_mask
251
+
252
+
253
+ def theshold_depth_change(depth: torch.Tensor, mask: torch.Tensor, pooler: Literal['min', 'max'], rtol: float = 0.2, kernel_size: int = 3):
254
+ *batch_shape, height, width = depth.shape
255
+ depth = depth.reshape(-1, 1, height, width)
256
+ mask = mask.reshape(-1, 1, height, width)
257
+ if pooler =='max':
258
+ pooled_depth = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
259
+ output_mask = pooled_depth > depth * (1 + rtol)
260
+ elif pooler =='min':
261
+ pooled_depth = -F.max_pool2d(-torch.where(mask, depth, torch.inf), kernel_size, stride=1, padding=kernel_size // 2)
262
+ output_mask = pooled_depth < depth * (1 - rtol)
263
+ else:
264
+ raise ValueError(f'Unsupported pooler: {pooler}')
265
+ output_mask = output_mask.reshape(*batch_shape, height, width)
266
+ return output_mask
267
+
268
+
269
+ def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
270
+ device, dtype = depth.device, depth.dtype
271
+
272
+ disp = torch.where(mask, 1 / depth, 0)
273
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
274
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
275
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
276
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)).flatten(-2) # [..., H, W, kernel_size ** 2]
277
+
278
+ x = torch.linspace(-kernel_size // 2, kernel_size // 2, kernel_size, device=device, dtype=dtype)
279
+ A = torch.stack([*torch.meshgrid(x, x, indexing='xy'), torch.ones((kernel_size, kernel_size), device=device, dtype=dtype)], dim=-1).reshape(kernel_size ** 2, 3) # [kernel_size ** 2, 3]
280
+ A = mask_window[..., None] * A
281
+ I = torch.eye(3, device=device, dtype=dtype)
282
+
283
+ affine_disp_window = (disp_window[..., None, :] @ A @ torch.inverse(A.mT @ A + 1e-5 * I) @ A.mT).clamp_min(1e-12)[..., 0, :] # [..., H, W, kernel_size ** 2]
284
+ diff = torch.where(mask_window, torch.maximum(affine_disp_window, disp_window) / torch.minimum(affine_disp_window, disp_window) - 1, 0)
285
+
286
+ edge_mask = mask & (diff > tol).any(dim=-1)
287
+
288
+ disp_mean = weighted_mean(disp_window, mask_window, dim=-1)
289
+ fg_edge_mask = edge_mask & (disp > disp_mean)
290
+ # fg_edge_mask = edge_mask & theshold_depth_change(depth, mask, pooler='max', rtol=tol, kernel_size=kernel_size)
291
+ bg_edge_mask = edge_mask & ~fg_edge_mask
292
+ return fg_edge_mask, bg_edge_mask
293
+
294
+
295
+ def depth_occlusion_edge(depth: torch.FloatTensor, mask: torch.BoolTensor, kernel_size: int = 3, tol: float = 0.1):
296
+ device, dtype = depth.device, depth.dtype
297
+
298
+ disp = torch.where(mask, 1 / depth, 0)
299
+ disp_pad = F.pad(disp, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=0)
300
+ mask_pad = F.pad(mask, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), value=False)
301
+ disp_window = utils3d.torch.sliding_window_2d(disp_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
302
+ mask_window = utils3d.torch.sliding_window_2d(mask_pad, (kernel_size, kernel_size), 1, dim=(-2, -1)) # [..., H, W, kernel_size ** 2]
303
+
304
+ disp_mean = weighted_mean(disp_window, mask_window, dim=(-2, -1))
305
+ fg_edge_mask = mask & (disp / disp_mean > 1 + tol)
306
+ bg_edge_mask = mask & (disp_mean / disp > 1 + tol)
307
+
308
+ fg_edge_mask = fg_edge_mask & F.max_pool2d(bg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
309
+ bg_edge_mask = bg_edge_mask & F.max_pool2d(fg_edge_mask.float(), kernel_size + 2, stride=1, padding=kernel_size // 2 + 1).bool()
310
+
311
+ return fg_edge_mask, bg_edge_mask
312
+
313
+
314
+ def dilate_with_mask(input: torch.Tensor, mask: torch.BoolTensor, filter: Literal['min', 'max', 'mean', 'median'] = 'mean', iterations: int = 1) -> torch.Tensor:
315
+ kernel = torch.tensor([[False, True, False], [True, True, True], [False, True, False]], device=input.device, dtype=torch.bool)
316
+ for _ in range(iterations):
317
+ input_window = utils3d.torch.sliding_window_2d(F.pad(input, (1, 1, 1, 1), mode='constant', value=0), window_size=3, stride=1, dim=(-2, -1))
318
+ mask_window = kernel & utils3d.torch.sliding_window_2d(F.pad(mask, (1, 1, 1, 1), mode='constant', value=False), window_size=3, stride=1, dim=(-2, -1))
319
+ if filter =='min':
320
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.inf).min(dim=(-2, -1)).values)
321
+ elif filter =='max':
322
+ input = torch.where(mask, input, torch.where(mask_window, input_window, -torch.inf).max(dim=(-2, -1)).values)
323
+ elif filter == 'mean':
324
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).nanmean(dim=(-2, -1)))
325
+ elif filter =='median':
326
+ input = torch.where(mask, input, torch.where(mask_window, input_window, torch.nan).flatten(-2).nanmedian(dim=-1).values)
327
+ mask = mask_window.any(dim=(-2, -1))
328
+ return input, mask
329
+
330
+
331
+ def refine_depth_with_normal(depth: torch.Tensor, normal: torch.Tensor, intrinsics: torch.Tensor, iterations: int = 10, damp: float = 1e-3, eps: float = 1e-12, kernel_size: int = 5) -> torch.Tensor:
332
+ device, dtype = depth.device, depth.dtype
333
+ height, width = depth.shape[-2:]
334
+ radius = kernel_size // 2
335
+
336
+ duv = torch.stack(torch.meshgrid(torch.linspace(-radius / width, radius / width, kernel_size, device=device, dtype=dtype), torch.linspace(-radius / height, radius / height, kernel_size, device=device, dtype=dtype), indexing='xy'), dim=-1).to(dtype=dtype, device=device)
337
+
338
+ log_depth = depth.clamp_min_(eps).log()
339
+ log_depth_diff = utils3d.torch.sliding_window_2d(log_depth, window_size=kernel_size, stride=1, dim=(-2, -1)) - log_depth[..., radius:-radius, radius:-radius, None, None]
340
+
341
+ weight = torch.exp(-(log_depth_diff / duv.norm(dim=-1).clamp_min_(eps) / 10).square())
342
+ tot_weight = weight.sum(dim=(-2, -1)).clamp_min_(eps)
343
+
344
+ uv = utils3d.torch.image_uv(height=height, width=width, device=device, dtype=dtype)
345
+ K_inv = torch.inverse(intrinsics)
346
+
347
+ grad = -(normal[..., None, :2] @ K_inv[..., None, None, :2, :2]).squeeze(-2) \
348
+ / (normal[..., None, 2:] + normal[..., None, :2] @ (K_inv[..., None, None, :2, :2] @ uv[..., :, None] + K_inv[..., None, None, :2, 2:])).squeeze(-2)
349
+ laplacian = (weight * ((utils3d.torch.sliding_window_2d(grad, window_size=kernel_size, stride=1, dim=(-3, -2)) + grad[..., radius:-radius, radius:-radius, :, None, None]) * (duv.permute(2, 0, 1) / 2)).sum(dim=-3)).sum(dim=(-2, -1))
350
+
351
+ laplacian = laplacian.clamp(-0.1, 0.1)
352
+ log_depth_refine = log_depth.clone()
353
+
354
+ for _ in range(iterations):
355
+ log_depth_refine[..., radius:-radius, radius:-radius] = 0.1 * log_depth_refine[..., radius:-radius, radius:-radius] + 0.9 * (damp * log_depth[..., radius:-radius, radius:-radius] - laplacian + (weight * utils3d.torch.sliding_window_2d(log_depth_refine, window_size=kernel_size, stride=1, dim=(-2, -1))).sum(dim=(-2, -1))) / (tot_weight + damp)
356
+
357
+ depth_refine = log_depth_refine.exp()
358
+
359
+ return depth_refine
moge/utils/io.py ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ import os
7
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
8
+ from typing import IO
9
+ import zipfile
10
+ import json
11
+ import io
12
+ from typing import *
13
+ from pathlib import Path
14
+ import re
15
+ from PIL import Image, PngImagePlugin
16
+
17
+ import numpy as np
18
+ import cv2
19
+
20
+ from .tools import timeit
21
+
22
+
23
+ def save_glb(
24
+ save_path: Union[str, os.PathLike],
25
+ vertices: np.ndarray,
26
+ faces: np.ndarray,
27
+ vertex_uvs: np.ndarray,
28
+ texture: np.ndarray,
29
+ vertex_normals: Optional[np.ndarray] = None,
30
+ ):
31
+ import trimesh
32
+ import trimesh.visual
33
+ from PIL import Image
34
+
35
+ trimesh.Trimesh(
36
+ vertices=vertices,
37
+ vertex_normals=vertex_normals,
38
+ faces=faces,
39
+ visual = trimesh.visual.texture.TextureVisuals(
40
+ uv=vertex_uvs,
41
+ material=trimesh.visual.material.PBRMaterial(
42
+ baseColorTexture=Image.fromarray(texture),
43
+ metallicFactor=0.5,
44
+ roughnessFactor=1.0
45
+ )
46
+ ),
47
+ process=False
48
+ ).export(save_path)
49
+
50
+
51
+ def save_ply(
52
+ save_path: Union[str, os.PathLike],
53
+ vertices: np.ndarray,
54
+ faces: np.ndarray,
55
+ vertex_colors: np.ndarray,
56
+ vertex_normals: Optional[np.ndarray] = None,
57
+ ):
58
+ import trimesh
59
+ import trimesh.visual
60
+ from PIL import Image
61
+
62
+ trimesh.Trimesh(
63
+ vertices=vertices,
64
+ faces=faces,
65
+ vertex_colors=vertex_colors,
66
+ vertex_normals=vertex_normals,
67
+ process=False
68
+ ).export(save_path)
69
+
70
+
71
+ def read_image(path: Union[str, os.PathLike, IO]) -> np.ndarray:
72
+ """
73
+ Read a image, return uint8 RGB array of shape (H, W, 3).
74
+ """
75
+ if isinstance(path, (str, os.PathLike)):
76
+ data = Path(path).read_bytes()
77
+ else:
78
+ data = path.read()
79
+ image = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_COLOR), cv2.COLOR_BGR2RGB)
80
+ return image
81
+
82
+
83
+ def write_image(path: Union[str, os.PathLike, IO], image: np.ndarray, quality: int = 95):
84
+ """
85
+ Write a image, input uint8 RGB array of shape (H, W, 3).
86
+ """
87
+ data = cv2.imencode('.jpg', cv2.cvtColor(image, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_JPEG_QUALITY, quality])[1].tobytes()
88
+ if isinstance(path, (str, os.PathLike)):
89
+ Path(path).write_bytes(data)
90
+ else:
91
+ path.write(data)
92
+
93
+
94
+ def read_depth(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, float]:
95
+ """
96
+ Read a depth image, return float32 depth array of shape (H, W).
97
+ """
98
+ if isinstance(path, (str, os.PathLike)):
99
+ data = Path(path).read_bytes()
100
+ else:
101
+ data = path.read()
102
+ pil_image = Image.open(io.BytesIO(data))
103
+ near = float(pil_image.info.get('near'))
104
+ far = float(pil_image.info.get('far'))
105
+ unit = float(pil_image.info.get('unit')) if 'unit' in pil_image.info else None
106
+ depth = np.array(pil_image)
107
+ mask_nan, mask_inf = depth == 0, depth == 65535
108
+ depth = (depth.astype(np.float32) - 1) / 65533
109
+ depth = near ** (1 - depth) * far ** depth
110
+ depth[mask_nan] = np.nan
111
+ depth[mask_inf] = np.inf
112
+ return depth, unit
113
+
114
+
115
+ def write_depth(
116
+ path: Union[str, os.PathLike, IO],
117
+ depth: np.ndarray,
118
+ unit: float = None,
119
+ max_range: float = 1e5,
120
+ compression_level: int = 7,
121
+ ):
122
+ """
123
+ Encode and write a depth image as 16-bit PNG format.
124
+ ### Parameters:
125
+ - `path: Union[str, os.PathLike, IO]`
126
+ The file path or file object to write to.
127
+ - `depth: np.ndarray`
128
+ The depth array, float32 array of shape (H, W).
129
+ May contain `NaN` for invalid values and `Inf` for infinite values.
130
+ - `unit: float = None`
131
+ The unit of the depth values.
132
+
133
+ Depth values are encoded as follows:
134
+ - 0: unknown
135
+ - 1 ~ 65534: depth values in logarithmic
136
+ - 65535: infinity
137
+
138
+ metadata is stored in the PNG file as text fields:
139
+ - `near`: the minimum depth value
140
+ - `far`: the maximum depth value
141
+ - `unit`: the unit of the depth values (optional)
142
+ """
143
+ mask_values, mask_nan, mask_inf = np.isfinite(depth), np.isnan(depth),np.isinf(depth)
144
+
145
+ depth = depth.astype(np.float32)
146
+ mask_finite = depth
147
+ near = max(depth[mask_values].min(), 1e-5)
148
+ far = max(near * 1.1, min(depth[mask_values].max(), near * max_range))
149
+ depth = 1 + np.round((np.log(np.nan_to_num(depth, nan=0).clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65533).astype(np.uint16) # 1~65534
150
+ depth[mask_nan] = 0
151
+ depth[mask_inf] = 65535
152
+
153
+ pil_image = Image.fromarray(depth)
154
+ pnginfo = PngImagePlugin.PngInfo()
155
+ pnginfo.add_text('near', str(near))
156
+ pnginfo.add_text('far', str(far))
157
+ if unit is not None:
158
+ pnginfo.add_text('unit', str(unit))
159
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
160
+
161
+
162
+ def read_segmentation(path: Union[str, os.PathLike, IO]) -> Tuple[np.ndarray, Dict[str, int]]:
163
+ """
164
+ Read a segmentation mask
165
+ ### Parameters:
166
+ - `path: Union[str, os.PathLike, IO]`
167
+ The file path or file object to read from.
168
+ ### Returns:
169
+ - `Tuple[np.ndarray, Dict[str, int]]`
170
+ A tuple containing:
171
+ - `mask`: uint8 or uint16 numpy.ndarray of shape (H, W).
172
+ - `labels`: Dict[str, int]. The label mapping, a dictionary of {label_name: label_id}.
173
+ """
174
+ if isinstance(path, (str, os.PathLike)):
175
+ data = Path(path).read_bytes()
176
+ else:
177
+ data = path.read()
178
+ pil_image = Image.open(io.BytesIO(data))
179
+ labels = json.loads(pil_image.info['labels']) if 'labels' in pil_image.info else None
180
+ mask = np.array(pil_image)
181
+ return mask, labels
182
+
183
+
184
+ def write_segmentation(path: Union[str, os.PathLike, IO], mask: np.ndarray, labels: Dict[str, int] = None, compression_level: int = 7):
185
+ """
186
+ Write a segmentation mask and label mapping, as PNG format.
187
+ ### Parameters:
188
+ - `path: Union[str, os.PathLike, IO]`
189
+ The file path or file object to write to.
190
+ - `mask: np.ndarray`
191
+ The segmentation mask, uint8 or uint16 array of shape (H, W).
192
+ - `labels: Dict[str, int] = None`
193
+ The label mapping, a dictionary of {label_name: label_id}.
194
+ - `compression_level: int = 7`
195
+ The compression level for PNG compression.
196
+ """
197
+ assert mask.dtype == np.uint8 or mask.dtype == np.uint16, f"Unsupported dtype {mask.dtype}"
198
+ pil_image = Image.fromarray(mask)
199
+ pnginfo = PngImagePlugin.PngInfo()
200
+ if labels is not None:
201
+ labels_json = json.dumps(labels, ensure_ascii=True, separators=(',', ':'))
202
+ pnginfo.add_text('labels', labels_json)
203
+ pil_image.save(path, pnginfo=pnginfo, compress_level=compression_level)
204
+
205
+
206
+
207
+ def read_normal(path: Union[str, os.PathLike, IO]) -> np.ndarray:
208
+ """
209
+ Read a normal image, return float32 normal array of shape (H, W, 3).
210
+ """
211
+ if isinstance(path, (str, os.PathLike)):
212
+ data = Path(path).read_bytes()
213
+ else:
214
+ data = path.read()
215
+ normal = cv2.cvtColor(cv2.imdecode(np.frombuffer(data, np.uint8), cv2.IMREAD_UNCHANGED), cv2.COLOR_BGR2RGB)
216
+ mask_nan = np.all(normal == 0, axis=-1)
217
+ normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
218
+ normal = normal / (np.sqrt(np.square(normal[..., 0]) + np.square(normal[..., 1]) + np.square(normal[..., 2])) + 1e-12)
219
+ normal[mask_nan] = np.nan
220
+ return normal
221
+
222
+
223
+ def write_normal(path: Union[str, os.PathLike, IO], normal: np.ndarray, compression_level: int = 7) -> np.ndarray:
224
+ """
225
+ Write a normal image, input float32 normal array of shape (H, W, 3).
226
+ """
227
+ mask_nan = np.isnan(normal).any(axis=-1)
228
+ normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
229
+ normal[mask_nan] = 0
230
+ data = cv2.imencode('.png', cv2.cvtColor(normal, cv2.COLOR_RGB2BGR), [cv2.IMWRITE_PNG_COMPRESSION, compression_level])[1].tobytes()
231
+ if isinstance(path, (str, os.PathLike)):
232
+ Path(path).write_bytes(data)
233
+ else:
234
+ path.write(data)
235
+
236
+
237
+ def read_meta(path: Union[str, os.PathLike, IO]) -> Dict[str, Any]:
238
+ return json.loads(Path(path).read_text())
239
+
240
+ def write_meta(path: Union[str, os.PathLike, IO], meta: Dict[str, Any]):
241
+ Path(path).write_text(json.dumps(meta))
moge/utils/panorama.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ import os
7
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
8
+ from pathlib import Path
9
+ from typing import *
10
+ import itertools
11
+ import json
12
+ import warnings
13
+
14
+ import cv2
15
+ import numpy as np
16
+ from numpy import ndarray
17
+ from tqdm import tqdm, trange
18
+ from scipy.sparse import csr_array, hstack, vstack
19
+ from scipy.ndimage import convolve
20
+ from scipy.sparse.linalg import lsmr
21
+
22
+ import utils3d
23
+
24
+
25
+ def get_panorama_cameras():
26
+ vertices, _ = utils3d.numpy.icosahedron()
27
+ intrinsics = utils3d.numpy.intrinsics_from_fov(fov_x=np.deg2rad(90), fov_y=np.deg2rad(90))
28
+ extrinsics = utils3d.numpy.extrinsics_look_at([0, 0, 0], vertices, [0, 0, 1]).astype(np.float32)
29
+ return extrinsics, [intrinsics] * len(vertices)
30
+
31
+
32
+ def spherical_uv_to_directions(uv: np.ndarray):
33
+ theta, phi = (1 - uv[..., 0]) * (2 * np.pi), uv[..., 1] * np.pi
34
+ directions = np.stack([np.sin(phi) * np.cos(theta), np.sin(phi) * np.sin(theta), np.cos(phi)], axis=-1)
35
+ return directions
36
+
37
+
38
+ def directions_to_spherical_uv(directions: np.ndarray):
39
+ directions = directions / np.linalg.norm(directions, axis=-1, keepdims=True)
40
+ u = 1 - np.arctan2(directions[..., 1], directions[..., 0]) / (2 * np.pi) % 1.0
41
+ v = np.arccos(directions[..., 2]) / np.pi
42
+ return np.stack([u, v], axis=-1)
43
+
44
+
45
+ def split_panorama_image(image: np.ndarray, extrinsics: np.ndarray, intrinsics: np.ndarray, resolution: int):
46
+ height, width = image.shape[:2]
47
+ uv = utils3d.numpy.image_uv(width=resolution, height=resolution)
48
+ splitted_images = []
49
+ for i in range(len(extrinsics)):
50
+ spherical_uv = directions_to_spherical_uv(utils3d.numpy.unproject_cv(uv, extrinsics=extrinsics[i], intrinsics=intrinsics[i]))
51
+ pixels = utils3d.numpy.uv_to_pixel(spherical_uv, width=width, height=height).astype(np.float32)
52
+
53
+ splitted_image = cv2.remap(image, pixels[..., 0], pixels[..., 1], interpolation=cv2.INTER_LINEAR)
54
+ splitted_images.append(splitted_image)
55
+ return splitted_images
56
+
57
+
58
+ def poisson_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, ndarray]:
59
+ grid_index = np.arange(height * width).reshape(height, width)
60
+ grid_index = np.pad(grid_index, ((0, 0), (1, 1)), mode='wrap' if wrap_x else 'edge')
61
+ grid_index = np.pad(grid_index, ((1, 1), (0, 0)), mode='wrap' if wrap_y else 'edge')
62
+
63
+ data = np.array([[-4, 1, 1, 1, 1]], dtype=np.float32).repeat(height * width, axis=0).reshape(-1)
64
+ indices = np.stack([
65
+ grid_index[1:-1, 1:-1],
66
+ grid_index[:-2, 1:-1], # up
67
+ grid_index[2:, 1:-1], # down
68
+ grid_index[1:-1, :-2], # left
69
+ grid_index[1:-1, 2:] # right
70
+ ], axis=-1).reshape(-1)
71
+ indptr = np.arange(0, height * width * 5 + 1, 5)
72
+ A = csr_array((data, indices, indptr), shape=(height * width, height * width))
73
+
74
+ return A
75
+
76
+
77
+ def grad_equation(width: int, height: int, wrap_x: bool = False, wrap_y: bool = False) -> Tuple[csr_array, np.ndarray]:
78
+ grid_index = np.arange(width * height).reshape(height, width)
79
+ if wrap_x:
80
+ grid_index = np.pad(grid_index, ((0, 0), (0, 1)), mode='wrap')
81
+ if wrap_y:
82
+ grid_index = np.pad(grid_index, ((0, 1), (0, 0)), mode='wrap')
83
+
84
+ data = np.concatenate([
85
+ np.concatenate([
86
+ np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j]
87
+ -np.ones((grid_index.shape[0], grid_index.shape[1] - 1), dtype=np.float32).reshape(-1, 1), # x[i,j-1]
88
+ ], axis=1).reshape(-1),
89
+ np.concatenate([
90
+ np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i,j]
91
+ -np.ones((grid_index.shape[0] - 1, grid_index.shape[1]), dtype=np.float32).reshape(-1, 1), # x[i-1,j]
92
+ ], axis=1).reshape(-1),
93
+ ])
94
+ indices = np.concatenate([
95
+ np.concatenate([
96
+ grid_index[:, :-1].reshape(-1, 1),
97
+ grid_index[:, 1:].reshape(-1, 1),
98
+ ], axis=1).reshape(-1),
99
+ np.concatenate([
100
+ grid_index[:-1, :].reshape(-1, 1),
101
+ grid_index[1:, :].reshape(-1, 1),
102
+ ], axis=1).reshape(-1),
103
+ ])
104
+ indptr = np.arange(0, grid_index.shape[0] * (grid_index.shape[1] - 1) * 2 + (grid_index.shape[0] - 1) * grid_index.shape[1] * 2 + 1, 2)
105
+ A = csr_array((data, indices, indptr), shape=(grid_index.shape[0] * (grid_index.shape[1] - 1) + (grid_index.shape[0] - 1) * grid_index.shape[1], height * width))
106
+
107
+ return A
108
+
109
+
110
+ def merge_panorama_depth(width: int, height: int, distance_maps: List[np.ndarray], pred_masks: List[np.ndarray], extrinsics: List[np.ndarray], intrinsics: List[np.ndarray]):
111
+ if max(width, height) > 256:
112
+ panorama_depth_init, _ = merge_panorama_depth(width // 2, height // 2, distance_maps, pred_masks, extrinsics, intrinsics)
113
+ panorama_depth_init = cv2.resize(panorama_depth_init, (width, height), cv2.INTER_LINEAR)
114
+ else:
115
+ panorama_depth_init = None
116
+
117
+ uv = utils3d.numpy.image_uv(width=width, height=height)
118
+ spherical_directions = spherical_uv_to_directions(uv)
119
+
120
+ # Warp each view to the panorama
121
+ panorama_log_distance_grad_maps, panorama_grad_masks = [], []
122
+ panorama_log_distance_laplacian_maps, panorama_laplacian_masks = [], []
123
+ panorama_pred_masks = []
124
+ for i in range(len(distance_maps)):
125
+ projected_uv, projected_depth = utils3d.numpy.project_cv(spherical_directions, extrinsics=extrinsics[i], intrinsics=intrinsics[i])
126
+ projection_valid_mask = (projected_depth > 0) & (projected_uv > 0).all(axis=-1) & (projected_uv < 1).all(axis=-1)
127
+
128
+ projected_pixels = utils3d.numpy.uv_to_pixel(np.clip(projected_uv, 0, 1), width=distance_maps[i].shape[1], height=distance_maps[i].shape[0]).astype(np.float32)
129
+
130
+ log_splitted_distance = np.log(distance_maps[i])
131
+ panorama_log_distance_map = np.where(projection_valid_mask, cv2.remap(log_splitted_distance, projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_LINEAR, borderMode=cv2.BORDER_REPLICATE), 0)
132
+ panorama_pred_mask = projection_valid_mask & (cv2.remap(pred_masks[i].astype(np.uint8), projected_pixels[..., 0], projected_pixels[..., 1], cv2.INTER_NEAREST, borderMode=cv2.BORDER_REPLICATE) > 0)
133
+
134
+ # calculate gradient map
135
+ padded = np.pad(panorama_log_distance_map, ((0, 0), (0, 1)), mode='wrap')
136
+ grad_x, grad_y = padded[:, :-1] - padded[:, 1:], padded[:-1, :] - padded[1:, :]
137
+
138
+ padded = np.pad(panorama_pred_mask, ((0, 0), (0, 1)), mode='wrap')
139
+ mask_x, mask_y = padded[:, :-1] & padded[:, 1:], padded[:-1, :] & padded[1:, :]
140
+
141
+ panorama_log_distance_grad_maps.append((grad_x, grad_y))
142
+ panorama_grad_masks.append((mask_x, mask_y))
143
+
144
+ # calculate laplacian map
145
+ padded = np.pad(panorama_log_distance_map, ((1, 1), (0, 0)), mode='edge')
146
+ padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
147
+ laplacian = convolve(padded, np.array([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=np.float32))[1:-1, 1:-1]
148
+
149
+ padded = np.pad(panorama_pred_mask, ((1, 1), (0, 0)), mode='edge')
150
+ padded = np.pad(padded, ((0, 0), (1, 1)), mode='wrap')
151
+ mask = convolve(padded.astype(np.uint8), np.array([[0, 1, 0], [1, 1, 1], [0, 1, 0]], dtype=np.uint8))[1:-1, 1:-1] == 5
152
+
153
+ panorama_log_distance_laplacian_maps.append(laplacian)
154
+ panorama_laplacian_masks.append(mask)
155
+
156
+ panorama_pred_masks.append(panorama_pred_mask)
157
+
158
+ panorama_log_distance_grad_x = np.stack([grad_map[0] for grad_map in panorama_log_distance_grad_maps], axis=0)
159
+ panorama_log_distance_grad_y = np.stack([grad_map[1] for grad_map in panorama_log_distance_grad_maps], axis=0)
160
+ panorama_grad_mask_x = np.stack([mask_map[0] for mask_map in panorama_grad_masks], axis=0)
161
+ panorama_grad_mask_y = np.stack([mask_map[1] for mask_map in panorama_grad_masks], axis=0)
162
+
163
+ panorama_log_distance_grad_x = np.sum(panorama_log_distance_grad_x * panorama_grad_mask_x, axis=0) / np.sum(panorama_grad_mask_x, axis=0).clip(1e-3)
164
+ panorama_log_distance_grad_y = np.sum(panorama_log_distance_grad_y * panorama_grad_mask_y, axis=0) / np.sum(panorama_grad_mask_y, axis=0).clip(1e-3)
165
+
166
+ panorama_laplacian_maps = np.stack(panorama_log_distance_laplacian_maps, axis=0)
167
+ panorama_laplacian_masks = np.stack(panorama_laplacian_masks, axis=0)
168
+ panorama_laplacian_map = np.sum(panorama_laplacian_maps * panorama_laplacian_masks, axis=0) / np.sum(panorama_laplacian_masks, axis=0).clip(1e-3)
169
+
170
+ grad_x_mask = np.any(panorama_grad_mask_x, axis=0).reshape(-1)
171
+ grad_y_mask = np.any(panorama_grad_mask_y, axis=0).reshape(-1)
172
+ grad_mask = np.concatenate([grad_x_mask, grad_y_mask])
173
+ laplacian_mask = np.any(panorama_laplacian_masks, axis=0).reshape(-1)
174
+
175
+ # Solve overdetermined system
176
+ A = vstack([
177
+ grad_equation(width, height, wrap_x=True, wrap_y=False)[grad_mask],
178
+ poisson_equation(width, height, wrap_x=True, wrap_y=False)[laplacian_mask],
179
+ ])
180
+ b = np.concatenate([
181
+ panorama_log_distance_grad_x.reshape(-1)[grad_x_mask],
182
+ panorama_log_distance_grad_y.reshape(-1)[grad_y_mask],
183
+ panorama_laplacian_map.reshape(-1)[laplacian_mask]
184
+ ])
185
+ x, *_ = lsmr(
186
+ A, b,
187
+ atol=1e-5, btol=1e-5,
188
+ x0=np.log(panorama_depth_init).reshape(-1) if panorama_depth_init is not None else None,
189
+ show=False,
190
+ )
191
+
192
+ panorama_depth = np.exp(x).reshape(height, width).astype(np.float32)
193
+ panorama_mask = np.any(panorama_pred_masks, axis=0)
194
+
195
+ return panorama_depth, panorama_mask
196
+
moge/utils/pipeline.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from abc import abstractmethod
3
+ from queue import Empty, Full
4
+ from threading import Thread
5
+ from queue import Queue
6
+ from multiprocessing import Process
7
+ from threading import Thread, Event
8
+ import multiprocessing
9
+ import threading
10
+ import inspect
11
+ import time
12
+ import uuid
13
+ from copy import deepcopy
14
+ import itertools
15
+ import functools
16
+
17
+ # Copied from the MoGe project:
18
+ # https://github.com/microsoft/MoGe
19
+ # Original license: MIT
20
+ # Copyright (c) the MoGe authors
21
+
22
+ __all__ = [
23
+ 'Node',
24
+ 'Link',
25
+ 'ConcurrentNode',
26
+ 'Worker',
27
+ 'WorkerFunction',
28
+ 'Provider',
29
+ 'ProviderFunction',
30
+ 'Sequential',
31
+ 'Batch',
32
+ 'Unbatch',
33
+ 'Parallel',
34
+ 'Graph',
35
+ 'Buffer',
36
+ ]
37
+
38
+ TERMINATE_CHECK_INTERVAL = 0.5
39
+
40
+
41
+ class _ItemWrapper:
42
+ def __init__(self, data: Any, id: Union[int, List[int]] = None):
43
+ self.data = data
44
+ self.id = id
45
+
46
+
47
+ class Terminate(Exception):
48
+ pass
49
+
50
+
51
+ def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper:
52
+ while True:
53
+ try:
54
+ item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL))
55
+ if terminate_flag.is_set():
56
+ raise Terminate()
57
+ return item
58
+ except Empty:
59
+ if terminate_flag.is_set():
60
+ raise Terminate()
61
+
62
+ if timeout is not None:
63
+ timeout -= TERMINATE_CHECK_INTERVAL
64
+ if timeout <= 0:
65
+ raise Empty()
66
+
67
+
68
+ def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event):
69
+ while True:
70
+ try:
71
+ queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL)
72
+ if terminate_flag.is_set():
73
+ raise Terminate()
74
+ return
75
+ except Full:
76
+ if terminate_flag.is_set():
77
+ raise Terminate()
78
+
79
+ class Node:
80
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
81
+ self.input: Queue = Queue(maxsize=in_buffer_size)
82
+ self.output: Queue = Queue(maxsize=out_buffer_size)
83
+ self.in_buffer_size = in_buffer_size
84
+ self.out_buffer_size = out_buffer_size
85
+
86
+ @abstractmethod
87
+ def start(self):
88
+ pass
89
+
90
+ @abstractmethod
91
+ def terminate(self):
92
+ pass
93
+
94
+ def stop(self):
95
+ self.terminate()
96
+ self.join()
97
+
98
+ @abstractmethod
99
+ def join(self):
100
+ pass
101
+
102
+ def put(self, data: Any, key: str = None, block: bool = True) -> None:
103
+ item = _ItemWrapper(data)
104
+ self.input.put(item, block=block)
105
+
106
+ def get(self, key: str = None, block: bool = True) -> Any:
107
+ item: _ItemWrapper = self.output.get(block=block)
108
+ return item.data
109
+
110
+ def __enter__(self):
111
+ self.start()
112
+ return self
113
+
114
+ def __exit__(self, exc_type, exc_value, traceback):
115
+ self.terminate()
116
+ self.join()
117
+
118
+
119
+ class ConcurrentNode(Node):
120
+ job: Union[Thread, Process]
121
+
122
+ def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
123
+ super().__init__(in_buffer_size, out_buffer_size)
124
+ self.running_as = running_as
125
+
126
+ @abstractmethod
127
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
128
+ pass
129
+
130
+ def start(self):
131
+ if self.running_as == 'thread':
132
+ terminate_flag = threading.Event()
133
+ job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
134
+ elif self.running_as == 'process':
135
+ terminate_flag = multiprocessing.Event()
136
+ job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
137
+ job.start()
138
+ self.job = job
139
+ self.terminate_flag = terminate_flag
140
+
141
+ def terminate(self):
142
+ self.terminate_flag.set()
143
+
144
+ def join(self):
145
+ self.job.join()
146
+
147
+
148
+ class Worker(ConcurrentNode):
149
+ def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None:
150
+ super().__init__(running_as, in_buffer_size, out_buffer_size)
151
+
152
+ def init(self) -> None:
153
+ """
154
+ This method is called the the thread is started, to initialize any resources that is only held in the thread.
155
+ """
156
+ pass
157
+
158
+ @abstractmethod
159
+ def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]:
160
+ """
161
+ This method defines the job that the node should do for each input item.
162
+ A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue.
163
+ The method is executed concurrently with other nodes.
164
+ """
165
+ pass
166
+
167
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
168
+ self.init()
169
+ try:
170
+ while True:
171
+ item = _get_queue_item(input, terminate_flag)
172
+ result = self.work(item.data)
173
+ _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag)
174
+
175
+ except Terminate:
176
+ return
177
+
178
+
179
+ class Provider(ConcurrentNode):
180
+ """
181
+ A node that provides data to successive nodes. It takes no input and provides data to the output queue.
182
+ """
183
+ def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None:
184
+ super().__init__(running_as, 0, out_buffer_size)
185
+
186
+ def init(self) -> None:
187
+ """
188
+ This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process.
189
+ """
190
+ pass
191
+
192
+ @abstractmethod
193
+ def provide(self) -> Generator[Any, None, None]:
194
+ pass
195
+
196
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
197
+ self.init()
198
+ try:
199
+ for data in self.provide():
200
+ _put_queue_item(output, _ItemWrapper(data), terminate_flag)
201
+ except Terminate:
202
+ return
203
+
204
+
205
+ class WorkerFunction(Worker):
206
+ def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
207
+ super().__init__(running_as, in_buffer_size, out_buffer_size)
208
+ self.fn = fn
209
+
210
+ def work(self, *args, **kwargs):
211
+ return self.fn(*args, **kwargs)
212
+
213
+
214
+ class ProviderFunction(Provider):
215
+ def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None:
216
+ super().__init__(running_as, out_buffer_size)
217
+ self.fn = fn
218
+
219
+ def provide(self):
220
+ for item in self.fn():
221
+ yield item
222
+
223
+
224
+ class Link:
225
+ def __init__(self, src: Queue, dst: Queue):
226
+ self.src = src
227
+ self.dst = dst
228
+
229
+ def _thread_fn(self):
230
+ try:
231
+ while True:
232
+ item = _get_queue_item(self.src, self.terminate_flag)
233
+ _put_queue_item(self.dst, item, self.terminate_flag)
234
+ except Terminate:
235
+ return
236
+
237
+ def start(self):
238
+ self.terminate_flag = threading.Event()
239
+ self.thread = Thread(target=self._thread_fn)
240
+ self.thread.start()
241
+
242
+ def terminate(self):
243
+ self.terminate_flag.set()
244
+
245
+ def join(self):
246
+ self.thread.join()
247
+
248
+
249
+ class Graph(Node):
250
+ """
251
+ Graph pipeline of nodes and links
252
+ """
253
+ nodes: List[Node]
254
+ links: List[Link]
255
+
256
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
257
+ super().__init__(in_buffer_size, out_buffer_size)
258
+ self.nodes = []
259
+ self.links = []
260
+
261
+ def add(self, node: Node):
262
+ self.nodes.append(node)
263
+
264
+ def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]):
265
+ """
266
+ Links the output of the source node to the input of the destination node.
267
+ If the source or destination node is None, the pipeline's input or output is used.
268
+ """
269
+ src_queue = self.input if src is None else src.output
270
+ dst_queue = self.output if dst is None else dst.input
271
+ self.links.append(Link(src_queue, dst_queue))
272
+
273
+ def chain(self, nodes: Iterable[Node]):
274
+ """
275
+ Link the output of each node to the input of the next node.
276
+ """
277
+ nodes = list(nodes)
278
+ for i in range(len(nodes) - 1):
279
+ self.link(nodes[i], nodes[i + 1])
280
+
281
+ def start(self):
282
+ for node in self.nodes:
283
+ node.start()
284
+ for link in self.links:
285
+ link.start()
286
+
287
+ def terminate(self):
288
+ for node in self.nodes:
289
+ node.terminate()
290
+ for link in self.links:
291
+ link.terminate()
292
+
293
+ def join(self):
294
+ for node in self.nodes:
295
+ node.join()
296
+ for link in self.links:
297
+ link.join()
298
+
299
+ def __iter__(self):
300
+ providers = [node for node in self.nodes if isinstance(node, Provider)]
301
+ if len(providers) == 0:
302
+ raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.")
303
+ with self:
304
+ # while all(provider.job.is_alive() for provider in providers):
305
+ while True:
306
+ yield self.get()
307
+
308
+ def __call__(self, data: Any) -> Any:
309
+ """
310
+ Submit data to the pipeline's input queue, and return the output data asynchronously.
311
+ NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work.
312
+ """
313
+ # TODO
314
+
315
+
316
+ class Sequential(Graph):
317
+ """
318
+ Pipeline of nodes in sequential order, where each node takes the output of the previous node as input.
319
+ The order of input and output items is preserved (FIFO)
320
+ """
321
+ def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
322
+ """
323
+ Initialize the pipeline with a list of nodes to execute sequentially.
324
+ ### Parameters:
325
+ - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
326
+ - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'.
327
+ - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited).
328
+ - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited).
329
+ """
330
+ super().__init__(in_buffer_size, out_buffer_size)
331
+ for node in nodes:
332
+ if isinstance(node, Node):
333
+ pass
334
+ elif isinstance(node, Callable):
335
+ if inspect.isgeneratorfunction(node):
336
+ node = ProviderFunction(node, function_running_as)
337
+ else:
338
+ node = WorkerFunction(node, function_running_as)
339
+ else:
340
+ raise ValueError(f"Invalid node type: {type(node)}")
341
+ self.add(node)
342
+ self.chain([None, *self.nodes, None])
343
+
344
+
345
+ class Parallel(Node):
346
+ """
347
+ A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available.
348
+ NOTE: It is FIFO if and only if all the nested nodes are FIFO.
349
+ """
350
+ nodes: List[Node]
351
+
352
+ def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'):
353
+ super().__init__(in_buffer_size, out_buffer_size)
354
+ self.nodes = []
355
+ for node in nodes:
356
+ if isinstance(node, Node):
357
+ pass
358
+ elif isinstance(node, Callable):
359
+ if inspect.isgeneratorfunction(node):
360
+ node = ProviderFunction(node, function_running_as)
361
+ else:
362
+ node = WorkerFunction(node, function_running_as)
363
+ else:
364
+ raise ValueError(f"Invalid node type: {type(node)}")
365
+ self.nodes.append(node)
366
+ self.output_order = Queue()
367
+ self.lock = threading.Lock()
368
+
369
+ def _in_thread_fn(self, node: Node):
370
+ try:
371
+ while True:
372
+ with self.lock:
373
+ # A better idea: first make sure its node is vacant, then get it a new item.
374
+ # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node.
375
+ # This could lead to suboptimal scheduling.
376
+ item = _get_queue_item(self.input, self.terminate_flag)
377
+ self.output_order.put(node.output)
378
+ _put_queue_item(node.input, item, self.terminate_flag)
379
+ except Terminate:
380
+ return
381
+
382
+ def _out_thread_fn(self):
383
+ try:
384
+ while True:
385
+ queue = _get_queue_item(self.output_order, self.terminate_flag)
386
+ item = _get_queue_item(queue, self.terminate_flag)
387
+ _put_queue_item(self.output, item, self.terminate_flag)
388
+ except Terminate:
389
+ return
390
+
391
+ def start(self):
392
+ self.terminate_flag = threading.Event()
393
+ self.in_threads = []
394
+ for node in self.nodes:
395
+ thread = Thread(target=self._in_thread_fn, args=(node,))
396
+ thread.start()
397
+ self.in_threads.append(thread)
398
+ thread = Thread(target=self._out_thread_fn)
399
+ thread.start()
400
+ self.out_thread = thread
401
+ for node in self.nodes:
402
+ node.start()
403
+
404
+ def terminate(self):
405
+ self.terminate_flag.set()
406
+ for node in self.nodes:
407
+ node.terminate()
408
+
409
+ def join(self):
410
+ for thread in self.in_threads:
411
+ thread.join()
412
+ self.out_thread.join()
413
+
414
+
415
+ class UnorderedParallel(Graph):
416
+ """
417
+ Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available.
418
+ NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input.
419
+ """
420
+ def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
421
+ """
422
+ Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node.
423
+ ### Parameters:
424
+ - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
425
+ - function_running_as: Whether to wrap the function as a thread or process worker. Defaults to 'thread'.
426
+ - in_buffer_size: Maximum size of the input queue of the pipeline. Defaults to 0 (unlimited).
427
+ - out_buffer_size: Maximum size of the output queue of the pipeline. Defaults to 0 (unlimited).
428
+ """
429
+ super().__init__(in_buffer_size, out_buffer_size)
430
+ for node in nodes:
431
+ if isinstance(node, Node):
432
+ pass
433
+ elif isinstance(node, Callable):
434
+ if inspect.isgeneratorfunction(node):
435
+ node = ProviderFunction(node, function_running_as)
436
+ else:
437
+ node = WorkerFunction(node, function_running_as)
438
+ else:
439
+ raise ValueError(f"Invalid node type: {type(node)}")
440
+ self.add(node)
441
+ for i in range(len(nodes)):
442
+ self.chain([None, self.nodes[i], None])
443
+
444
+
445
+ class Batch(ConcurrentNode):
446
+ """
447
+ Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes.
448
+ The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node,
449
+ i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size.
450
+ """
451
+ def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1):
452
+ assert batch_size > 0, "Batch size must be greater than 0."
453
+ super().__init__('thread', in_buffer_size, out_buffer_size)
454
+ self.batch_size = batch_size
455
+ self.patience = patience
456
+
457
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
458
+ try:
459
+ while True:
460
+ batch_id, batch_data = [], []
461
+ # Try to fill the batch
462
+ for i in range(self.batch_size):
463
+ if i == 0 or self.patience is None:
464
+ timeout = None
465
+ else:
466
+ timeout = self.patience - (time.time() - earliest_time)
467
+ if timeout < 0:
468
+ break
469
+ try:
470
+ item = _get_queue_item(input, terminate_flag, timeout)
471
+ except Empty:
472
+ break
473
+
474
+ if i == 0:
475
+ earliest_time = time.time()
476
+ batch_data.append(item.data)
477
+ batch_id.append(item.id)
478
+
479
+ batch = _ItemWrapper(batch_data, batch_id)
480
+ _put_queue_item(output, batch, terminate_flag)
481
+ except Terminate:
482
+ return
483
+
484
+
485
+ class Unbatch(ConcurrentNode):
486
+ """
487
+ Ungroups every batch (a list of items) into individual items and passes them to successive nodes.
488
+ """
489
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
490
+ super().__init__('thread', in_buffer_size, out_buffer_size)
491
+
492
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
493
+ try:
494
+ while True:
495
+ batch = _get_queue_item(input, terminate_flag)
496
+ for id, data in zip(batch.id or itertools.repeat(None), batch.data):
497
+ item = _ItemWrapper(data, id)
498
+ _put_queue_item(output, item, terminate_flag)
499
+ except Terminate:
500
+ return
501
+
502
+
503
+ class Buffer(Node):
504
+ "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time."
505
+ def __init__(self, size: int):
506
+ super().__init__(size, size)
507
+ self.size = size
508
+ self.input = self.output = Queue(maxsize=size)
moge/utils/tools.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+ import time
8
+ from pathlib import Path
9
+ from numbers import Number
10
+ from functools import wraps
11
+ import warnings
12
+ import math
13
+ import json
14
+ import os
15
+ import importlib
16
+ import importlib.util
17
+
18
+
19
+ def catch_exception(fn):
20
+ @wraps(fn)
21
+ def wrapper(*args, **kwargs):
22
+ try:
23
+ return fn(*args, **kwargs)
24
+ except Exception as e:
25
+ import traceback
26
+ print(f"Exception in {fn.__name__}", end='r')
27
+ # print({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})
28
+ traceback.print_exc(chain=False)
29
+ time.sleep(0.1)
30
+ return None
31
+ return wrapper
32
+
33
+
34
+ class CallbackOnException:
35
+ def __init__(self, callback: Callable, exception: type):
36
+ self.exception = exception
37
+ self.callback = callback
38
+
39
+ def __enter__(self):
40
+ return self
41
+
42
+ def __exit__(self, exc_type, exc_val, exc_tb):
43
+ if isinstance(exc_val, self.exception):
44
+ self.callback()
45
+ return True
46
+ return False
47
+
48
+ def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
49
+ for k, v in d.items():
50
+ if isinstance(v, dict):
51
+ for sub_key in traverse_nested_dict_keys(v):
52
+ yield (k, ) + sub_key
53
+ else:
54
+ yield (k, )
55
+
56
+
57
+ def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
58
+ for k in keys:
59
+ d = d.get(k, default)
60
+ if d is None:
61
+ break
62
+ return d
63
+
64
+ def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
65
+ for k in keys[:-1]:
66
+ d = d.setdefault(k, {})
67
+ d[keys[-1]] = value
68
+
69
+
70
+ def key_average(list_of_dicts: list) -> Dict[str, Any]:
71
+ """
72
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
73
+ """
74
+ _nested_dict_keys = set()
75
+ for d in list_of_dicts:
76
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
77
+ _nested_dict_keys = sorted(_nested_dict_keys)
78
+ result = {}
79
+ for k in _nested_dict_keys:
80
+ values = []
81
+ for d in list_of_dicts:
82
+ v = get_nested_dict(d, k)
83
+ if v is not None and not math.isnan(v):
84
+ values.append(v)
85
+ avg = sum(values) / len(values) if values else float('nan')
86
+ set_nested_dict(result, k, avg)
87
+ return result
88
+
89
+
90
+ def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
91
+ """
92
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
93
+ """
94
+ items = []
95
+ if parent_key is None:
96
+ parent_key = ()
97
+ for k, v in d.items():
98
+ new_key = parent_key + (k, )
99
+ if isinstance(v, MutableMapping):
100
+ items.extend(flatten_nested_dict(v, new_key).items())
101
+ else:
102
+ items.append((new_key, v))
103
+ return dict(items)
104
+
105
+
106
+ def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
107
+ """
108
+ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
109
+ """
110
+ result = {}
111
+ for k, v in d.items():
112
+ sub_dict = result
113
+ for k_ in k[:-1]:
114
+ if k_ not in sub_dict:
115
+ sub_dict[k_] = {}
116
+ sub_dict = sub_dict[k_]
117
+ sub_dict[k[-1]] = v
118
+ return result
119
+
120
+
121
+ def read_jsonl(file):
122
+ import json
123
+ with open(file, 'r') as f:
124
+ data = f.readlines()
125
+ return [json.loads(line) for line in data]
126
+
127
+
128
+ def write_jsonl(data: List[dict], file):
129
+ import json
130
+ with open(file, 'w') as f:
131
+ for item in data:
132
+ f.write(json.dumps(item) + '\n')
133
+
134
+
135
+ def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
136
+ import pandas as pd
137
+ data = [flatten_nested_dict(d) for d in data]
138
+ df = pd.DataFrame(data)
139
+ df = df.sort_index(axis=1)
140
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
141
+ return df
142
+
143
+
144
+ def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
145
+ if isinstance(d, str):
146
+ for old, new in mapping.items():
147
+ d = d.replace(old, new)
148
+ elif isinstance(d, list):
149
+ for i, item in enumerate(d):
150
+ d[i] = recursive_replace(item, mapping)
151
+ elif isinstance(d, dict):
152
+ for k, v in d.items():
153
+ d[k] = recursive_replace(v, mapping)
154
+ return d
155
+
156
+
157
+ class timeit:
158
+ _history: Dict[str, List['timeit']] = {}
159
+
160
+ def __init__(self, name: str = None, verbose: bool = True, average: bool = False):
161
+ self.name = name
162
+ self.verbose = verbose
163
+ self.start = None
164
+ self.end = None
165
+ self.average = average
166
+ if average and name not in timeit._history:
167
+ timeit._history[name] = []
168
+
169
+ def __call__(self, func: Callable):
170
+ import inspect
171
+ if inspect.iscoroutinefunction(func):
172
+ async def wrapper(*args, **kwargs):
173
+ with timeit(self.name or func.__qualname__):
174
+ ret = await func(*args, **kwargs)
175
+ return ret
176
+ return wrapper
177
+ else:
178
+ def wrapper(*args, **kwargs):
179
+ with timeit(self.name or func.__qualname__):
180
+ ret = func(*args, **kwargs)
181
+ return ret
182
+ return wrapper
183
+
184
+ def __enter__(self):
185
+ self.start = time.time()
186
+ return self
187
+
188
+ @property
189
+ def time(self) -> float:
190
+ assert self.start is not None, "Time not yet started."
191
+ assert self.end is not None, "Time not yet ended."
192
+ return self.end - self.start
193
+
194
+ @property
195
+ def average_time(self) -> float:
196
+ assert self.average, "Average time not available."
197
+ return sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
198
+
199
+ @property
200
+ def history(self) -> List['timeit']:
201
+ return timeit._history.get(self.name, [])
202
+
203
+ def __exit__(self, exc_type, exc_val, exc_tb):
204
+ self.end = time.time()
205
+ if self.average:
206
+ timeit._history[self.name].append(self)
207
+ if self.verbose:
208
+ if self.average:
209
+ avg = self.average_time
210
+ print(f"{self.name or 'It'} took {avg:.6f} seconds in average.")
211
+ else:
212
+ print(f"{self.name or 'It'} took {self.time:.6f} seconds.")
213
+
214
+
215
+ def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
216
+ first = strings[0]
217
+
218
+ for start in range(len(first)):
219
+ if any(s[start] != strings[0][start] for s in strings):
220
+ break
221
+
222
+ for end in range(1, min(len(s) for s in strings)):
223
+ if any(s[-end] != first[-end] for s in strings):
224
+ break
225
+
226
+ return [s[start:len(s) - end + 1] for s in strings]
227
+
228
+
229
+ def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
230
+ from concurrent.futures import ThreadPoolExecutor
231
+ from contextlib import nullcontext
232
+ from tqdm import tqdm
233
+
234
+ if pbar is not None:
235
+ pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
236
+ else:
237
+ pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
238
+
239
+ def decorator(fn: Callable):
240
+ with (
241
+ ThreadPoolExecutor(max_workers=num_workers) as executor,
242
+ pbar
243
+ ):
244
+ pbar.refresh()
245
+ @catch_exception
246
+ @suppress_traceback
247
+ def _fn(input):
248
+ ret = fn(input)
249
+ pbar.update()
250
+ return ret
251
+ executor.map(_fn, inputs)
252
+ executor.shutdown(wait=True)
253
+
254
+ return decorator
255
+
256
+
257
+ def suppress_traceback(fn):
258
+ @wraps(fn)
259
+ def wrapper(*args, **kwargs):
260
+ try:
261
+ return fn(*args, **kwargs)
262
+ except Exception as e:
263
+ e.__traceback__ = e.__traceback__.tb_next.tb_next
264
+ raise
265
+ return wrapper
266
+
267
+
268
+ class no_warnings:
269
+ def __init__(self, action: str = 'ignore', **kwargs):
270
+ self.action = action
271
+ self.filter_kwargs = kwargs
272
+
273
+ def __call__(self, fn):
274
+ @wraps(fn)
275
+ def wrapper(*args, **kwargs):
276
+ with warnings.catch_warnings():
277
+ warnings.simplefilter(self.action, **self.filter_kwargs)
278
+ return fn(*args, **kwargs)
279
+ return wrapper
280
+
281
+ def __enter__(self):
282
+ self.warnings_manager = warnings.catch_warnings()
283
+ self.warnings_manager.__enter__()
284
+ warnings.simplefilter(self.action, **self.filter_kwargs)
285
+
286
+ def __exit__(self, exc_type, exc_val, exc_tb):
287
+ self.warnings_manager.__exit__(exc_type, exc_val, exc_tb)
288
+
289
+
290
+ def import_file_as_module(file_path: Union[str, os.PathLike], module_name: str):
291
+ spec = importlib.util.spec_from_file_location(module_name, file_path)
292
+ module = importlib.util.module_from_spec(spec)
293
+ spec.loader.exec_module(module)
294
+ return module
moge/utils/vis.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+
8
+ import numpy as np
9
+ import matplotlib
10
+
11
+
12
+ def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
13
+ if mask is None:
14
+ depth = np.where(depth > 0, depth, np.nan)
15
+ else:
16
+ depth = np.where((depth > 0) & mask, depth, np.nan)
17
+ disp = 1 / depth
18
+ if normalize:
19
+ min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.99)
20
+ disp = (disp - min_disp) / (max_disp - min_disp)
21
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp)[..., :3], 0)
22
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
23
+ return colored
24
+
25
+
26
+ def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
27
+ if mask is not None:
28
+ depth = np.where(mask, depth, np.nan)
29
+
30
+ min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
31
+ depth = (depth - min_depth) / (max_depth - min_depth)
32
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](depth)[..., :3], 0)
33
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
34
+ return colored
35
+
36
+
37
+ def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
38
+ if mask is not None:
39
+ disparity = np.where(mask, disparity, np.nan)
40
+
41
+ if normalize:
42
+ min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
43
+ disparity = (disparity - min_disp) / (max_disp - min_disp)
44
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity)[..., :3], 0)
45
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
46
+ return colored
47
+
48
+
49
+ def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray:
50
+ colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)[..., :3]
51
+ colored = np.ascontiguousarray((colored.clip(0, 1) * 255).astype(np.uint8))
52
+ return colored
53
+
54
+
55
+ def colorize_normal(normal: np.ndarray, mask: np.ndarray = None) -> np.ndarray:
56
+ if mask is not None:
57
+ normal = np.where(mask[..., None], normal, 0)
58
+ normal = normal * [0.5, -0.5, -0.5] + 0.5
59
+ normal = (normal.clip(0, 1) * 255).astype(np.uint8)
60
+ return normal
61
+
62
+
63
+ def colorize_error_map(error_map: np.ndarray, mask: np.ndarray = None, cmap: str = 'plasma', value_range: Tuple[float, float] = None):
64
+ vmin, vmax = value_range if value_range is not None else (np.nanmin(error_map), np.nanmax(error_map))
65
+ cmap = matplotlib.colormaps[cmap]
66
+ colorized_error_map = cmap(((error_map - vmin) / (vmax - vmin)).clip(0, 1))[..., :3]
67
+ if mask is not None:
68
+ colorized_error_map = np.where(mask[..., None], colorized_error_map, 0)
69
+ colorized_error_map = np.ascontiguousarray((colorized_error_map.clip(0, 1) * 255).astype(np.uint8))
70
+ return colorized_error_map
moge/utils/webfile.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ import requests
7
+ from typing import *
8
+
9
+ __all__ = ["WebFile"]
10
+
11
+
12
+ class WebFile:
13
+ def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None):
14
+ self.url = url
15
+ self.session = session or requests.Session()
16
+ self.session.headers.update(headers or {})
17
+ self._offset = 0
18
+ self.size = size if size is not None else self._fetch_size()
19
+
20
+ def _fetch_size(self):
21
+ with self.session.get(self.url, stream=True) as response:
22
+ response.raise_for_status()
23
+ content_length = response.headers.get("Content-Length")
24
+ if content_length is None:
25
+ raise ValueError("Missing Content-Length in header")
26
+ return int(content_length)
27
+
28
+ def _fetch_data(self, offset: int, n: int) -> bytes:
29
+ headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"}
30
+ response = self.session.get(self.url, headers=headers)
31
+ response.raise_for_status()
32
+ return response.content
33
+
34
+ def seekable(self) -> bool:
35
+ return True
36
+
37
+ def tell(self) -> int:
38
+ return self._offset
39
+
40
+ def available(self) -> int:
41
+ return self.size - self._offset
42
+
43
+ def seek(self, offset: int, whence: int = 0) -> None:
44
+ if whence == 0:
45
+ new_offset = offset
46
+ elif whence == 1:
47
+ new_offset = self._offset + offset
48
+ elif whence == 2:
49
+ new_offset = self.size + offset
50
+ else:
51
+ raise ValueError("Invalid value for whence")
52
+
53
+ self._offset = max(0, min(new_offset, self.size))
54
+
55
+ def read(self, n: Optional[int] = None) -> bytes:
56
+ if n is None or n < 0:
57
+ n = self.available()
58
+ else:
59
+ n = min(n, self.available())
60
+
61
+ if n == 0:
62
+ return b''
63
+
64
+ data = self._fetch_data(self._offset, n)
65
+ self._offset += len(data)
66
+
67
+ return data
68
+
69
+ def close(self) -> None:
70
+ pass
71
+
72
+ def __enter__(self):
73
+ return self
74
+
75
+ def __exit__(self, exc_type, exc_value, traceback):
76
+ pass
77
+
78
+
moge/utils/webzipfile.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copied from the MoGe project:
2
+ # https://github.com/microsoft/MoGe
3
+ # Original license: MIT
4
+ # Copyright (c) the MoGe authors
5
+
6
+ from typing import *
7
+ import io
8
+ import os
9
+ from zipfile import (
10
+ ZipInfo, BadZipFile, ZipFile, ZipExtFile,
11
+ sizeFileHeader, structFileHeader, stringFileHeader,
12
+ _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS,
13
+ _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED
14
+ )
15
+ import struct
16
+ from requests import Session
17
+
18
+ from .webfile import WebFile
19
+
20
+
21
+ class _SharedWebFile(WebFile):
22
+ def __init__(self, webfile: WebFile, pos: int):
23
+ super().__init__(webfile.url, webfile.session, size=webfile.size)
24
+ self.seek(pos)
25
+
26
+
27
+ class WebZipFile(ZipFile):
28
+ "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads."
29
+ def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None):
30
+ """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x',
31
+ or append 'a'."""
32
+ webf = WebFile(url, session=session, headers=headers)
33
+ super().__init__(webf, mode='r')
34
+
35
+ def open(self, name, mode="r", pwd=None, *, force_zip64=False):
36
+ """Return file-like object for 'name'.
37
+
38
+ name is a string for the file name within the ZIP file, or a ZipInfo
39
+ object.
40
+
41
+ mode should be 'r' to read a file already in the ZIP file, or 'w' to
42
+ write to a file newly added to the archive.
43
+
44
+ pwd is the password to decrypt files (only used for reading).
45
+
46
+ When writing, if the file size is not known in advance but may exceed
47
+ 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large
48
+ files. If the size is known in advance, it is best to pass a ZipInfo
49
+ instance for name, with zinfo.file_size set.
50
+ """
51
+ if mode not in {"r", "w"}:
52
+ raise ValueError('open() requires mode "r" or "w"')
53
+ if pwd and (mode == "w"):
54
+ raise ValueError("pwd is only supported for reading files")
55
+ if not self.fp:
56
+ raise ValueError(
57
+ "Attempt to use ZIP archive that was already closed")
58
+
59
+ assert mode == "r", "Only read mode is supported for now"
60
+
61
+ # Make sure we have an info object
62
+ if isinstance(name, ZipInfo):
63
+ # 'name' is already an info object
64
+ zinfo = name
65
+ elif mode == 'w':
66
+ zinfo = ZipInfo(name)
67
+ zinfo.compress_type = self.compression
68
+ zinfo._compresslevel = self.compresslevel
69
+ else:
70
+ # Get info object for name
71
+ zinfo = self.getinfo(name)
72
+
73
+ if mode == 'w':
74
+ return self._open_to_write(zinfo, force_zip64=force_zip64)
75
+
76
+ if self._writing:
77
+ raise ValueError("Can't read from the ZIP file while there "
78
+ "is an open writing handle on it. "
79
+ "Close the writing handle before trying to read.")
80
+
81
+ # Open for reading:
82
+ self._fileRefCnt += 1
83
+ zef_file = _SharedWebFile(self.fp, zinfo.header_offset)
84
+
85
+ try:
86
+ # Skip the file header:
87
+ fheader = zef_file.read(sizeFileHeader)
88
+ if len(fheader) != sizeFileHeader:
89
+ raise BadZipFile("Truncated file header")
90
+ fheader = struct.unpack(structFileHeader, fheader)
91
+ if fheader[_FH_SIGNATURE] != stringFileHeader:
92
+ raise BadZipFile("Bad magic number for file header")
93
+
94
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
95
+ if fheader[_FH_EXTRA_FIELD_LENGTH]:
96
+ zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1)
97
+
98
+ if zinfo.flag_bits & _MASK_COMPRESSED_PATCH:
99
+ # Zip 2.7: compressed patched data
100
+ raise NotImplementedError("compressed patched data (flag bit 5)")
101
+
102
+ if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION:
103
+ # strong encryption
104
+ raise NotImplementedError("strong encryption (flag bit 6)")
105
+
106
+ if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME:
107
+ # UTF-8 filename
108
+ fname_str = fname.decode("utf-8")
109
+ else:
110
+ fname_str = fname.decode(self.metadata_encoding or "cp437")
111
+
112
+ if fname_str != zinfo.orig_filename:
113
+ raise BadZipFile(
114
+ 'File name in directory %r and header %r differ.'
115
+ % (zinfo.orig_filename, fname))
116
+
117
+ # check for encrypted flag & handle password
118
+ is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED
119
+ if is_encrypted:
120
+ if not pwd:
121
+ pwd = self.pwd
122
+ if pwd and not isinstance(pwd, bytes):
123
+ raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__)
124
+ if not pwd:
125
+ raise RuntimeError("File %r is encrypted, password "
126
+ "required for extraction" % name)
127
+ else:
128
+ pwd = None
129
+
130
+ return ZipExtFile(zef_file, mode, zinfo, pwd, True)
131
+ except:
132
+ zef_file.close()
133
+ raise