Spaces:
Runtime error
Runtime error
Commit
Β·
201e424
1
Parent(s):
735bc3b
add flash3d model and unidepth code
Browse files- flash3d/networks/depth_decoder.py +81 -0
- flash3d/networks/gaussian_decoder.py +196 -0
- flash3d/networks/gaussian_predictor.py +293 -0
- flash3d/networks/layers.py +295 -0
- flash3d/networks/resnet_encoder.py +115 -0
- flash3d/networks/unidepth.py +577 -0
- flash3d/networks/unidepth_extension.py +205 -0
- flash3d/unidepth/layers/__init__.py +21 -0
- flash3d/unidepth/layers/activation.py +15 -0
- flash3d/unidepth/layers/attention.py +308 -0
- flash3d/unidepth/layers/convnext.py +44 -0
- flash3d/unidepth/layers/drop_path.py +25 -0
- flash3d/unidepth/layers/layer_scale.py +17 -0
- flash3d/unidepth/layers/mlp.py +34 -0
- flash3d/unidepth/layers/nystrom_attention.py +74 -0
- flash3d/unidepth/layers/positional_encoding.py +228 -0
- flash3d/unidepth/layers/upsample.py +69 -0
- flash3d/unidepth/models/__init__.py +5 -0
- flash3d/unidepth/models/backbones/__init__.py +9 -0
- flash3d/unidepth/models/backbones/convnext.py +590 -0
- flash3d/unidepth/models/backbones/convnext2.py +288 -0
- flash3d/unidepth/models/backbones/dinov2.py +552 -0
- flash3d/unidepth/models/backbones/metadinov2/__init__.py +12 -0
- flash3d/unidepth/models/backbones/metadinov2/attention.py +85 -0
- flash3d/unidepth/models/backbones/metadinov2/block.py +284 -0
- flash3d/unidepth/models/backbones/metadinov2/dino_head.py +68 -0
- flash3d/unidepth/models/backbones/metadinov2/drop_path.py +37 -0
- flash3d/unidepth/models/backbones/metadinov2/layer_scale.py +28 -0
- flash3d/unidepth/models/backbones/metadinov2/mlp.py +41 -0
- flash3d/unidepth/models/backbones/metadinov2/patch_embed.py +101 -0
- flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py +63 -0
- flash3d/unidepth/models/encoder.py +184 -0
- flash3d/unidepth/models/unidepthv1/__init__.py +5 -0
- flash3d/unidepth/models/unidepthv1/decoder.py +542 -0
- flash3d/unidepth/models/unidepthv1/unidepthv1.py +329 -0
- flash3d/unidepth/ops/__init__.py +9 -0
- flash3d/unidepth/ops/losses.py +429 -0
- flash3d/unidepth/ops/scheduler.py +70 -0
- flash3d/unidepth/utils/__init__.py +35 -0
- flash3d/unidepth/utils/constants.py +21 -0
- flash3d/unidepth/utils/distributed.py +179 -0
- flash3d/unidepth/utils/ema_torch.py +342 -0
- flash3d/unidepth/utils/evaluation_depth.py +173 -0
- flash3d/unidepth/utils/geometric.py +248 -0
- flash3d/unidepth/utils/misc.py +403 -0
- flash3d/unidepth/utils/positional_embedding.py +274 -0
- flash3d/unidepth/utils/sht.py +1637 -0
- flash3d/unidepth/utils/visualization.py +201 -0
- flash3d/util/vis3d.py +135 -0
flash3d/networks/depth_decoder.py
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
| 4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
| 5 |
+
# available in the LICENSE file.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
from collections import OrderedDict
|
| 12 |
+
from networks.layers import upsample, ConvBlock, Conv3x3
|
| 13 |
+
|
| 14 |
+
from einops import rearrange
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DepthDecoder(nn.Module):
|
| 18 |
+
def __init__(self, cfg, num_ch_enc, num_output_channels=1, use_skips=True):
|
| 19 |
+
super(DepthDecoder, self).__init__()
|
| 20 |
+
|
| 21 |
+
self.cfg = cfg
|
| 22 |
+
depth_num = cfg.model.gaussians_per_pixel - 1 if "unidepth" in cfg.model.name else cfg.model.gaussians_per_pixel
|
| 23 |
+
self.num_output_channels = num_output_channels * depth_num
|
| 24 |
+
self.use_skips = use_skips
|
| 25 |
+
self.upsample_mode = 'nearest'
|
| 26 |
+
self.scales = cfg.model.scales
|
| 27 |
+
|
| 28 |
+
self.num_ch_enc = num_ch_enc
|
| 29 |
+
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
|
| 30 |
+
|
| 31 |
+
# decoder
|
| 32 |
+
self.convs = OrderedDict()
|
| 33 |
+
for i in range(4, -1, -1):
|
| 34 |
+
# upconv_0
|
| 35 |
+
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
|
| 36 |
+
num_ch_out = self.num_ch_dec[i]
|
| 37 |
+
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
|
| 38 |
+
|
| 39 |
+
# upconv_1
|
| 40 |
+
num_ch_in = self.num_ch_dec[i]
|
| 41 |
+
if self.use_skips and i > 0:
|
| 42 |
+
num_ch_in += self.num_ch_enc[i - 1]
|
| 43 |
+
num_ch_out = self.num_ch_dec[i]
|
| 44 |
+
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
|
| 45 |
+
|
| 46 |
+
for s in self.scales:
|
| 47 |
+
out = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
|
| 48 |
+
self.convs[("dispconv", s)] = out
|
| 49 |
+
nn.init.xavier_uniform_(out.conv.weight, cfg.model.depth_scale)
|
| 50 |
+
nn.init.constant_(out.conv.bias, cfg.model.depth_bias)
|
| 51 |
+
|
| 52 |
+
self.decoder = nn.ModuleList(list(self.convs.values()))
|
| 53 |
+
if cfg.model.depth_type in ["disp", "disp_inc"]:
|
| 54 |
+
self.activate = nn.Sigmoid()
|
| 55 |
+
elif cfg.model.depth_type == "depth":
|
| 56 |
+
self.activate = nn.Softplus()
|
| 57 |
+
elif cfg.model.depth_type == "depth_inc":
|
| 58 |
+
self.activate = torch.exp
|
| 59 |
+
|
| 60 |
+
def forward(self, input_features):
|
| 61 |
+
outputs = {}
|
| 62 |
+
x = input_features[-1]
|
| 63 |
+
for i in range(4, -1, -1):
|
| 64 |
+
x = self.convs[("upconv", i, 0)](x)
|
| 65 |
+
x = [upsample(x)]
|
| 66 |
+
if self.use_skips and i > 0:
|
| 67 |
+
x += [input_features[i - 1]]
|
| 68 |
+
x = torch.cat(x, 1)
|
| 69 |
+
x = self.convs[("upconv", i, 1)](x)
|
| 70 |
+
if i in self.scales:
|
| 71 |
+
depth_num = self.cfg.model.gaussians_per_pixel - 1 if "unidepth" in self.cfg.model.name else self.cfg.model.gaussians_per_pixel
|
| 72 |
+
if self.cfg.model.depth_type == "depth_inc":
|
| 73 |
+
outputs[("depth", i)] = rearrange(self.activate(torch.clamp(self.convs[("dispconv", i)](x), min=-10.0, max=6.0)),
|
| 74 |
+
'b (n c) ...-> (b n) c ...', n = depth_num)
|
| 75 |
+
elif self.cfg.model.depth_type in ["disp", "disp_inc"]:
|
| 76 |
+
outputs[("disp", i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
|
| 77 |
+
'b (n c) ...-> (b n) c ...', n = depth_num)
|
| 78 |
+
else:
|
| 79 |
+
outputs[(self.cfg.model.depth_type, i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
|
| 80 |
+
'b (n c) ...-> (b n) c ...', n = depth_num)
|
| 81 |
+
return outputs
|
flash3d/networks/gaussian_decoder.py
ADDED
|
@@ -0,0 +1,196 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
import numpy as np
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def upsample(x):
|
| 9 |
+
"""Upsample input tensor by a factor of 2
|
| 10 |
+
"""
|
| 11 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class Conv3x3(nn.Module):
|
| 15 |
+
"""Layer to pad and convolve input
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, in_channels, out_channels, use_refl=True):
|
| 18 |
+
super(Conv3x3, self).__init__()
|
| 19 |
+
|
| 20 |
+
if use_refl:
|
| 21 |
+
self.pad = nn.ReflectionPad2d(1)
|
| 22 |
+
else:
|
| 23 |
+
self.pad = nn.ZeroPad2d(1)
|
| 24 |
+
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
|
| 25 |
+
|
| 26 |
+
def forward(self, x):
|
| 27 |
+
out = self.pad(x)
|
| 28 |
+
out = self.conv(out)
|
| 29 |
+
return out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class ConvBlock(nn.Module):
|
| 33 |
+
"""Layer to perform a convolution followed by ELU
|
| 34 |
+
"""
|
| 35 |
+
def __init__(self, in_channels, out_channels):
|
| 36 |
+
super(ConvBlock, self).__init__()
|
| 37 |
+
|
| 38 |
+
self.conv = Conv3x3(in_channels, out_channels)
|
| 39 |
+
self.nonlin = nn.ELU(inplace=True)
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
out = self.conv(x)
|
| 43 |
+
out = self.nonlin(out)
|
| 44 |
+
return out
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_splits_and_inits(cfg):
|
| 48 |
+
split_dimensions = []
|
| 49 |
+
scale_inits = []
|
| 50 |
+
bias_inits = []
|
| 51 |
+
|
| 52 |
+
for g_idx in range(cfg.model.gaussians_per_pixel):
|
| 53 |
+
if cfg.model.predict_offset:
|
| 54 |
+
split_dimensions += [3]
|
| 55 |
+
scale_inits += [cfg.model.xyz_scale]
|
| 56 |
+
bias_inits += [cfg.model.xyz_bias]
|
| 57 |
+
|
| 58 |
+
split_dimensions += [1, 3, 4, 3]
|
| 59 |
+
scale_inits += [cfg.model.opacity_scale,
|
| 60 |
+
cfg.model.scale_scale,
|
| 61 |
+
1.0,
|
| 62 |
+
5.0]
|
| 63 |
+
bias_inits += [cfg.model.opacity_bias,
|
| 64 |
+
np.log(cfg.model.scale_bias),
|
| 65 |
+
0.0,
|
| 66 |
+
0.0]
|
| 67 |
+
|
| 68 |
+
if cfg.model.max_sh_degree != 0:
|
| 69 |
+
sh_num = (cfg.model.max_sh_degree + 1) ** 2 - 1
|
| 70 |
+
sh_num_rgb = sh_num * 3
|
| 71 |
+
split_dimensions.append(sh_num_rgb)
|
| 72 |
+
scale_inits.append(cfg.model.sh_scale)
|
| 73 |
+
bias_inits.append(0.0)
|
| 74 |
+
if not cfg.model.one_gauss_decoder:
|
| 75 |
+
break
|
| 76 |
+
|
| 77 |
+
return split_dimensions, scale_inits, bias_inits,
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class GaussianDecoder(nn.Module):
|
| 81 |
+
def __init__(self, cfg, num_ch_enc, use_skips=True):
|
| 82 |
+
super(GaussianDecoder, self).__init__()
|
| 83 |
+
|
| 84 |
+
self.cfg = cfg
|
| 85 |
+
self.use_skips = use_skips
|
| 86 |
+
self.upsample_mode = 'nearest'
|
| 87 |
+
|
| 88 |
+
self.num_ch_enc = num_ch_enc
|
| 89 |
+
self.num_ch_dec = np.array(cfg.model.num_ch_dec)
|
| 90 |
+
|
| 91 |
+
split_dimensions, scale, bias = get_splits_and_inits(cfg)
|
| 92 |
+
|
| 93 |
+
# [offset], opacity, scaling, rotation, feat_dc
|
| 94 |
+
assert not cfg.model.unified_decoder
|
| 95 |
+
|
| 96 |
+
self.split_dimensions = split_dimensions
|
| 97 |
+
|
| 98 |
+
self.num_output_channels = sum(self.split_dimensions)
|
| 99 |
+
|
| 100 |
+
# decoder
|
| 101 |
+
self.convs = OrderedDict()
|
| 102 |
+
for i in range(4, -1, -1):
|
| 103 |
+
# upconv_0
|
| 104 |
+
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
|
| 105 |
+
num_ch_out = self.num_ch_dec[i]
|
| 106 |
+
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
|
| 107 |
+
|
| 108 |
+
# upconv_1
|
| 109 |
+
num_ch_in = self.num_ch_dec[i]
|
| 110 |
+
if self.use_skips and i > 0:
|
| 111 |
+
num_ch_in += self.num_ch_enc[i - 1]
|
| 112 |
+
num_ch_out = self.num_ch_dec[i]
|
| 113 |
+
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
|
| 114 |
+
|
| 115 |
+
self.out = nn.Conv2d(self.num_ch_dec[0], self.num_output_channels, 1)
|
| 116 |
+
|
| 117 |
+
out_channels = self.split_dimensions
|
| 118 |
+
start_channels = 0
|
| 119 |
+
for out_channel, b, s in zip(out_channels, bias, scale):
|
| 120 |
+
nn.init.xavier_uniform_(
|
| 121 |
+
self.out.weight[start_channels:start_channels+out_channel,
|
| 122 |
+
:, :, :], s)
|
| 123 |
+
nn.init.constant_(
|
| 124 |
+
self.out.bias[start_channels:start_channels+out_channel], b)
|
| 125 |
+
start_channels += out_channel
|
| 126 |
+
|
| 127 |
+
self.decoder = nn.ModuleList(list(self.convs.values()))
|
| 128 |
+
|
| 129 |
+
self.scaling_activation = torch.exp
|
| 130 |
+
self.opacity_activation = torch.sigmoid
|
| 131 |
+
self.rotation_activation = torch.nn.functional.normalize
|
| 132 |
+
self.scaling_lambda = cfg.model.scale_lambda
|
| 133 |
+
self.sigmoid = nn.Sigmoid()
|
| 134 |
+
|
| 135 |
+
def forward(self, input_features):
|
| 136 |
+
self.outputs = {}
|
| 137 |
+
|
| 138 |
+
# decoder
|
| 139 |
+
x = input_features[-1]
|
| 140 |
+
for i in range(4, -1, -1):
|
| 141 |
+
x = self.convs[("upconv", i, 0)](x)
|
| 142 |
+
x = [upsample(x)]
|
| 143 |
+
if self.use_skips and i > 0:
|
| 144 |
+
x += [input_features[i - 1]]
|
| 145 |
+
x = torch.cat(x, 1)
|
| 146 |
+
x = self.convs[("upconv", i, 1)](x)
|
| 147 |
+
|
| 148 |
+
x = self.out(x)
|
| 149 |
+
|
| 150 |
+
split_network_outputs = x.split(self.split_dimensions, dim=1)
|
| 151 |
+
|
| 152 |
+
offset_list = []
|
| 153 |
+
opacity_list = []
|
| 154 |
+
scaling_list = []
|
| 155 |
+
rotation_list = []
|
| 156 |
+
feat_dc_list = []
|
| 157 |
+
feat_rest_list = []
|
| 158 |
+
|
| 159 |
+
assert not self.cfg.model.unified_decoder
|
| 160 |
+
|
| 161 |
+
for i in range(self.cfg.model.gaussians_per_pixel):
|
| 162 |
+
assert self.cfg.model.max_sh_degree > 0
|
| 163 |
+
if self.cfg.model.predict_offset:
|
| 164 |
+
offset_s, opacity_s, scaling_s, \
|
| 165 |
+
rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*6:(i+1)*6]
|
| 166 |
+
offset_list.append(offset_s[:, None, ...])
|
| 167 |
+
else:
|
| 168 |
+
opacity_s, scaling_s, rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*5:(i+1)*5]
|
| 169 |
+
opacity_list.append(opacity_s[:, None, ...])
|
| 170 |
+
scaling_list.append(scaling_s[:, None, ...])
|
| 171 |
+
rotation_list.append(rotation_s[:, None, ...])
|
| 172 |
+
feat_dc_list.append(feat_dc_s[:, None, ...])
|
| 173 |
+
feat_rest_list.append(features_rest_s[:, None, ...])
|
| 174 |
+
if not self.cfg.model.one_gauss_decoder:
|
| 175 |
+
break
|
| 176 |
+
|
| 177 |
+
# squeezing will remove dimension if there is only one gaussian per pixel
|
| 178 |
+
opacity = torch.cat(opacity_list, dim=1).squeeze(1)
|
| 179 |
+
scaling = torch.cat(scaling_list, dim=1).squeeze(1)
|
| 180 |
+
rotation = torch.cat(rotation_list, dim=1).squeeze(1)
|
| 181 |
+
feat_dc = torch.cat(feat_dc_list, dim=1).squeeze(1)
|
| 182 |
+
features_rest = torch.cat(feat_rest_list, dim=1).squeeze(1)
|
| 183 |
+
|
| 184 |
+
out = {
|
| 185 |
+
("gauss_opacity", 0): self.opacity_activation(opacity),
|
| 186 |
+
("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda,
|
| 187 |
+
("gauss_rotation", 0): self.rotation_activation(rotation),
|
| 188 |
+
("gauss_features_dc", 0): feat_dc,
|
| 189 |
+
("gauss_features_rest", 0): features_rest
|
| 190 |
+
}
|
| 191 |
+
|
| 192 |
+
if self.cfg.model.predict_offset:
|
| 193 |
+
offset = torch.cat(offset_list, dim=1).squeeze(1)
|
| 194 |
+
out[("gauss_offset", 0)] = offset
|
| 195 |
+
return out
|
| 196 |
+
|
flash3d/networks/gaussian_predictor.py
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
import logging
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
|
| 8 |
+
from networks.layers import BackprojectDepth, disp_to_depth
|
| 9 |
+
from networks.resnet_encoder import ResnetEncoder
|
| 10 |
+
from networks.depth_decoder import DepthDecoder
|
| 11 |
+
from networks.gaussian_decoder import GaussianDecoder
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def default_param_group(model):
|
| 15 |
+
return [{'params': model.parameters()}]
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def to_device(inputs, device):
|
| 19 |
+
for key, ipt in inputs.items():
|
| 20 |
+
if isinstance(ipt, torch.Tensor):
|
| 21 |
+
inputs[key] = ipt.to(device)
|
| 22 |
+
return inputs
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class GaussianPredictor(nn.Module):
|
| 26 |
+
def __init__(self, cfg):
|
| 27 |
+
super().__init__()
|
| 28 |
+
self.cfg = cfg
|
| 29 |
+
|
| 30 |
+
# checking height and width are multiples of 32
|
| 31 |
+
# assert cfg.dataset.width % 32 == 0, "'width' must be a multiple of 32"
|
| 32 |
+
|
| 33 |
+
models = {}
|
| 34 |
+
self.parameters_to_train = []
|
| 35 |
+
|
| 36 |
+
self.num_scales = len(cfg.model.scales)
|
| 37 |
+
|
| 38 |
+
assert cfg.model.frame_ids[0] == 0, "frame_ids must start with 0"
|
| 39 |
+
|
| 40 |
+
if cfg.model.use_stereo:
|
| 41 |
+
cfg.model.frame_ids.append("s")
|
| 42 |
+
|
| 43 |
+
model_name = cfg.model.name
|
| 44 |
+
if model_name == "resnet":
|
| 45 |
+
models["encoder"] = ResnetEncoder(
|
| 46 |
+
cfg.model.num_layers,
|
| 47 |
+
cfg.model.weights_init == "pretrained",
|
| 48 |
+
cfg.model.resnet_bn_order
|
| 49 |
+
)
|
| 50 |
+
self.parameters_to_train += default_param_group(models["encoder"])
|
| 51 |
+
if not cfg.model.unified_decoder:
|
| 52 |
+
models["depth"] = DepthDecoder(
|
| 53 |
+
cfg, models["encoder"].num_ch_enc)
|
| 54 |
+
self.parameters_to_train += default_param_group(models["depth"])
|
| 55 |
+
if cfg.model.gaussian_rendering:
|
| 56 |
+
for i in range(cfg.model.gaussians_per_pixel):
|
| 57 |
+
gauss_decoder = GaussianDecoder(
|
| 58 |
+
cfg, models["encoder"].num_ch_enc,
|
| 59 |
+
)
|
| 60 |
+
self.parameters_to_train += default_param_group(gauss_decoder)
|
| 61 |
+
models["gauss_decoder_"+str(i)] = gauss_decoder
|
| 62 |
+
elif model_name == "unidepth":
|
| 63 |
+
from networks.unidepth import UniDepthSplatter
|
| 64 |
+
models["unidepth"] = UniDepthSplatter(cfg)
|
| 65 |
+
self.parameters_to_train += models["unidepth"].get_parameter_groups()
|
| 66 |
+
elif model_name in ["unidepth_unprojector_vit", "unidepth_unprojector_cnvnxtl"]:
|
| 67 |
+
from networks.unidepth import UniDepthUnprojector
|
| 68 |
+
models["unidepth"] = UniDepthUnprojector(cfg)
|
| 69 |
+
self.parameters_to_train += models["unidepth"].get_parameter_groups()
|
| 70 |
+
elif model_name in ["unidepth_extension_vit", "unidepth_extension_cnvnxtl"]:
|
| 71 |
+
from networks.unidepth_extension import UniDepthExtended
|
| 72 |
+
models["unidepth_extended"] = UniDepthExtended(cfg)
|
| 73 |
+
self.parameters_to_train += models["unidepth_extended"].get_parameter_groups()
|
| 74 |
+
|
| 75 |
+
self.models = nn.ModuleDict(models)
|
| 76 |
+
|
| 77 |
+
backproject_depth = {}
|
| 78 |
+
H = cfg.dataset.height
|
| 79 |
+
W = cfg.dataset.width
|
| 80 |
+
for scale in cfg.model.scales:
|
| 81 |
+
h = H // (2 ** scale)
|
| 82 |
+
w = W // (2 ** scale)
|
| 83 |
+
if cfg.model.shift_rays_half_pixel == "zero":
|
| 84 |
+
shift_rays_half_pixel = 0
|
| 85 |
+
elif cfg.model.shift_rays_half_pixel == "forward":
|
| 86 |
+
shift_rays_half_pixel = 0.5
|
| 87 |
+
elif cfg.model.shift_rays_half_pixel == "backward":
|
| 88 |
+
shift_rays_half_pixel = -0.5
|
| 89 |
+
else:
|
| 90 |
+
raise NotImplementedError
|
| 91 |
+
backproject_depth[str(scale)] = BackprojectDepth(
|
| 92 |
+
cfg.optimiser.batch_size * cfg.model.gaussians_per_pixel,
|
| 93 |
+
# backprojection can be different if padding was used
|
| 94 |
+
h + 2 * self.cfg.dataset.pad_border_aug,
|
| 95 |
+
w + 2 * self.cfg.dataset.pad_border_aug,
|
| 96 |
+
shift_rays_half_pixel=shift_rays_half_pixel
|
| 97 |
+
)
|
| 98 |
+
self.backproject_depth = nn.ModuleDict(backproject_depth)
|
| 99 |
+
|
| 100 |
+
def set_train(self):
|
| 101 |
+
"""Convert all models to training mode
|
| 102 |
+
"""
|
| 103 |
+
for m in self.models.values():
|
| 104 |
+
m.train()
|
| 105 |
+
self._is_train = True
|
| 106 |
+
|
| 107 |
+
def set_eval(self):
|
| 108 |
+
"""Convert all models to testing/evaluation mode
|
| 109 |
+
"""
|
| 110 |
+
for m in self.models.values():
|
| 111 |
+
m.eval()
|
| 112 |
+
self._is_train = False
|
| 113 |
+
|
| 114 |
+
def is_train(self):
|
| 115 |
+
return self._is_train
|
| 116 |
+
|
| 117 |
+
def forward(self, inputs):
|
| 118 |
+
cfg = self.cfg
|
| 119 |
+
B = cfg.optimiser.batch_size
|
| 120 |
+
|
| 121 |
+
if cfg.model.name == "resnet":
|
| 122 |
+
do_flip = self.is_train() and \
|
| 123 |
+
cfg.train.lazy_flip_augmentation and \
|
| 124 |
+
(torch.rand(1) > .5).item()
|
| 125 |
+
# Otherwise, we only feed the image with frame_id 0 through the depth encoder
|
| 126 |
+
input_img = inputs["color_aug", 0, 0]
|
| 127 |
+
if do_flip:
|
| 128 |
+
input_img = torch.flip(input_img, dims=(-1, ))
|
| 129 |
+
features = self.models["encoder"](input_img)
|
| 130 |
+
if not cfg.model.unified_decoder:
|
| 131 |
+
outputs = self.models["depth"](features)
|
| 132 |
+
else:
|
| 133 |
+
outputs = dict()
|
| 134 |
+
|
| 135 |
+
if self.cfg.model.gaussian_rendering:
|
| 136 |
+
# gauss_feats = self.models["gauss_encoder"](inputs["color_aug", 0, 0])
|
| 137 |
+
input_f_id = 0
|
| 138 |
+
gauss_feats = features
|
| 139 |
+
gauss_outs = dict()
|
| 140 |
+
for i in range(self.cfg.model.gaussians_per_pixel):
|
| 141 |
+
outs = self.models["gauss_decoder_"+str(i)](gauss_feats)
|
| 142 |
+
for key, v in outs.items():
|
| 143 |
+
gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1)
|
| 144 |
+
for key, v in gauss_outs.items():
|
| 145 |
+
gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...')
|
| 146 |
+
outputs |= gauss_outs
|
| 147 |
+
outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
|
| 148 |
+
else:
|
| 149 |
+
for scale in cfg.model.scales:
|
| 150 |
+
outputs[("disp", 0, scale)] = outputs[("disp", scale)]
|
| 151 |
+
|
| 152 |
+
# unflip all outputs
|
| 153 |
+
if do_flip:
|
| 154 |
+
for k, v in outputs.items():
|
| 155 |
+
outputs[k] = torch.flip(v, dims=(-1, ))
|
| 156 |
+
elif "unidepth" in cfg.model.name:
|
| 157 |
+
if cfg.model.name in ["unidepth",
|
| 158 |
+
"unidepth_unprojector_vit",
|
| 159 |
+
"unidepth_unprojector_cnvnxtl"]:
|
| 160 |
+
outputs = self.models["unidepth"](inputs)
|
| 161 |
+
elif cfg.model.name in ["unidepth_extension_vit",
|
| 162 |
+
"unidepth_extension_cnvnxtl"]:
|
| 163 |
+
outputs = self.models["unidepth_extended"](inputs)
|
| 164 |
+
|
| 165 |
+
input_f_id = 0
|
| 166 |
+
outputs = {(key[0], input_f_id, key[1]): v for key, v in outputs.items()}
|
| 167 |
+
|
| 168 |
+
input_f_id = 0
|
| 169 |
+
scale = 0
|
| 170 |
+
if not ("depth", input_f_id, scale) in outputs:
|
| 171 |
+
disp = outputs[("disp", input_f_id, scale)]
|
| 172 |
+
_, depth = disp_to_depth(disp, cfg.model.min_depth, cfg.model.max_depth)
|
| 173 |
+
outputs[("depth", input_f_id, scale)] = depth
|
| 174 |
+
|
| 175 |
+
self.compute_gauss_means(inputs, outputs)
|
| 176 |
+
|
| 177 |
+
return outputs
|
| 178 |
+
|
| 179 |
+
def target_tensor_image_dims(self, inputs):
|
| 180 |
+
B, _, H, W = inputs["color", 0, 0].shape
|
| 181 |
+
return B, H, W
|
| 182 |
+
|
| 183 |
+
def compute_gauss_means(self, inputs, outputs):
|
| 184 |
+
cfg = self.cfg
|
| 185 |
+
input_f_id = 0
|
| 186 |
+
scale = 0
|
| 187 |
+
depth = outputs[("depth", input_f_id, scale)]
|
| 188 |
+
B, _, H, W = depth.shape
|
| 189 |
+
if ("inv_K_src", scale) in inputs:
|
| 190 |
+
inv_K = inputs[("inv_K_src", scale)]
|
| 191 |
+
else:
|
| 192 |
+
inv_K = outputs[("inv_K_src", input_f_id, scale)]
|
| 193 |
+
if self.cfg.model.gaussians_per_pixel > 1:
|
| 194 |
+
inv_K = rearrange(inv_K[:,None,...].
|
| 195 |
+
repeat(1, self.cfg.model.gaussians_per_pixel, 1, 1),
|
| 196 |
+
'b n ... -> (b n) ...')
|
| 197 |
+
xyz = self.backproject_depth[str(scale)](
|
| 198 |
+
depth, inv_K
|
| 199 |
+
)
|
| 200 |
+
inputs[("inv_K_src", scale)] = inv_K
|
| 201 |
+
if cfg.model.predict_offset:
|
| 202 |
+
offset = outputs[("gauss_offset", input_f_id, scale)]
|
| 203 |
+
if cfg.model.scaled_offset:
|
| 204 |
+
offset = offset * depth.detach()
|
| 205 |
+
offset = offset.view(B, 3, -1)
|
| 206 |
+
zeros = torch.zeros(B, 1, H * W, device=depth.device)
|
| 207 |
+
offset = torch.cat([offset, zeros], 1)
|
| 208 |
+
xyz = xyz + offset # [B, 4, W*H]
|
| 209 |
+
outputs[("gauss_means", input_f_id, scale)] = xyz
|
| 210 |
+
|
| 211 |
+
def checkpoint_dir(self):
|
| 212 |
+
return Path("checkpoints")
|
| 213 |
+
|
| 214 |
+
def save_model(self, optimizer, step, ema=None):
|
| 215 |
+
"""Save model weights to disk
|
| 216 |
+
"""
|
| 217 |
+
save_folder = self.checkpoint_dir()
|
| 218 |
+
save_folder.mkdir(exist_ok=True, parents=True)
|
| 219 |
+
|
| 220 |
+
save_path = save_folder / f"model_{step:07}.pth"
|
| 221 |
+
logging.info(f"saving checkpoint to {str(save_path)}")
|
| 222 |
+
|
| 223 |
+
model = ema.ema_model if ema is not None else self
|
| 224 |
+
save_dict = {
|
| 225 |
+
"model": model.state_dict(),
|
| 226 |
+
"version": "1.0",
|
| 227 |
+
"optimiser": optimizer.state_dict(),
|
| 228 |
+
"step": step
|
| 229 |
+
}
|
| 230 |
+
torch.save(save_dict, save_path)
|
| 231 |
+
|
| 232 |
+
num_ckpts = self.cfg.optimiser.num_keep_ckpts
|
| 233 |
+
ckpts = sorted(list(save_folder.glob("model_*.pth")), reverse=True)
|
| 234 |
+
if len(ckpts) > num_ckpts:
|
| 235 |
+
for ckpt in ckpts[num_ckpts:]:
|
| 236 |
+
ckpt.unlink()
|
| 237 |
+
|
| 238 |
+
def load_model(self, weights_path, optimizer=None):
|
| 239 |
+
"""Load model(s) from disk
|
| 240 |
+
"""
|
| 241 |
+
weights_path = Path(weights_path)
|
| 242 |
+
|
| 243 |
+
# determine if it is an old or new saving format
|
| 244 |
+
if weights_path.is_dir() and weights_path.joinpath("encoder.pth").exists():
|
| 245 |
+
self.load_model_old(weights_path, optimizer)
|
| 246 |
+
return
|
| 247 |
+
|
| 248 |
+
logging.info(f"Loading weights from {weights_path}...")
|
| 249 |
+
state_dict = torch.load(weights_path)
|
| 250 |
+
if "version" in state_dict and state_dict["version"] == "1.0":
|
| 251 |
+
new_dict = {}
|
| 252 |
+
for k, v in state_dict["model"].items():
|
| 253 |
+
if "backproject_depth" in k:
|
| 254 |
+
new_dict[k] = self.state_dict()[k].clone()
|
| 255 |
+
else:
|
| 256 |
+
new_dict[k] = v.clone()
|
| 257 |
+
# for k, v in state_dict["model"].items():
|
| 258 |
+
# if "backproject_depth" in k and ("pix_coords" in k or "ones" in k):
|
| 259 |
+
# # model has these parameters set as a function of batch size
|
| 260 |
+
# # when batch size changes in eval this results in a loading error
|
| 261 |
+
# state_dict["model"][k] = v[:1, ...]
|
| 262 |
+
self.load_state_dict(new_dict, strict=False)
|
| 263 |
+
else:
|
| 264 |
+
# TODO remove loading according to the old format
|
| 265 |
+
for name in self.cfg.train.models_to_load:
|
| 266 |
+
if name not in self.models:
|
| 267 |
+
continue
|
| 268 |
+
self.models[name].load_state_dict(state_dict[name])
|
| 269 |
+
|
| 270 |
+
# loading adam state
|
| 271 |
+
if optimizer is not None:
|
| 272 |
+
optimizer.load_state_dict(state_dict["optimiser"])
|
| 273 |
+
self.step = state_dict["step"]
|
| 274 |
+
|
| 275 |
+
def load_model_old(self, weights_folder, optimizer=None):
|
| 276 |
+
for n in self.cfg.train.models_to_load:
|
| 277 |
+
print(f"Loading {n} weights...")
|
| 278 |
+
path = weights_folder / f"{n}.pth"
|
| 279 |
+
if n not in self.models:
|
| 280 |
+
continue
|
| 281 |
+
model_dict = self.models[n].state_dict()
|
| 282 |
+
pretrained_dict = torch.load(path)
|
| 283 |
+
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
| 284 |
+
model_dict.update(pretrained_dict)
|
| 285 |
+
self.models[n].load_state_dict(model_dict)
|
| 286 |
+
|
| 287 |
+
# loading adam state
|
| 288 |
+
optimizer_load_path = weights_folder / "adam.pth"
|
| 289 |
+
if optimizer is not None and optimizer_load_path.is_file():
|
| 290 |
+
print("Loading Adam weights")
|
| 291 |
+
optimizer_state = torch.load(optimizer_load_path)
|
| 292 |
+
optimizer.load_state_dict(optimizer_state["adam"])
|
| 293 |
+
self.step = optimizer_state["step"]
|
flash3d/networks/layers.py
ADDED
|
@@ -0,0 +1,295 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
| 4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
| 5 |
+
# available in the LICENSE file.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def disp_to_depth(disp, min_depth, max_depth):
|
| 15 |
+
"""Convert network's sigmoid output into depth prediction
|
| 16 |
+
The formula for this conversion is given in the 'additional considerations'
|
| 17 |
+
section of the paper.
|
| 18 |
+
"""
|
| 19 |
+
min_disp = 1 / max_depth
|
| 20 |
+
max_disp = 1 / min_depth
|
| 21 |
+
scaled_disp = min_disp + (max_disp - min_disp) * disp
|
| 22 |
+
depth = 1 / scaled_disp
|
| 23 |
+
return scaled_disp, depth
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def transformation_from_parameters(axisangle, translation, invert=False):
|
| 27 |
+
"""Convert the network's (axisangle, translation) output into a 4x4 matrix
|
| 28 |
+
"""
|
| 29 |
+
R = rot_from_axisangle(axisangle)
|
| 30 |
+
t = translation.clone()
|
| 31 |
+
|
| 32 |
+
if invert:
|
| 33 |
+
R = R.transpose(1, 2)
|
| 34 |
+
t *= -1
|
| 35 |
+
|
| 36 |
+
T = get_translation_matrix(t)
|
| 37 |
+
|
| 38 |
+
if invert:
|
| 39 |
+
M = torch.matmul(R, T)
|
| 40 |
+
else:
|
| 41 |
+
M = torch.matmul(T, R)
|
| 42 |
+
|
| 43 |
+
return M
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def get_translation_matrix(translation_vector):
|
| 47 |
+
"""Convert a translation vector into a 4x4 transformation matrix
|
| 48 |
+
"""
|
| 49 |
+
T = torch.zeros(translation_vector.shape[0], 4, 4).to(device=translation_vector.device)
|
| 50 |
+
|
| 51 |
+
t = translation_vector.contiguous().view(-1, 3, 1)
|
| 52 |
+
|
| 53 |
+
T[:, 0, 0] = 1
|
| 54 |
+
T[:, 1, 1] = 1
|
| 55 |
+
T[:, 2, 2] = 1
|
| 56 |
+
T[:, 3, 3] = 1
|
| 57 |
+
T[:, :3, 3, None] = t
|
| 58 |
+
|
| 59 |
+
return T
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def rot_from_axisangle(vec):
|
| 63 |
+
"""Convert an axisangle rotation into a 4x4 transformation matrix
|
| 64 |
+
(adapted from https://github.com/Wallacoloo/printipi)
|
| 65 |
+
Input 'vec' has to be Bx1x3
|
| 66 |
+
"""
|
| 67 |
+
angle = torch.norm(vec, 2, 2, True)
|
| 68 |
+
axis = vec / (angle + 1e-7)
|
| 69 |
+
|
| 70 |
+
ca = torch.cos(angle)
|
| 71 |
+
sa = torch.sin(angle)
|
| 72 |
+
C = 1 - ca
|
| 73 |
+
|
| 74 |
+
x = axis[..., 0].unsqueeze(1)
|
| 75 |
+
y = axis[..., 1].unsqueeze(1)
|
| 76 |
+
z = axis[..., 2].unsqueeze(1)
|
| 77 |
+
|
| 78 |
+
xs = x * sa
|
| 79 |
+
ys = y * sa
|
| 80 |
+
zs = z * sa
|
| 81 |
+
xC = x * C
|
| 82 |
+
yC = y * C
|
| 83 |
+
zC = z * C
|
| 84 |
+
xyC = x * yC
|
| 85 |
+
yzC = y * zC
|
| 86 |
+
zxC = z * xC
|
| 87 |
+
|
| 88 |
+
rot = torch.zeros((vec.shape[0], 4, 4)).to(device=vec.device)
|
| 89 |
+
|
| 90 |
+
rot[:, 0, 0] = torch.squeeze(x * xC + ca)
|
| 91 |
+
rot[:, 0, 1] = torch.squeeze(xyC - zs)
|
| 92 |
+
rot[:, 0, 2] = torch.squeeze(zxC + ys)
|
| 93 |
+
rot[:, 1, 0] = torch.squeeze(xyC + zs)
|
| 94 |
+
rot[:, 1, 1] = torch.squeeze(y * yC + ca)
|
| 95 |
+
rot[:, 1, 2] = torch.squeeze(yzC - xs)
|
| 96 |
+
rot[:, 2, 0] = torch.squeeze(zxC - ys)
|
| 97 |
+
rot[:, 2, 1] = torch.squeeze(yzC + xs)
|
| 98 |
+
rot[:, 2, 2] = torch.squeeze(z * zC + ca)
|
| 99 |
+
rot[:, 3, 3] = 1
|
| 100 |
+
|
| 101 |
+
return rot
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class ConvBlock(nn.Module):
|
| 105 |
+
"""Layer to perform a convolution followed by ELU
|
| 106 |
+
"""
|
| 107 |
+
def __init__(self, in_channels, out_channels):
|
| 108 |
+
super(ConvBlock, self).__init__()
|
| 109 |
+
|
| 110 |
+
self.conv = Conv3x3(in_channels, out_channels)
|
| 111 |
+
self.nonlin = nn.ELU(inplace=True)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
out = self.conv(x)
|
| 115 |
+
out = self.nonlin(out)
|
| 116 |
+
return out
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class Conv3x3(nn.Module):
|
| 120 |
+
"""Layer to pad and convolve input
|
| 121 |
+
"""
|
| 122 |
+
def __init__(self, in_channels, out_channels, use_refl=True):
|
| 123 |
+
super(Conv3x3, self).__init__()
|
| 124 |
+
|
| 125 |
+
if use_refl:
|
| 126 |
+
self.pad = nn.ReflectionPad2d(1)
|
| 127 |
+
else:
|
| 128 |
+
self.pad = nn.ZeroPad2d(1)
|
| 129 |
+
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
|
| 130 |
+
|
| 131 |
+
def forward(self, x):
|
| 132 |
+
out = self.pad(x)
|
| 133 |
+
out = self.conv(out)
|
| 134 |
+
return out
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class BackprojectDepth(nn.Module):
|
| 138 |
+
"""Layer to transform a depth image into a point cloud
|
| 139 |
+
"""
|
| 140 |
+
def __init__(self, batch_size, height, width, shift_rays_half_pixel=0):
|
| 141 |
+
super(BackprojectDepth, self).__init__()
|
| 142 |
+
|
| 143 |
+
self.batch_size = batch_size
|
| 144 |
+
self.height = height
|
| 145 |
+
self.width = width
|
| 146 |
+
|
| 147 |
+
meshgrid = np.meshgrid(range(self.width), range(self.height), indexing='xy')
|
| 148 |
+
id_coords = np.stack(meshgrid, axis=0).astype(np.float32)
|
| 149 |
+
id_coords = torch.from_numpy(id_coords)
|
| 150 |
+
|
| 151 |
+
ones = torch.ones(self.batch_size, 1, self.height * self.width)
|
| 152 |
+
|
| 153 |
+
pix_coords = torch.unsqueeze(torch.stack(
|
| 154 |
+
[id_coords[0].view(-1), id_coords[1].view(-1)], 0), 0)
|
| 155 |
+
pix_coords = pix_coords.repeat(batch_size, 1, 1)
|
| 156 |
+
pix_coords = torch.cat([pix_coords + shift_rays_half_pixel,
|
| 157 |
+
ones], 1)
|
| 158 |
+
self.register_buffer("pix_coords", pix_coords)
|
| 159 |
+
self.register_buffer("id_coords", id_coords)
|
| 160 |
+
self.register_buffer("ones", ones)
|
| 161 |
+
# self.pix_coords = pix_coords
|
| 162 |
+
# self.ones = ones
|
| 163 |
+
|
| 164 |
+
def forward(self, depth, inv_K):
|
| 165 |
+
cam_points = torch.matmul(inv_K[:, :3, :3], self.pix_coords.to(depth.device))
|
| 166 |
+
cam_points = depth.view(self.batch_size, 1, -1) * cam_points
|
| 167 |
+
cam_points = torch.cat([cam_points, self.ones.to(depth.device)], 1)
|
| 168 |
+
|
| 169 |
+
return cam_points
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class Project3D(nn.Module):
|
| 173 |
+
"""Layer which projects 3D points into a camera with intrinsics K and at position T
|
| 174 |
+
"""
|
| 175 |
+
def __init__(self, batch_size, height, width, eps=1e-7):
|
| 176 |
+
super(Project3D, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.batch_size = batch_size
|
| 179 |
+
self.height = height
|
| 180 |
+
self.width = width
|
| 181 |
+
self.eps = eps
|
| 182 |
+
|
| 183 |
+
def forward(self, points, K, T=None):
|
| 184 |
+
if T is None:
|
| 185 |
+
P = K
|
| 186 |
+
else:
|
| 187 |
+
P = torch.matmul(K, T)
|
| 188 |
+
P = P[:, :3, :]
|
| 189 |
+
|
| 190 |
+
cam_points = torch.matmul(P, points)
|
| 191 |
+
|
| 192 |
+
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
|
| 193 |
+
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
|
| 194 |
+
pix_coords = pix_coords.permute(0, 2, 3, 1)
|
| 195 |
+
pix_coords[..., 0] /= self.width - 1
|
| 196 |
+
pix_coords[..., 1] /= self.height - 1
|
| 197 |
+
pix_coords = (pix_coords - 0.5) * 2
|
| 198 |
+
return pix_coords
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class Project3DSimple(nn.Module):
|
| 202 |
+
"""Layer which projects 3D points into a camera with intrinsics K and at position T
|
| 203 |
+
"""
|
| 204 |
+
def __init__(self, batch_size, height, width, eps=1e-7):
|
| 205 |
+
super(Project3DSimple, self).__init__()
|
| 206 |
+
|
| 207 |
+
self.batch_size = batch_size
|
| 208 |
+
self.height = height
|
| 209 |
+
self.width = width
|
| 210 |
+
self.eps = eps
|
| 211 |
+
|
| 212 |
+
def forward(self, points, K):
|
| 213 |
+
K = K[:, :3, :]
|
| 214 |
+
|
| 215 |
+
cam_points = torch.matmul(K, points)
|
| 216 |
+
|
| 217 |
+
pix_coords = cam_points[:, :2, :] / (cam_points[:, 2, :].unsqueeze(1) + self.eps)
|
| 218 |
+
pix_coords = pix_coords.view(self.batch_size, 2, self.height, self.width)
|
| 219 |
+
pix_coords = pix_coords.permute(0, 2, 3, 1)
|
| 220 |
+
return pix_coords
|
| 221 |
+
|
| 222 |
+
def upsample(x):
|
| 223 |
+
"""Upsample input tensor by a factor of 2
|
| 224 |
+
"""
|
| 225 |
+
return F.interpolate(x, scale_factor=2, mode="nearest")
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def get_smooth_loss(disp, img):
|
| 229 |
+
"""Computes the smoothness loss for a disparity image
|
| 230 |
+
The color image is used for edge-aware smoothness
|
| 231 |
+
"""
|
| 232 |
+
grad_disp_x = torch.abs(disp[:, :, :, :-1] - disp[:, :, :, 1:])
|
| 233 |
+
grad_disp_y = torch.abs(disp[:, :, :-1, :] - disp[:, :, 1:, :])
|
| 234 |
+
|
| 235 |
+
grad_img_x = torch.mean(torch.abs(img[:, :, :, :-1] - img[:, :, :, 1:]), 1, keepdim=True)
|
| 236 |
+
grad_img_y = torch.mean(torch.abs(img[:, :, :-1, :] - img[:, :, 1:, :]), 1, keepdim=True)
|
| 237 |
+
|
| 238 |
+
grad_disp_x *= torch.exp(-grad_img_x)
|
| 239 |
+
grad_disp_y *= torch.exp(-grad_img_y)
|
| 240 |
+
|
| 241 |
+
return grad_disp_x.mean() + grad_disp_y.mean()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class SSIM(nn.Module):
|
| 245 |
+
"""Layer to compute the SSIM loss between a pair of images
|
| 246 |
+
"""
|
| 247 |
+
def __init__(self):
|
| 248 |
+
super(SSIM, self).__init__()
|
| 249 |
+
self.mu_x_pool = nn.AvgPool2d(3, 1)
|
| 250 |
+
self.mu_y_pool = nn.AvgPool2d(3, 1)
|
| 251 |
+
self.sig_x_pool = nn.AvgPool2d(3, 1)
|
| 252 |
+
self.sig_y_pool = nn.AvgPool2d(3, 1)
|
| 253 |
+
self.sig_xy_pool = nn.AvgPool2d(3, 1)
|
| 254 |
+
|
| 255 |
+
self.refl = nn.ReflectionPad2d(1)
|
| 256 |
+
|
| 257 |
+
self.C1 = 0.01 ** 2
|
| 258 |
+
self.C2 = 0.03 ** 2
|
| 259 |
+
|
| 260 |
+
def forward(self, x, y):
|
| 261 |
+
x = self.refl(x)
|
| 262 |
+
y = self.refl(y)
|
| 263 |
+
|
| 264 |
+
mu_x = self.mu_x_pool(x)
|
| 265 |
+
mu_y = self.mu_y_pool(y)
|
| 266 |
+
|
| 267 |
+
sigma_x = self.sig_x_pool(x ** 2) - mu_x ** 2
|
| 268 |
+
sigma_y = self.sig_y_pool(y ** 2) - mu_y ** 2
|
| 269 |
+
sigma_xy = self.sig_xy_pool(x * y) - mu_x * mu_y
|
| 270 |
+
|
| 271 |
+
SSIM_n = (2 * mu_x * mu_y + self.C1) * (2 * sigma_xy + self.C2)
|
| 272 |
+
SSIM_d = (mu_x ** 2 + mu_y ** 2 + self.C1) * (sigma_x + sigma_y + self.C2)
|
| 273 |
+
|
| 274 |
+
return torch.clamp((1 - SSIM_n / SSIM_d) / 2, 0, 1)
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def compute_depth_errors(gt, pred):
|
| 278 |
+
"""Computation of error metrics between predicted and ground truth depths
|
| 279 |
+
"""
|
| 280 |
+
thresh = torch.max((gt / pred), (pred / gt))
|
| 281 |
+
a1 = (thresh < 1.25 ).float().mean()
|
| 282 |
+
a2 = (thresh < 1.25 ** 2).float().mean()
|
| 283 |
+
a3 = (thresh < 1.25 ** 3).float().mean()
|
| 284 |
+
|
| 285 |
+
rmse = (gt - pred) ** 2
|
| 286 |
+
rmse = torch.sqrt(rmse.mean())
|
| 287 |
+
|
| 288 |
+
rmse_log = (torch.log(gt) - torch.log(pred)) ** 2
|
| 289 |
+
rmse_log = torch.sqrt(rmse_log.mean())
|
| 290 |
+
|
| 291 |
+
abs_rel = torch.mean(torch.abs(gt - pred) / gt)
|
| 292 |
+
|
| 293 |
+
sq_rel = torch.mean((gt - pred) ** 2 / gt)
|
| 294 |
+
|
| 295 |
+
return abs_rel, sq_rel, rmse, rmse_log, a1, a2, a3
|
flash3d/networks/resnet_encoder.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright Niantic 2019. Patent Pending. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# This software is licensed under the terms of the Monodepth2 licence
|
| 4 |
+
# which allows for non-commercial use only, the full terms of which are made
|
| 5 |
+
# available in the LICENSE file.
|
| 6 |
+
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torchvision.models as models
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
RESNETS = {18: (models.resnet18, models.ResNet18_Weights.IMAGENET1K_V1),
|
| 15 |
+
50: (models.resnet50, models.ResNet50_Weights.IMAGENET1K_V2)}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ResNetMultiImageInput(models.ResNet):
|
| 19 |
+
"""Constructs a resnet model with varying number of input images.
|
| 20 |
+
Adapted from https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
|
| 21 |
+
"""
|
| 22 |
+
def __init__(self, block, layers, num_classes=1000, num_input_images=1):
|
| 23 |
+
super(ResNetMultiImageInput, self).__init__(block, layers)
|
| 24 |
+
self.inplanes = 64
|
| 25 |
+
self.conv1 = nn.Conv2d(
|
| 26 |
+
num_input_images * 3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
| 27 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 28 |
+
self.relu = nn.ReLU(inplace=True)
|
| 29 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 30 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 31 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 32 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 33 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 34 |
+
|
| 35 |
+
for m in self.modules():
|
| 36 |
+
if isinstance(m, nn.Conv2d):
|
| 37 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 38 |
+
elif isinstance(m, nn.BatchNorm2d):
|
| 39 |
+
nn.init.constant_(m.weight, 1)
|
| 40 |
+
nn.init.constant_(m.bias, 0)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def resnet_multiimage_input(num_layers, pretrained=False, num_input_images=1):
|
| 44 |
+
"""Constructs a ResNet model.
|
| 45 |
+
Args:
|
| 46 |
+
num_layers (int): Number of resnet layers. Must be 18 or 50
|
| 47 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 48 |
+
num_input_images (int): Number of frames stacked as input
|
| 49 |
+
"""
|
| 50 |
+
assert num_layers in [18, 50], "Can only run with 18 or 50 layer resnet"
|
| 51 |
+
blocks = {18: [2, 2, 2, 2], 50: [3, 4, 6, 3]}[num_layers]
|
| 52 |
+
block_type = {18: models.resnet.BasicBlock, 50: models.resnet.Bottleneck}[num_layers]
|
| 53 |
+
model = ResNetMultiImageInput(block_type, blocks, num_input_images=num_input_images)
|
| 54 |
+
model, weigths = RESNETS[num_layers]
|
| 55 |
+
|
| 56 |
+
if pretrained:
|
| 57 |
+
loaded = torch.hub.load_state_dict_from_url(weigths.url)
|
| 58 |
+
loaded['conv1.weight'] = torch.cat(
|
| 59 |
+
[loaded['conv1.weight']] * num_input_images, 1) / num_input_images
|
| 60 |
+
model.load_state_dict(loaded)
|
| 61 |
+
return model
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ResnetEncoder(nn.Module):
|
| 65 |
+
"""Pytorch module for a resnet encoder
|
| 66 |
+
"""
|
| 67 |
+
def __init__(self, num_layers, pretrained, bn_order, num_input_images=1):
|
| 68 |
+
super(ResnetEncoder, self).__init__()
|
| 69 |
+
|
| 70 |
+
self.num_ch_enc = np.array([64, 64, 128, 256, 512])
|
| 71 |
+
self.bn_order = bn_order
|
| 72 |
+
|
| 73 |
+
if num_layers not in RESNETS:
|
| 74 |
+
raise ValueError("{} is not a valid number of resnet layers".format(num_layers))
|
| 75 |
+
|
| 76 |
+
if num_input_images > 1:
|
| 77 |
+
self.encoder = resnet_multiimage_input(num_layers, pretrained, num_input_images)
|
| 78 |
+
else:
|
| 79 |
+
model, weights = RESNETS[num_layers]
|
| 80 |
+
self.encoder = model(weights=weights)
|
| 81 |
+
|
| 82 |
+
if num_layers > 34:
|
| 83 |
+
self.num_ch_enc[1:] *= 4
|
| 84 |
+
|
| 85 |
+
def forward(self, input_image):
|
| 86 |
+
encoder = self.encoder
|
| 87 |
+
features = []
|
| 88 |
+
x = (input_image - 0.45) / 0.225
|
| 89 |
+
x = encoder.conv1(x)
|
| 90 |
+
|
| 91 |
+
if self.bn_order == "pre_bn":
|
| 92 |
+
# Concatenating pre-norm features allows us to
|
| 93 |
+
# keep the scale and shift of RGB colours
|
| 94 |
+
# and recover them at output
|
| 95 |
+
features.append(x)
|
| 96 |
+
x = encoder.bn1(x)
|
| 97 |
+
x = encoder.relu(x)
|
| 98 |
+
features.append(encoder.layer1(encoder.maxpool(x)))
|
| 99 |
+
elif self.bn_order == "monodepth":
|
| 100 |
+
# Batchnorm gets rid of constants due to colour shift
|
| 101 |
+
# will make the network not able to recover absolute colour shift
|
| 102 |
+
# of the input image
|
| 103 |
+
# used in old models
|
| 104 |
+
x = encoder.bn1(x)
|
| 105 |
+
x = encoder.relu(x)
|
| 106 |
+
features.append(x)
|
| 107 |
+
features.append(encoder.layer1(encoder.maxpool(x)))
|
| 108 |
+
else:
|
| 109 |
+
assert False
|
| 110 |
+
|
| 111 |
+
features.append(encoder.layer2(features[-1]))
|
| 112 |
+
features.append(encoder.layer3(features[-1]))
|
| 113 |
+
features.append(encoder.layer4(features[-1]))
|
| 114 |
+
|
| 115 |
+
return features
|
flash3d/networks/unidepth.py
ADDED
|
@@ -0,0 +1,577 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import List, Tuple
|
| 4 |
+
from math import ceil
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchvision.transforms.functional as TF
|
| 9 |
+
from einops import rearrange
|
| 10 |
+
|
| 11 |
+
from unidepth.models.unidepthv1 import UniDepthV1
|
| 12 |
+
from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
|
| 13 |
+
from unidepth.utils.geometric import (
|
| 14 |
+
generate_rays,
|
| 15 |
+
spherical_zbuffer_to_euclidean,
|
| 16 |
+
flat_interpolate,
|
| 17 |
+
)
|
| 18 |
+
from unidepth.layers import (
|
| 19 |
+
MLP,
|
| 20 |
+
AttentionBlock,
|
| 21 |
+
NystromBlock,
|
| 22 |
+
PositionEmbeddingSine,
|
| 23 |
+
ConvUpsample,
|
| 24 |
+
)
|
| 25 |
+
from unidepth.utils.sht import rsh_cart_8
|
| 26 |
+
|
| 27 |
+
from networks.gaussian_decoder import get_splits_and_inits
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# inference helpers
|
| 31 |
+
def _paddings(image_shape, network_shape):
|
| 32 |
+
cur_h, cur_w = image_shape
|
| 33 |
+
h, w = network_shape
|
| 34 |
+
pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
|
| 35 |
+
pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
|
| 36 |
+
return pad_left, pad_right, pad_top, pad_bottom
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _shapes(image_shape, network_shape):
|
| 40 |
+
h, w = image_shape
|
| 41 |
+
input_ratio = w / h
|
| 42 |
+
output_ratio = network_shape[1] / network_shape[0]
|
| 43 |
+
if output_ratio > input_ratio:
|
| 44 |
+
ratio = network_shape[0] / h
|
| 45 |
+
elif output_ratio <= input_ratio:
|
| 46 |
+
ratio = network_shape[1] / w
|
| 47 |
+
return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
|
| 51 |
+
(pad_left, pad_right, pad_top, pad_bottom) = pads
|
| 52 |
+
rgbs = F.interpolate(
|
| 53 |
+
rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
|
| 54 |
+
)
|
| 55 |
+
rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
|
| 56 |
+
if intrinsics is not None:
|
| 57 |
+
intrinsics = intrinsics.clone()
|
| 58 |
+
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
|
| 59 |
+
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
|
| 60 |
+
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
|
| 61 |
+
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
|
| 62 |
+
return rgbs, intrinsics
|
| 63 |
+
return rgbs, None
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
|
| 67 |
+
|
| 68 |
+
(pad_left, pad_right, pad_top, pad_bottom) = pads
|
| 69 |
+
# pred mean, trim paddings, and upsample to input dim
|
| 70 |
+
predictions = sum(
|
| 71 |
+
[
|
| 72 |
+
F.interpolate(
|
| 73 |
+
x,
|
| 74 |
+
size=shapes,
|
| 75 |
+
mode="bilinear",
|
| 76 |
+
align_corners=False,
|
| 77 |
+
antialias=True,
|
| 78 |
+
)
|
| 79 |
+
for x in predictions
|
| 80 |
+
]
|
| 81 |
+
) / len(predictions)
|
| 82 |
+
|
| 83 |
+
shapes = predictions.shape[2:]
|
| 84 |
+
predictions = predictions[
|
| 85 |
+
..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
|
| 86 |
+
]
|
| 87 |
+
|
| 88 |
+
predictions = F.interpolate(
|
| 89 |
+
predictions,
|
| 90 |
+
size=original_shapes,
|
| 91 |
+
mode="bilinear",
|
| 92 |
+
align_corners=False,
|
| 93 |
+
antialias=True,
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
if intrinsics is not None:
|
| 97 |
+
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
|
| 98 |
+
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
|
| 99 |
+
intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
|
| 100 |
+
intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
|
| 101 |
+
|
| 102 |
+
return predictions, intrinsics
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def scale_intrinsics_xy(intrinsics, x_ratio, y_ratio):
|
| 106 |
+
intrinsics = intrinsics.clone()
|
| 107 |
+
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * x_ratio
|
| 108 |
+
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * y_ratio
|
| 109 |
+
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * x_ratio
|
| 110 |
+
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * y_ratio
|
| 111 |
+
return intrinsics
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def scale_intrinsics(intrinsics, ratio):
|
| 115 |
+
intrinsics = intrinsics.clone()
|
| 116 |
+
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
|
| 117 |
+
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
|
| 118 |
+
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio
|
| 119 |
+
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio
|
| 120 |
+
return intrinsics
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def unidepthv1_forward(model, rgbs, intrinsics, skip_camera,
|
| 124 |
+
return_raw_preds=False):
|
| 125 |
+
B, _, H, W = rgbs.shape
|
| 126 |
+
|
| 127 |
+
rgbs = TF.normalize(
|
| 128 |
+
rgbs,
|
| 129 |
+
mean=IMAGENET_DATASET_MEAN,
|
| 130 |
+
std=IMAGENET_DATASET_STD,
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
(h, w), ratio = _shapes((H, W), model.image_shape)
|
| 134 |
+
pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), model.image_shape)
|
| 135 |
+
rgbs, gt_intrinsics = _preprocess(
|
| 136 |
+
rgbs,
|
| 137 |
+
intrinsics,
|
| 138 |
+
(h, w),
|
| 139 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 140 |
+
ratio,
|
| 141 |
+
model.image_shape,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
encoder_outputs, cls_tokens = model.pixel_encoder(rgbs)
|
| 145 |
+
if "dino" in model.pixel_encoder.__class__.__name__.lower():
|
| 146 |
+
encoder_outputs = [
|
| 147 |
+
(x + y.unsqueeze(1)).contiguous()
|
| 148 |
+
for x, y in zip(encoder_outputs, cls_tokens)
|
| 149 |
+
]
|
| 150 |
+
|
| 151 |
+
# get data for decoder and adapt to given camera
|
| 152 |
+
inputs = {}
|
| 153 |
+
inputs["encoder_outputs"] = encoder_outputs
|
| 154 |
+
inputs["cls_tokens"] = cls_tokens
|
| 155 |
+
inputs["image"] = rgbs
|
| 156 |
+
if gt_intrinsics is not None:
|
| 157 |
+
rays, angles = generate_rays(
|
| 158 |
+
gt_intrinsics, model.image_shape, noisy=False
|
| 159 |
+
)
|
| 160 |
+
inputs["rays"] = rays
|
| 161 |
+
inputs["angles"] = angles
|
| 162 |
+
inputs["K"] = gt_intrinsics
|
| 163 |
+
model.pixel_decoder.test_fixed_camera = True
|
| 164 |
+
model.pixel_decoder.skip_camera = skip_camera
|
| 165 |
+
|
| 166 |
+
# decode all
|
| 167 |
+
pred_intrinsics, predictions, features, rays = model.pixel_decoder(inputs, {})
|
| 168 |
+
|
| 169 |
+
pads = (pad_left, pad_right, pad_top, pad_bottom)
|
| 170 |
+
|
| 171 |
+
# undo the reshaping and get original image size (slow)
|
| 172 |
+
predictions, pred_intrinsics = _postprocess(
|
| 173 |
+
predictions,
|
| 174 |
+
pred_intrinsics,
|
| 175 |
+
model.image_shape,
|
| 176 |
+
pads,
|
| 177 |
+
ratio,
|
| 178 |
+
(H, W),
|
| 179 |
+
)
|
| 180 |
+
|
| 181 |
+
if return_raw_preds:
|
| 182 |
+
return inputs, predictions
|
| 183 |
+
|
| 184 |
+
# final 3D points backprojection
|
| 185 |
+
intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
|
| 186 |
+
angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
|
| 187 |
+
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
|
| 188 |
+
points_3d = torch.cat((angles, predictions), dim=1)
|
| 189 |
+
points_3d = spherical_zbuffer_to_euclidean(
|
| 190 |
+
points_3d.permute(0, 2, 3, 1)
|
| 191 |
+
).permute(0, 3, 1, 2)
|
| 192 |
+
|
| 193 |
+
# output data
|
| 194 |
+
outputs = {
|
| 195 |
+
"intrinsics": intrinsics,
|
| 196 |
+
"points": points_3d,
|
| 197 |
+
"depth": predictions[:, -1:],
|
| 198 |
+
"depth_feats": features,
|
| 199 |
+
"rays": rays,
|
| 200 |
+
"padding": pads
|
| 201 |
+
}
|
| 202 |
+
model.pixel_decoder.test_fixed_camera = False
|
| 203 |
+
model.pixel_decoder.skip_camera = False
|
| 204 |
+
return inputs, outputs
|
| 205 |
+
|
| 206 |
+
class UniDepthDepth(nn.Module):
|
| 207 |
+
def __init__(
|
| 208 |
+
self,
|
| 209 |
+
cfg,
|
| 210 |
+
return_raw_preds=False
|
| 211 |
+
):
|
| 212 |
+
super().__init__()
|
| 213 |
+
|
| 214 |
+
self.cfg = cfg
|
| 215 |
+
self.return_raw_preds = return_raw_preds
|
| 216 |
+
|
| 217 |
+
if "cnvnxtl" in cfg.model.name:
|
| 218 |
+
self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl")
|
| 219 |
+
elif "vit" in cfg.model.name:
|
| 220 |
+
self.depth_prediction_model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
|
| 221 |
+
|
| 222 |
+
self.skip_camera = True
|
| 223 |
+
|
| 224 |
+
def get_depth(self, img, intrinsics):
|
| 225 |
+
depth_inputs, outputs = unidepthv1_forward(
|
| 226 |
+
self.depth_prediction_model,
|
| 227 |
+
img,
|
| 228 |
+
intrinsics,
|
| 229 |
+
self.skip_camera,
|
| 230 |
+
return_raw_preds=self.return_raw_preds)
|
| 231 |
+
return outputs
|
| 232 |
+
|
| 233 |
+
def forward(self, inputs):
|
| 234 |
+
input_img = inputs["color_aug", 0, 0]
|
| 235 |
+
# here we need the intrinsics of the source image to condition on
|
| 236 |
+
# the depth prediction. needs to account for padding
|
| 237 |
+
if ("K_src", 0) in inputs:
|
| 238 |
+
intrinsics = inputs[("K_src", 0)]
|
| 239 |
+
else:
|
| 240 |
+
intrinsics = None
|
| 241 |
+
|
| 242 |
+
depth_inputs, outputs = unidepthv1_forward(
|
| 243 |
+
self.depth_prediction_model,
|
| 244 |
+
input_img,
|
| 245 |
+
intrinsics,
|
| 246 |
+
self.skip_camera,
|
| 247 |
+
return_raw_preds=self.return_raw_preds)
|
| 248 |
+
|
| 249 |
+
return depth_inputs, outputs
|
| 250 |
+
|
| 251 |
+
class UniDepthUnprojector(nn.Module):
|
| 252 |
+
def __init__(
|
| 253 |
+
self,
|
| 254 |
+
cfg
|
| 255 |
+
):
|
| 256 |
+
super().__init__()
|
| 257 |
+
|
| 258 |
+
self.cfg = cfg
|
| 259 |
+
|
| 260 |
+
if cfg.model.name == "unidepth_unprojector_cnvnxtl":
|
| 261 |
+
model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-cnvnxtl")
|
| 262 |
+
elif cfg.model.name == "unidepth_unprojector_vit":
|
| 263 |
+
model = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
|
| 264 |
+
self.unidepth = model
|
| 265 |
+
|
| 266 |
+
self.skip_camera = True
|
| 267 |
+
|
| 268 |
+
self.register_buffer("gauss_opacity", torch.ones(1, 1, 1).float())
|
| 269 |
+
self.register_buffer("gauss_scaling", torch.ones(3, 1, 1).float())
|
| 270 |
+
self.register_buffer("gauss_rotation", torch.ones(4, 1, 1).float() * 0.5)
|
| 271 |
+
self.register_buffer("gauss_features_rest", torch.zeros(9, 1, 1).float())
|
| 272 |
+
self.register_buffer("gauss_offset", torch.zeros(3, 1, 1).float())
|
| 273 |
+
|
| 274 |
+
self.all_params = nn.ParameterDict({
|
| 275 |
+
"opacity_scaling": nn.Parameter(torch.tensor(cfg.model.opacity_bias).float()),
|
| 276 |
+
"scale_scaling": nn.Parameter(torch.tensor(cfg.model.scale_bias).float()),
|
| 277 |
+
"colour_scaling": nn.Parameter(torch.tensor(self.cfg.model.colour_scale).float())})
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
self.scaling_activation = torch.exp
|
| 281 |
+
self.opacity_activation = torch.sigmoid
|
| 282 |
+
self.relu = nn.ReLU()
|
| 283 |
+
|
| 284 |
+
def get_parameter_groups(self):
|
| 285 |
+
# tune scalars for size, opacity and colour modulation
|
| 286 |
+
return [{'params': self.all_params.parameters()}]
|
| 287 |
+
|
| 288 |
+
def forward(self, inputs):
|
| 289 |
+
model = self.unidepth
|
| 290 |
+
input_img = inputs["color_aug", 0, 0]
|
| 291 |
+
# here we need the intrinsics of the source image to condition on
|
| 292 |
+
# the depth prediction. needs to account for padding
|
| 293 |
+
intrinsics = inputs[("K_src", 0)]
|
| 294 |
+
b, c, h, w = inputs["color_aug", 0, 0].shape
|
| 295 |
+
|
| 296 |
+
with torch.no_grad():
|
| 297 |
+
_, depth_outs = unidepthv1_forward(model, input_img, intrinsics, self.skip_camera)
|
| 298 |
+
|
| 299 |
+
outs = {}
|
| 300 |
+
|
| 301 |
+
outs[("gauss_opacity", 0)] = self.gauss_opacity.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
|
| 302 |
+
* self.opacity_activation(self.all_params["opacity_scaling"])
|
| 303 |
+
if not self.cfg.model.scale_with_depth:
|
| 304 |
+
outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
|
| 305 |
+
* self.scaling_activation(self.all_params["scale_scaling"])
|
| 306 |
+
else:
|
| 307 |
+
outs[("gauss_scaling", 0)] = self.gauss_scaling.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w) \
|
| 308 |
+
* self.scaling_activation(self.all_params["scale_scaling"]) * depth_outs["depth"] / 10.0
|
| 309 |
+
outs[("gauss_rotation", 0)] = self.gauss_rotation.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
|
| 310 |
+
outs[("gauss_offset", 0)] = self.gauss_offset.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
|
| 311 |
+
outs[("gauss_features_rest", 0)] = self.gauss_features_rest.unsqueeze(0).expand(depth_outs["depth"].shape[0], -1, h, w)
|
| 312 |
+
# rendering adds 0.5 to go from rendered colours to output
|
| 313 |
+
outs[("gauss_features_dc", 0)] = (input_img - 0.5)* self.relu(self.all_params["colour_scaling"])
|
| 314 |
+
|
| 315 |
+
outs[("depth", 0)] = depth_outs["depth"]
|
| 316 |
+
|
| 317 |
+
return outs
|
| 318 |
+
|
| 319 |
+
class UniDepthSplatter(nn.Module):
|
| 320 |
+
def __init__(
|
| 321 |
+
self,
|
| 322 |
+
cfg
|
| 323 |
+
):
|
| 324 |
+
super().__init__()
|
| 325 |
+
|
| 326 |
+
self.cfg = cfg
|
| 327 |
+
|
| 328 |
+
config_path = Path("/work/eldar/src/UniDepth")
|
| 329 |
+
with open(config_path / "configs/config_v1_cnvnxtl.json") as f:
|
| 330 |
+
config = json.load(f)
|
| 331 |
+
self.unidepth = UniDepthDepth(self.cfg)
|
| 332 |
+
|
| 333 |
+
hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
|
| 334 |
+
expansion = config["model"]["expansion"]
|
| 335 |
+
depth = config["model"]["pixel_decoder"]["depths"]
|
| 336 |
+
num_heads = config["model"]["num_heads"]
|
| 337 |
+
dropout = config["model"]["pixel_decoder"]["dropout"]
|
| 338 |
+
layer_scale = 1.0
|
| 339 |
+
self.splat_decoder = GaussSplatHead(
|
| 340 |
+
cfg,
|
| 341 |
+
hidden_dim=hidden_dim,
|
| 342 |
+
num_heads=num_heads,
|
| 343 |
+
expansion=expansion,
|
| 344 |
+
depths=depth,
|
| 345 |
+
camera_dim=81,
|
| 346 |
+
dropout=dropout,
|
| 347 |
+
layer_scale=layer_scale,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
self.skip_camera = True
|
| 351 |
+
|
| 352 |
+
def get_parameter_groups(self):
|
| 353 |
+
base_lr = self.cfg.optimiser.learning_rate
|
| 354 |
+
return [
|
| 355 |
+
{'params': self.unidepth.parameters(), "lr": base_lr * 0.05},
|
| 356 |
+
{'params': self.splat_decoder.parameters()}
|
| 357 |
+
]
|
| 358 |
+
|
| 359 |
+
def forward(self, inputs):
|
| 360 |
+
gauss_head = self.splat_decoder
|
| 361 |
+
|
| 362 |
+
depth_inputs, depth_outs = self.unidepth(inputs)
|
| 363 |
+
depth_feats = depth_outs["depth_feats"]
|
| 364 |
+
rays = depth_outs["rays"]
|
| 365 |
+
padding = depth_outs["padding"]
|
| 366 |
+
|
| 367 |
+
B, _, H, W = depth_inputs["image"].shape
|
| 368 |
+
|
| 369 |
+
# TODO remove hardcoded shapes
|
| 370 |
+
common_shape = (28, 38)
|
| 371 |
+
gauss_head.set_shapes(common_shape)
|
| 372 |
+
gauss_head.set_original_shapes((H, W))
|
| 373 |
+
|
| 374 |
+
depth_feats = rearrange(depth_feats, "b c h w -> b (h w) c")
|
| 375 |
+
outs = gauss_head(
|
| 376 |
+
latents_16=depth_feats,
|
| 377 |
+
rays_hr=rays,
|
| 378 |
+
)
|
| 379 |
+
for k, v in outs.items():
|
| 380 |
+
pred, _ = _postprocess([v], None, self.unidepth.depth_prediction_model.image_shape,
|
| 381 |
+
padding, None, inputs["color_aug", 0, 0].shape[2:4])
|
| 382 |
+
outs[k] = pred
|
| 383 |
+
outs[("depth", 0)] = depth_outs["depth"]
|
| 384 |
+
|
| 385 |
+
return outs
|
| 386 |
+
|
| 387 |
+
|
| 388 |
+
class GaussSplatHead(nn.Module):
|
| 389 |
+
def __init__(
|
| 390 |
+
self,
|
| 391 |
+
cfg,
|
| 392 |
+
hidden_dim: int,
|
| 393 |
+
num_heads: int = 8,
|
| 394 |
+
expansion: int = 4,
|
| 395 |
+
depths: int | list[int] = 4,
|
| 396 |
+
camera_dim: int = 256,
|
| 397 |
+
dropout: float = 0.0,
|
| 398 |
+
layer_scale: float = 1.0,
|
| 399 |
+
) -> None:
|
| 400 |
+
super().__init__()
|
| 401 |
+
|
| 402 |
+
self.cfg = cfg
|
| 403 |
+
|
| 404 |
+
if isinstance(depths, int):
|
| 405 |
+
depths = [depths] * 3
|
| 406 |
+
assert len(depths) == 3
|
| 407 |
+
|
| 408 |
+
self.project_rays16 = MLP(
|
| 409 |
+
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim
|
| 410 |
+
)
|
| 411 |
+
self.project_rays8 = MLP(
|
| 412 |
+
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2
|
| 413 |
+
)
|
| 414 |
+
self.project_rays4 = MLP(
|
| 415 |
+
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4
|
| 416 |
+
)
|
| 417 |
+
|
| 418 |
+
self.layers_8 = nn.ModuleList([])
|
| 419 |
+
self.layers_4 = nn.ModuleList([])
|
| 420 |
+
layers_16 = nn.ModuleList([])
|
| 421 |
+
|
| 422 |
+
self.up8 = ConvUpsample(
|
| 423 |
+
hidden_dim, expansion=expansion, layer_scale=layer_scale
|
| 424 |
+
)
|
| 425 |
+
self.up4 = ConvUpsample(
|
| 426 |
+
hidden_dim // 2, expansion=expansion, layer_scale=layer_scale
|
| 427 |
+
)
|
| 428 |
+
self.up2 = ConvUpsample(
|
| 429 |
+
hidden_dim // 4, expansion=expansion, layer_scale=layer_scale
|
| 430 |
+
)
|
| 431 |
+
|
| 432 |
+
split_dimensions, scale, bias = get_splits_and_inits(cfg)
|
| 433 |
+
start = 1
|
| 434 |
+
self.split_dimensions = split_dimensions[start:]
|
| 435 |
+
scale = scale[start:]
|
| 436 |
+
bias = bias[start:]
|
| 437 |
+
|
| 438 |
+
self.num_output_channels = sum(self.split_dimensions)
|
| 439 |
+
|
| 440 |
+
self.out2 = nn.Conv2d(hidden_dim // 8, self.num_output_channels, 3, padding=1)
|
| 441 |
+
# self.out4 = nn.Conv2d(hidden_dim // 4, self.num_output_channels, 3, padding=1)
|
| 442 |
+
# self.out8 = nn.Conv2d(hidden_dim // 2, self.num_output_channels, 3, padding=1)
|
| 443 |
+
|
| 444 |
+
start_channels = 0
|
| 445 |
+
for out_channel, b, s in zip(self.split_dimensions, bias, scale):
|
| 446 |
+
nn.init.xavier_uniform_(
|
| 447 |
+
self.out2.weight[start_channels:start_channels+out_channel,
|
| 448 |
+
:, :, :], s)
|
| 449 |
+
nn.init.constant_(
|
| 450 |
+
self.out2.bias[start_channels:start_channels+out_channel], b)
|
| 451 |
+
start_channels += out_channel
|
| 452 |
+
|
| 453 |
+
for i, (blk_lst, depth) in enumerate(
|
| 454 |
+
zip([layers_16, self.layers_8, self.layers_4], depths)
|
| 455 |
+
):
|
| 456 |
+
if i == 0:
|
| 457 |
+
continue
|
| 458 |
+
attn_cls = AttentionBlock if i == 0 else NystromBlock
|
| 459 |
+
for _ in range(depth):
|
| 460 |
+
blk_lst.append(
|
| 461 |
+
attn_cls(
|
| 462 |
+
hidden_dim // (2**i),
|
| 463 |
+
num_heads=num_heads // (2**i),
|
| 464 |
+
expansion=expansion,
|
| 465 |
+
dropout=dropout,
|
| 466 |
+
layer_scale=layer_scale,
|
| 467 |
+
)
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
self.scaling_activation = torch.exp
|
| 471 |
+
self.opacity_activation = torch.sigmoid
|
| 472 |
+
self.rotation_activation = torch.nn.functional.normalize
|
| 473 |
+
self.scaling_lambda = cfg.model.scale_lambda
|
| 474 |
+
self.sigmoid = nn.Sigmoid()
|
| 475 |
+
|
| 476 |
+
def set_original_shapes(self, shapes: Tuple[int, int]):
|
| 477 |
+
self.original_shapes = shapes
|
| 478 |
+
|
| 479 |
+
def set_shapes(self, shapes: Tuple[int, int]):
|
| 480 |
+
self.shapes = shapes
|
| 481 |
+
|
| 482 |
+
def forward(
|
| 483 |
+
self, latents_16: torch.Tensor, rays_hr: torch.Tensor
|
| 484 |
+
) -> torch.Tensor:
|
| 485 |
+
shapes = self.shapes
|
| 486 |
+
|
| 487 |
+
# camera_embedding
|
| 488 |
+
# torch.cuda.synchronize()
|
| 489 |
+
# start = time()
|
| 490 |
+
rays_embedding_16 = F.normalize(
|
| 491 |
+
flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1
|
| 492 |
+
)
|
| 493 |
+
rays_embedding_8 = F.normalize(
|
| 494 |
+
flat_interpolate(
|
| 495 |
+
rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes]
|
| 496 |
+
),
|
| 497 |
+
dim=-1,
|
| 498 |
+
)
|
| 499 |
+
rays_embedding_4 = F.normalize(
|
| 500 |
+
flat_interpolate(
|
| 501 |
+
rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes]
|
| 502 |
+
),
|
| 503 |
+
dim=-1,
|
| 504 |
+
)
|
| 505 |
+
rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16))
|
| 506 |
+
rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8))
|
| 507 |
+
rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4))
|
| 508 |
+
|
| 509 |
+
# Block 16 - Out 8
|
| 510 |
+
latents_8 = self.up8(
|
| 511 |
+
rearrange(
|
| 512 |
+
latents_16 + rays_embedding_16,
|
| 513 |
+
"b (h w) c -> b c h w",
|
| 514 |
+
h=shapes[0],
|
| 515 |
+
w=shapes[1],
|
| 516 |
+
).contiguous()
|
| 517 |
+
)
|
| 518 |
+
# out8 = self.out8(
|
| 519 |
+
# rearrange(
|
| 520 |
+
# latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2
|
| 521 |
+
# )
|
| 522 |
+
# )
|
| 523 |
+
|
| 524 |
+
# Block 8 - Out 4
|
| 525 |
+
for layer in self.layers_8:
|
| 526 |
+
latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
|
| 527 |
+
latents_4 = self.up4(
|
| 528 |
+
rearrange(
|
| 529 |
+
latents_8 + rays_embedding_8,
|
| 530 |
+
"b (h w) c -> b c h w",
|
| 531 |
+
h=shapes[0] * 2,
|
| 532 |
+
w=shapes[1] * 2,
|
| 533 |
+
).contiguous()
|
| 534 |
+
)
|
| 535 |
+
# out4 = self.out4(
|
| 536 |
+
# rearrange(
|
| 537 |
+
# latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4
|
| 538 |
+
# )
|
| 539 |
+
# )
|
| 540 |
+
|
| 541 |
+
# Block 4 - Out 2
|
| 542 |
+
for layer in self.layers_4:
|
| 543 |
+
latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
|
| 544 |
+
latents_2 = self.up2(
|
| 545 |
+
rearrange(
|
| 546 |
+
latents_4 + rays_embedding_4,
|
| 547 |
+
"b (h w) c -> b c h w",
|
| 548 |
+
h=shapes[0] * 4,
|
| 549 |
+
w=shapes[1] * 4,
|
| 550 |
+
).contiguous()
|
| 551 |
+
)
|
| 552 |
+
out2 = self.out2(
|
| 553 |
+
rearrange(
|
| 554 |
+
latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8
|
| 555 |
+
)
|
| 556 |
+
)
|
| 557 |
+
|
| 558 |
+
split_network_outputs = out2.split(self.split_dimensions, dim=1)
|
| 559 |
+
last = 5
|
| 560 |
+
offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:last]
|
| 561 |
+
|
| 562 |
+
out = {
|
| 563 |
+
("gauss_opacity", 0): self.opacity_activation(opacity),
|
| 564 |
+
("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda,
|
| 565 |
+
("gauss_rotation", 0): self.rotation_activation(rotation),
|
| 566 |
+
("gauss_features_dc", 0): feat_dc
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
if self.cfg.model.max_sh_degree > 0:
|
| 570 |
+
features_rest = split_network_outputs[last]
|
| 571 |
+
out[("gauss_features_rest", 0)] = features_rest
|
| 572 |
+
|
| 573 |
+
if self.cfg.model.predict_offset:
|
| 574 |
+
out[("gauss_offset", 0)] = offset
|
| 575 |
+
|
| 576 |
+
return out
|
| 577 |
+
# return out8, out4, out2, proj_latents_16
|
flash3d/networks/unidepth_extension.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from einops import rearrange
|
| 5 |
+
|
| 6 |
+
from .unidepth import UniDepthDepth
|
| 7 |
+
from unidepth.models import UniDepthV1
|
| 8 |
+
from .resnet_encoder import ResnetEncoder
|
| 9 |
+
from .gaussian_decoder import GaussianDecoder
|
| 10 |
+
from .depth_decoder import DepthDecoder
|
| 11 |
+
|
| 12 |
+
from networks.layers import disp_to_depth
|
| 13 |
+
from networks.gaussian_decoder import get_splits_and_inits
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class UniDepthExtended(nn.Module):
|
| 17 |
+
def __init__(self,cfg):
|
| 18 |
+
super().__init__()
|
| 19 |
+
|
| 20 |
+
self.cfg = cfg
|
| 21 |
+
|
| 22 |
+
self.unidepth = UniDepthDepth(cfg)
|
| 23 |
+
# self.unidepth = UniDepthV1.from_pretrained("lpiccinelli/unidepth-v1-vitl14")
|
| 24 |
+
|
| 25 |
+
self.parameters_to_train = []
|
| 26 |
+
if self.cfg.model.splat_branch == "resnet":
|
| 27 |
+
self.encoder = ResnetEncoder(cfg.model.num_layers,
|
| 28 |
+
cfg.model.weights_init == "pretrained",
|
| 29 |
+
cfg.model.resnet_bn_order
|
| 30 |
+
)
|
| 31 |
+
# change encoder to take depth as conditioning
|
| 32 |
+
if self.cfg.model.depth_cond:
|
| 33 |
+
self.encoder.encoder.conv1 = nn.Conv2d(
|
| 34 |
+
4,
|
| 35 |
+
self.encoder.encoder.conv1.out_channels,
|
| 36 |
+
kernel_size = self.encoder.encoder.conv1.kernel_size,
|
| 37 |
+
padding = self.encoder.encoder.conv1.padding,
|
| 38 |
+
stride = self.encoder.encoder.conv1.stride
|
| 39 |
+
)
|
| 40 |
+
self.parameters_to_train += [{"params": self.encoder.parameters()}]
|
| 41 |
+
|
| 42 |
+
# use depth branch only for more gaussians
|
| 43 |
+
if cfg.model.gaussians_per_pixel > 1:
|
| 44 |
+
models ={}
|
| 45 |
+
models["depth"] = DepthDecoder(cfg, self.encoder.num_ch_enc)
|
| 46 |
+
self.parameters_to_train +=[{"params": models["depth"].parameters()}]
|
| 47 |
+
for i in range(cfg.model.gaussians_per_pixel):
|
| 48 |
+
models["gauss_decoder_"+str(i)] = GaussianDecoder(cfg, self.encoder.num_ch_enc)
|
| 49 |
+
self.parameters_to_train += [{"params": models["gauss_decoder_"+str(i)].parameters()}]
|
| 50 |
+
if cfg.model.one_gauss_decoder:
|
| 51 |
+
break
|
| 52 |
+
self.models = nn.ModuleDict(models)
|
| 53 |
+
else:
|
| 54 |
+
self.gauss_decoder = GaussianDecoder(cfg, self.encoder.num_ch_enc)
|
| 55 |
+
self.parameters_to_train += [{"params": self.gauss_decoder.parameters()}]
|
| 56 |
+
|
| 57 |
+
elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl":
|
| 58 |
+
self.splat_branch = UniDepthDepth(cfg,
|
| 59 |
+
return_raw_preds=True)
|
| 60 |
+
# modify the head to output the channels for Gaussian parameters
|
| 61 |
+
self.init_ouput_head_splat_branch()
|
| 62 |
+
self.parameters_to_train +=[{"params": self.splat_branch.parameters()}]
|
| 63 |
+
|
| 64 |
+
self.scaling_activation = torch.exp
|
| 65 |
+
self.opacity_activation = torch.sigmoid
|
| 66 |
+
self.rotation_activation = torch.nn.functional.normalize
|
| 67 |
+
|
| 68 |
+
def init_ouput_head_splat_branch(self):
|
| 69 |
+
split_dimensions, scale, bias = get_splits_and_inits(self.cfg)
|
| 70 |
+
# the first dim in the output is for depth - we don't use that in this branch
|
| 71 |
+
self.split_dimensions = split_dimensions[1:]
|
| 72 |
+
scale = scale[1:]
|
| 73 |
+
bias = bias[1:]
|
| 74 |
+
|
| 75 |
+
self.num_output_channels = sum(self.split_dimensions)
|
| 76 |
+
|
| 77 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2 = \
|
| 78 |
+
nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.in_channels,
|
| 79 |
+
self.num_output_channels,
|
| 80 |
+
kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.kernel_size,
|
| 81 |
+
padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.padding)
|
| 82 |
+
|
| 83 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4 = \
|
| 84 |
+
nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.in_channels,
|
| 85 |
+
self.num_output_channels,
|
| 86 |
+
kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.kernel_size,
|
| 87 |
+
padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.padding)
|
| 88 |
+
|
| 89 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8 = \
|
| 90 |
+
nn.Conv2d(self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.in_channels,
|
| 91 |
+
self.num_output_channels,
|
| 92 |
+
kernel_size = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.kernel_size,
|
| 93 |
+
padding = self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.padding)
|
| 94 |
+
|
| 95 |
+
start_channels = 0
|
| 96 |
+
for out_channel, b, s in zip(split_dimensions, bias, scale):
|
| 97 |
+
nn.init.xavier_uniform_(
|
| 98 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.weight[start_channels:start_channels+out_channel,
|
| 99 |
+
:, :, :], s)
|
| 100 |
+
nn.init.constant_(
|
| 101 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out2.bias[start_channels:start_channels+out_channel], b)
|
| 102 |
+
start_channels += out_channel
|
| 103 |
+
|
| 104 |
+
start_channels = 0
|
| 105 |
+
for out_channel, b, s in zip(split_dimensions, bias, scale):
|
| 106 |
+
nn.init.xavier_uniform_(
|
| 107 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.weight[start_channels:start_channels+out_channel,
|
| 108 |
+
:, :, :], s)
|
| 109 |
+
nn.init.constant_(
|
| 110 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out4.bias[start_channels:start_channels+out_channel], b)
|
| 111 |
+
start_channels += out_channel
|
| 112 |
+
|
| 113 |
+
start_channels = 0
|
| 114 |
+
for out_channel, b, s in zip(split_dimensions, bias, scale):
|
| 115 |
+
nn.init.xavier_uniform_(
|
| 116 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.weight[start_channels:start_channels+out_channel,
|
| 117 |
+
:, :, :], s)
|
| 118 |
+
nn.init.constant_(
|
| 119 |
+
self.splat_branch.depth_prediction_model.pixel_decoder.depth_layer.out8.bias[start_channels:start_channels+out_channel], b)
|
| 120 |
+
start_channels += out_channel
|
| 121 |
+
|
| 122 |
+
def get_parameter_groups(self):
|
| 123 |
+
# only the resnet encoder and gaussian parameter decoder are optimisable
|
| 124 |
+
return self.parameters_to_train
|
| 125 |
+
|
| 126 |
+
def forward(self, inputs):
|
| 127 |
+
if ('unidepth', 0, 0) in inputs.keys() and inputs[('unidepth', 0, 0)] is not None:
|
| 128 |
+
depth_outs = dict()
|
| 129 |
+
depth_outs["depth"] = inputs[('unidepth', 0, 0)]
|
| 130 |
+
else:
|
| 131 |
+
with torch.no_grad():
|
| 132 |
+
# if self.training and self.cfg.dataset.pad_border_aug > 0:
|
| 133 |
+
# pad = self.cfg.dataset.pad_border_aug
|
| 134 |
+
# input = inputs["color_aug", 0, 0][:,:,pad:-pad, pad:-pad]
|
| 135 |
+
# intrincs = inputs[("K_tgt", 0)]
|
| 136 |
+
# else:
|
| 137 |
+
# input = inputs["color_aug", 0, 0]
|
| 138 |
+
# intrincs = inputs[("K_src", 0)]
|
| 139 |
+
_, depth_outs = self.unidepth(inputs)
|
| 140 |
+
# depth_outs = self.unidepth.infer(input, intrincs)
|
| 141 |
+
# if self.training and self.cfg.dataset.pad_border_aug > 0:
|
| 142 |
+
# depth_outs["depth"] = F.pad(depth_outs["depth"], (pad,pad,pad,pad), mode="replicate")
|
| 143 |
+
|
| 144 |
+
outputs_gauss = {}
|
| 145 |
+
|
| 146 |
+
K = depth_outs["intrinsics"]
|
| 147 |
+
outputs_gauss[("K_src", 0)] = K
|
| 148 |
+
outputs_gauss[("inv_K_src", 0)] = torch.linalg.inv(K)
|
| 149 |
+
|
| 150 |
+
if self.cfg.model.splat_branch == "resnet":
|
| 151 |
+
if self.cfg.model.depth_cond:
|
| 152 |
+
# division by 20 is to put depth in a similar range to RGB
|
| 153 |
+
resnet_input = torch.cat([inputs["color_aug", 0, 0],
|
| 154 |
+
depth_outs["depth"] / 20.0], dim=1)
|
| 155 |
+
else:
|
| 156 |
+
resnet_input = inputs["color_aug", 0, 0]
|
| 157 |
+
resnet_features = self.encoder(resnet_input)
|
| 158 |
+
if self.cfg.model.gaussians_per_pixel > 1:
|
| 159 |
+
pred_depth = dict()
|
| 160 |
+
depth = self.models["depth"](resnet_features)
|
| 161 |
+
if self.cfg.model.depth_type == "disp":
|
| 162 |
+
for key, v in depth.items():
|
| 163 |
+
_, pred_depth[("depth", key[1])] = disp_to_depth(v, self.cfg.model.min_depth, self.cfg.model.max_depth)
|
| 164 |
+
elif self.cfg.model.depth_type in ["depth", "depth_inc"]:
|
| 165 |
+
pred_depth = depth
|
| 166 |
+
pred_depth[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "(b n) ... -> b n ...", n=self.cfg.model.gaussians_per_pixel - 1)
|
| 167 |
+
if self.cfg.model.depth_type in ["depth_inc", "disp_inc"]:
|
| 168 |
+
pred_depth[("depth", 0)] = torch.cumsum(torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1), dim=1)
|
| 169 |
+
else:
|
| 170 |
+
pred_depth[("depth", 0)] = torch.cat((depth_outs["depth"][:,None,...], pred_depth[("depth", 0)]), dim=1)
|
| 171 |
+
outputs_gauss[("depth", 0)] = rearrange(pred_depth[("depth", 0)], "b n c ... -> (b n) c ...", n = self.cfg.model.gaussians_per_pixel)
|
| 172 |
+
gauss_outs = dict()
|
| 173 |
+
for i in range(self.cfg.model.gaussians_per_pixel):
|
| 174 |
+
outs = self.models["gauss_decoder_"+str(i)](resnet_features)
|
| 175 |
+
if not self.cfg.model.one_gauss_decoder:
|
| 176 |
+
for key, v in outs.items():
|
| 177 |
+
gauss_outs[key] = outs[key][:,None,...] if i==0 else torch.cat([gauss_outs[key], outs[key][:,None,...]], dim=1)
|
| 178 |
+
else:
|
| 179 |
+
gauss_outs |= outs
|
| 180 |
+
for key, v in gauss_outs.items():
|
| 181 |
+
gauss_outs[key] = rearrange(gauss_outs[key], 'b n ... -> (b n) ...')
|
| 182 |
+
outputs_gauss |= gauss_outs
|
| 183 |
+
else:
|
| 184 |
+
outputs_gauss[("depth", 0)] = depth_outs["depth"]
|
| 185 |
+
outputs_gauss |= self.gauss_decoder(resnet_features)
|
| 186 |
+
elif self.cfg.model.splat_branch == "unidepth_vit" or self.cfg.model.splat_branch == "unidepth_cnvnxtl":
|
| 187 |
+
split_network_outputs = self.splat_branch(inputs)[1].split(self.split_dimensions, dim=1)
|
| 188 |
+
offset, opacity, scaling, rotation, feat_dc = split_network_outputs[:5]
|
| 189 |
+
|
| 190 |
+
outputs_gauss |= {
|
| 191 |
+
("gauss_opacity", 0): self.opacity_activation(opacity),
|
| 192 |
+
("gauss_scaling", 0): self.scaling_activation(scaling),
|
| 193 |
+
("gauss_rotation", 0): self.rotation_activation(rotation),
|
| 194 |
+
("gauss_features_dc", 0): feat_dc
|
| 195 |
+
}
|
| 196 |
+
|
| 197 |
+
if self.cfg.model.max_sh_degree > 0:
|
| 198 |
+
features_rest = split_network_outputs[5]
|
| 199 |
+
outputs_gauss[("gauss_features_rest", 0)] = features_rest
|
| 200 |
+
|
| 201 |
+
assert self.cfg.model.predict_offset
|
| 202 |
+
outputs_gauss[("gauss_offset", 0)] = offset
|
| 203 |
+
|
| 204 |
+
return outputs_gauss
|
| 205 |
+
|
flash3d/unidepth/layers/__init__.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .activation import SwiGLU, GEGLU
|
| 2 |
+
from .convnext import CvnxtBlock
|
| 3 |
+
from .attention import AttentionBlock, AttentionDecoderBlock
|
| 4 |
+
from .nystrom_attention import NystromBlock
|
| 5 |
+
from .positional_encoding import PositionEmbeddingSine
|
| 6 |
+
from .upsample import ConvUpsample, ConvUpsampleShuffle
|
| 7 |
+
from .mlp import MLP
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"SwiGLU",
|
| 12 |
+
"GEGLU",
|
| 13 |
+
"CvnxtBlock",
|
| 14 |
+
"AttentionBlock",
|
| 15 |
+
"NystromBlock",
|
| 16 |
+
"PositionEmbeddingSine",
|
| 17 |
+
"ConvUpsample",
|
| 18 |
+
"MLP",
|
| 19 |
+
"ConvUpsampleShuffle",
|
| 20 |
+
"AttentionDecoderBlock",
|
| 21 |
+
]
|
flash3d/unidepth/layers/activation.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class SwiGLU(nn.Module):
|
| 7 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 8 |
+
x, gates = x.chunk(2, dim=-1)
|
| 9 |
+
return x * F.silu(gates)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class GEGLU(nn.Module):
|
| 13 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 14 |
+
x, gates = x.chunk(2, dim=-1)
|
| 15 |
+
return x * F.gelu(gates)
|
flash3d/unidepth/layers/attention.py
ADDED
|
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
|
| 13 |
+
from .layer_scale import LayerScale
|
| 14 |
+
from .mlp import MLP
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SimpleAttention(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
dim: int,
|
| 21 |
+
num_heads: int = 4,
|
| 22 |
+
dropout: float = 0.0,
|
| 23 |
+
cosine: bool = False,
|
| 24 |
+
context_dim: int | None = None,
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.dropout = dropout
|
| 28 |
+
self.num_heads = num_heads
|
| 29 |
+
self.hidden_dim = dim
|
| 30 |
+
context_dim = context_dim or dim
|
| 31 |
+
|
| 32 |
+
self.kv = nn.Linear(context_dim, dim * 2, bias=False)
|
| 33 |
+
self.q = nn.Linear(dim, dim, bias=False)
|
| 34 |
+
self.norm_attnx = nn.LayerNorm(dim)
|
| 35 |
+
self.norm_attnctx = nn.LayerNorm(context_dim)
|
| 36 |
+
self.cosine = cosine
|
| 37 |
+
self.out = nn.Linear(dim, dim)
|
| 38 |
+
|
| 39 |
+
def forward(
|
| 40 |
+
self,
|
| 41 |
+
x: torch.Tensor,
|
| 42 |
+
attn_bias: torch.Tensor | None = None,
|
| 43 |
+
context: torch.Tensor | None = None,
|
| 44 |
+
pos_embed: torch.Tensor | None = None,
|
| 45 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 46 |
+
rope: nn.Module | None = None,
|
| 47 |
+
) -> torch.Tensor:
|
| 48 |
+
context = x if context is None else context
|
| 49 |
+
x = self.norm_attnx(x)
|
| 50 |
+
context = self.norm_attnctx(context)
|
| 51 |
+
k, v = rearrange(
|
| 52 |
+
self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
|
| 53 |
+
).unbind(dim=-1)
|
| 54 |
+
q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
|
| 55 |
+
|
| 56 |
+
if rope is not None:
|
| 57 |
+
q = rope(q)
|
| 58 |
+
k = rope(k)
|
| 59 |
+
else:
|
| 60 |
+
if pos_embed is not None:
|
| 61 |
+
pos_embed = rearrange(
|
| 62 |
+
pos_embed, "b n (h d) -> b h n d", h=self.num_heads
|
| 63 |
+
)
|
| 64 |
+
q = q + pos_embed
|
| 65 |
+
if pos_embed_context is not None:
|
| 66 |
+
pos_embed_context = rearrange(
|
| 67 |
+
pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
|
| 68 |
+
)
|
| 69 |
+
k = k + pos_embed_context
|
| 70 |
+
|
| 71 |
+
if self.cosine:
|
| 72 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
| 73 |
+
x = F.scaled_dot_product_attention(
|
| 74 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
| 75 |
+
)
|
| 76 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
| 77 |
+
x = self.out(x)
|
| 78 |
+
return x
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
class AttentionBlock(nn.Module):
|
| 82 |
+
def __init__(
|
| 83 |
+
self,
|
| 84 |
+
dim: int,
|
| 85 |
+
num_heads: int = 4,
|
| 86 |
+
expansion: int = 4,
|
| 87 |
+
dropout: float = 0.0,
|
| 88 |
+
cosine: bool = False,
|
| 89 |
+
gated: bool = False,
|
| 90 |
+
layer_scale: float = 1.0,
|
| 91 |
+
context_dim: int | None = None,
|
| 92 |
+
):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.dropout = dropout
|
| 95 |
+
self.num_heads = num_heads
|
| 96 |
+
self.hidden_dim = dim
|
| 97 |
+
context_dim = context_dim or dim
|
| 98 |
+
self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
|
| 99 |
+
self.kv = nn.Linear(context_dim, dim * 2)
|
| 100 |
+
self.q = nn.Linear(dim, dim)
|
| 101 |
+
self.norm_attnx = nn.LayerNorm(dim)
|
| 102 |
+
self.norm_attnctx = nn.LayerNorm(context_dim)
|
| 103 |
+
self.cosine = cosine
|
| 104 |
+
self.out = nn.Linear(dim, dim)
|
| 105 |
+
self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
| 106 |
+
self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
| 107 |
+
|
| 108 |
+
def attn(
|
| 109 |
+
self,
|
| 110 |
+
x: torch.Tensor,
|
| 111 |
+
attn_bias: torch.Tensor | None = None,
|
| 112 |
+
context: torch.Tensor | None = None,
|
| 113 |
+
pos_embed: torch.Tensor | None = None,
|
| 114 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 115 |
+
rope: nn.Module | None = None,
|
| 116 |
+
) -> torch.Tensor:
|
| 117 |
+
x = self.norm_attnx(x)
|
| 118 |
+
context = self.norm_attnctx(context)
|
| 119 |
+
k, v = rearrange(
|
| 120 |
+
self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
|
| 121 |
+
).unbind(dim=-1)
|
| 122 |
+
q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
|
| 123 |
+
|
| 124 |
+
if rope is not None:
|
| 125 |
+
q = rope(q)
|
| 126 |
+
k = rope(k)
|
| 127 |
+
else:
|
| 128 |
+
if pos_embed is not None:
|
| 129 |
+
pos_embed = rearrange(
|
| 130 |
+
pos_embed, "b n (h d) -> b h n d", h=self.num_heads
|
| 131 |
+
)
|
| 132 |
+
q = q + pos_embed
|
| 133 |
+
if pos_embed_context is not None:
|
| 134 |
+
pos_embed_context = rearrange(
|
| 135 |
+
pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
|
| 136 |
+
)
|
| 137 |
+
k = k + pos_embed_context
|
| 138 |
+
|
| 139 |
+
if self.cosine:
|
| 140 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
| 141 |
+
|
| 142 |
+
x = F.scaled_dot_product_attention(
|
| 143 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
| 144 |
+
)
|
| 145 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
| 146 |
+
x = self.out(x)
|
| 147 |
+
return x
|
| 148 |
+
|
| 149 |
+
def forward(
|
| 150 |
+
self,
|
| 151 |
+
x: torch.Tensor,
|
| 152 |
+
attn_bias: torch.Tensor | None = None,
|
| 153 |
+
context: torch.Tensor | None = None,
|
| 154 |
+
pos_embed: torch.Tensor | None = None,
|
| 155 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 156 |
+
rope: nn.Module | None = None,
|
| 157 |
+
) -> torch.Tensor:
|
| 158 |
+
context = x if context is None else context
|
| 159 |
+
x = (
|
| 160 |
+
self.ls1(
|
| 161 |
+
self.attn(
|
| 162 |
+
x,
|
| 163 |
+
rope=rope,
|
| 164 |
+
attn_bias=attn_bias,
|
| 165 |
+
context=context,
|
| 166 |
+
pos_embed=pos_embed,
|
| 167 |
+
pos_embed_context=pos_embed_context,
|
| 168 |
+
)
|
| 169 |
+
)
|
| 170 |
+
+ x
|
| 171 |
+
)
|
| 172 |
+
x = self.ls2(self.mlp(x)) + x
|
| 173 |
+
return x
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
class AttentionDecoderBlock(nn.Module):
|
| 177 |
+
def __init__(
|
| 178 |
+
self,
|
| 179 |
+
dim: int,
|
| 180 |
+
num_heads: int = 4,
|
| 181 |
+
expansion: int = 4,
|
| 182 |
+
dropout: float = 0.0,
|
| 183 |
+
cosine: bool = False,
|
| 184 |
+
gated: bool = False,
|
| 185 |
+
layer_scale: float = 1.0,
|
| 186 |
+
context_dim: int | None = None,
|
| 187 |
+
single_head_ca: bool = True,
|
| 188 |
+
):
|
| 189 |
+
super().__init__()
|
| 190 |
+
self.dropout = dropout
|
| 191 |
+
self.num_heads = num_heads
|
| 192 |
+
self.hidden_dim = dim
|
| 193 |
+
self.single_head_ca = single_head_ca
|
| 194 |
+
context_dim = context_dim or dim
|
| 195 |
+
self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
|
| 196 |
+
self.kv_ca = nn.Linear(context_dim, dim * 2)
|
| 197 |
+
self.q_ca = nn.Linear(dim, dim)
|
| 198 |
+
self.kv_sa = nn.Linear(dim, dim * 2)
|
| 199 |
+
self.q_sa = nn.Linear(dim, dim)
|
| 200 |
+
self.norm_x_sa = nn.LayerNorm(dim)
|
| 201 |
+
self.norm_x_ca = nn.LayerNorm(dim)
|
| 202 |
+
self.norm_ctx_ca = nn.LayerNorm(context_dim)
|
| 203 |
+
self.cosine = cosine
|
| 204 |
+
self.out_ca = nn.Linear(dim, dim)
|
| 205 |
+
self.out_sa = nn.Linear(dim, dim)
|
| 206 |
+
self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
| 207 |
+
self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
| 208 |
+
self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
|
| 209 |
+
|
| 210 |
+
def cross_attn(
|
| 211 |
+
self,
|
| 212 |
+
x: torch.Tensor,
|
| 213 |
+
attn_bias: torch.Tensor | None = None,
|
| 214 |
+
context: torch.Tensor | None = None,
|
| 215 |
+
pos_embed: torch.Tensor | None = None,
|
| 216 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 217 |
+
rope: nn.Module | None = None,
|
| 218 |
+
) -> torch.Tensor:
|
| 219 |
+
num_heads = 1 if self.single_head_ca else self.num_heads
|
| 220 |
+
x = self.norm_x_ca(x)
|
| 221 |
+
context = self.norm_ctx_ca(context)
|
| 222 |
+
k, v = rearrange(
|
| 223 |
+
self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2
|
| 224 |
+
).unbind(dim=-1)
|
| 225 |
+
q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads)
|
| 226 |
+
|
| 227 |
+
if rope is not None:
|
| 228 |
+
q = rope(q)
|
| 229 |
+
k = rope(k)
|
| 230 |
+
else:
|
| 231 |
+
if pos_embed is not None:
|
| 232 |
+
pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads)
|
| 233 |
+
q = q + pos_embed
|
| 234 |
+
if pos_embed_context is not None:
|
| 235 |
+
pos_embed_context = rearrange(
|
| 236 |
+
pos_embed_context, "b n (h d) -> b h n d", h=num_heads
|
| 237 |
+
)
|
| 238 |
+
k = k + pos_embed_context
|
| 239 |
+
|
| 240 |
+
if self.cosine:
|
| 241 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
| 242 |
+
x = F.scaled_dot_product_attention(
|
| 243 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
| 244 |
+
)
|
| 245 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
| 246 |
+
x = self.out_ca(x)
|
| 247 |
+
return x
|
| 248 |
+
|
| 249 |
+
def self_attn(
|
| 250 |
+
self,
|
| 251 |
+
x: torch.Tensor,
|
| 252 |
+
attn_bias: torch.Tensor | None = None,
|
| 253 |
+
pos_embed: torch.Tensor | None = None,
|
| 254 |
+
rope: nn.Module | None = None,
|
| 255 |
+
) -> torch.Tensor:
|
| 256 |
+
x = self.norm_x_sa(x)
|
| 257 |
+
k, v = rearrange(
|
| 258 |
+
self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
|
| 259 |
+
).unbind(dim=-1)
|
| 260 |
+
q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads)
|
| 261 |
+
|
| 262 |
+
if rope is not None:
|
| 263 |
+
q = rope(q)
|
| 264 |
+
k = rope(k)
|
| 265 |
+
elif pos_embed is not None:
|
| 266 |
+
pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
|
| 267 |
+
q = q + pos_embed
|
| 268 |
+
|
| 269 |
+
if self.cosine:
|
| 270 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
| 271 |
+
x = F.scaled_dot_product_attention(
|
| 272 |
+
q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
|
| 273 |
+
)
|
| 274 |
+
x = rearrange(x, "b h n d -> b n (h d)")
|
| 275 |
+
x = self.out_sa(x)
|
| 276 |
+
return x
|
| 277 |
+
|
| 278 |
+
def forward(
|
| 279 |
+
self,
|
| 280 |
+
x: torch.Tensor,
|
| 281 |
+
attn_bias: torch.Tensor | None = None,
|
| 282 |
+
context: torch.Tensor | None = None,
|
| 283 |
+
pos_embed: torch.Tensor | None = None,
|
| 284 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 285 |
+
rope: nn.Module | None = None,
|
| 286 |
+
) -> torch.Tensor:
|
| 287 |
+
context = x if context is None else context
|
| 288 |
+
x = (
|
| 289 |
+
self.ls1(
|
| 290 |
+
self.cross_attn(
|
| 291 |
+
x,
|
| 292 |
+
rope=rope,
|
| 293 |
+
attn_bias=attn_bias,
|
| 294 |
+
context=context,
|
| 295 |
+
pos_embed=pos_embed,
|
| 296 |
+
pos_embed_context=pos_embed_context,
|
| 297 |
+
)
|
| 298 |
+
)
|
| 299 |
+
+ x
|
| 300 |
+
)
|
| 301 |
+
x = (
|
| 302 |
+
self.ls2(
|
| 303 |
+
self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed)
|
| 304 |
+
)
|
| 305 |
+
+ x
|
| 306 |
+
)
|
| 307 |
+
x = self.ls3(self.mlp(x)) + x
|
| 308 |
+
return x
|
flash3d/unidepth/layers/convnext.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class CvnxtBlock(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
dim,
|
| 9 |
+
kernel_size=7,
|
| 10 |
+
layer_scale=1.0,
|
| 11 |
+
expansion=4,
|
| 12 |
+
dilation=1,
|
| 13 |
+
):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.dwconv = nn.Conv2d(
|
| 16 |
+
dim,
|
| 17 |
+
dim,
|
| 18 |
+
kernel_size=kernel_size,
|
| 19 |
+
padding="same",
|
| 20 |
+
groups=dim,
|
| 21 |
+
dilation=dilation,
|
| 22 |
+
) # depthwise conv
|
| 23 |
+
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
| 24 |
+
self.pwconv1 = nn.Linear(
|
| 25 |
+
dim, expansion * dim
|
| 26 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 27 |
+
self.act = nn.GELU()
|
| 28 |
+
self.pwconv2 = nn.Linear(expansion * dim, dim)
|
| 29 |
+
self.gamma = (
|
| 30 |
+
nn.Parameter(layer_scale * torch.ones((dim))) if layer_scale > 0.0 else 1.0
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
def forward(self, x):
|
| 34 |
+
input = x
|
| 35 |
+
x = self.dwconv(x)
|
| 36 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 37 |
+
x = self.norm(x)
|
| 38 |
+
x = self.pwconv1(x)
|
| 39 |
+
x = self.act(x)
|
| 40 |
+
x = self.pwconv2(x)
|
| 41 |
+
|
| 42 |
+
x = self.gamma * x
|
| 43 |
+
x = input + x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 44 |
+
return x
|
flash3d/unidepth/layers/drop_path.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False):
|
| 6 |
+
if drop_prob == 0.0 or not training:
|
| 7 |
+
return x
|
| 8 |
+
keep_prob = 1 - drop_prob
|
| 9 |
+
shape = (x.shape[0],) + (1,) * (
|
| 10 |
+
x.ndim - 1
|
| 11 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 12 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 13 |
+
if keep_prob > 0.0:
|
| 14 |
+
random_tensor.div_(keep_prob)
|
| 15 |
+
output = x * random_tensor
|
| 16 |
+
return output
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class DropPath(nn.Module):
|
| 20 |
+
def __init__(self, drop_prob=None):
|
| 21 |
+
super(DropPath, self).__init__()
|
| 22 |
+
self.drop_prob = drop_prob
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return drop_path(x, self.drop_prob, self.training)
|
flash3d/unidepth/layers/layer_scale.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class LayerScale(nn.Module):
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
dim: int,
|
| 9 |
+
init_values: float | torch.Tensor = 1e-5,
|
| 10 |
+
inplace: bool = False,
|
| 11 |
+
) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.inplace = inplace
|
| 14 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 15 |
+
|
| 16 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 17 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
flash3d/unidepth/layers/mlp.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from unidepth.utils.misc import default
|
| 5 |
+
from .activation import SwiGLU
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class MLP(nn.Module):
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
input_dim: int,
|
| 12 |
+
expansion: int = 4,
|
| 13 |
+
dropout: float = 0.0,
|
| 14 |
+
gated: bool = False,
|
| 15 |
+
output_dim: int | None = None,
|
| 16 |
+
):
|
| 17 |
+
super().__init__()
|
| 18 |
+
if gated:
|
| 19 |
+
expansion = int(expansion * 2 / 3)
|
| 20 |
+
hidden_dim = int(input_dim * expansion)
|
| 21 |
+
output_dim = default(output_dim, input_dim)
|
| 22 |
+
self.norm = nn.LayerNorm(input_dim)
|
| 23 |
+
self.proj1 = nn.Linear(input_dim, hidden_dim)
|
| 24 |
+
self.proj2 = nn.Linear(hidden_dim, output_dim)
|
| 25 |
+
self.act = nn.GELU() if not gated else SwiGLU()
|
| 26 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
| 27 |
+
|
| 28 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 29 |
+
x = self.norm(x)
|
| 30 |
+
x = self.proj1(x)
|
| 31 |
+
x = self.act(x)
|
| 32 |
+
x = self.proj2(x)
|
| 33 |
+
x = self.dropout(x)
|
| 34 |
+
return x
|
flash3d/unidepth/layers/nystrom_attention.py
ADDED
|
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from einops import rearrange
|
| 7 |
+
from xformers.components.attention import NystromAttention
|
| 8 |
+
|
| 9 |
+
from .attention import AttentionBlock
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class NystromBlock(AttentionBlock):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
dim: int,
|
| 16 |
+
num_heads: int = 4,
|
| 17 |
+
expansion: int = 4,
|
| 18 |
+
dropout: float = 0.0,
|
| 19 |
+
cosine: bool = False,
|
| 20 |
+
gated: bool = False,
|
| 21 |
+
layer_scale: float = 1.0,
|
| 22 |
+
context_dim: int | None = None,
|
| 23 |
+
):
|
| 24 |
+
super().__init__(
|
| 25 |
+
dim=dim,
|
| 26 |
+
num_heads=num_heads,
|
| 27 |
+
expansion=expansion,
|
| 28 |
+
dropout=dropout,
|
| 29 |
+
cosine=cosine,
|
| 30 |
+
gated=gated,
|
| 31 |
+
layer_scale=layer_scale,
|
| 32 |
+
context_dim=context_dim,
|
| 33 |
+
)
|
| 34 |
+
self.attention_fn = NystromAttention(
|
| 35 |
+
num_landmarks=128, num_heads=num_heads, dropout=dropout
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def attn(
|
| 39 |
+
self,
|
| 40 |
+
x: torch.Tensor,
|
| 41 |
+
attn_bias: torch.Tensor | None = None,
|
| 42 |
+
context: torch.Tensor | None = None,
|
| 43 |
+
pos_embed: torch.Tensor | None = None,
|
| 44 |
+
pos_embed_context: torch.Tensor | None = None,
|
| 45 |
+
rope: nn.Module | None = None,
|
| 46 |
+
) -> torch.Tensor:
|
| 47 |
+
x = self.norm_attnx(x)
|
| 48 |
+
context = self.norm_attnctx(context)
|
| 49 |
+
k, v = rearrange(
|
| 50 |
+
self.kv(context), "b n (kv h d) -> b n h d kv", h=self.num_heads, kv=2
|
| 51 |
+
).unbind(dim=-1)
|
| 52 |
+
q = rearrange(self.q(x), "b n (h d) -> b n h d", h=self.num_heads)
|
| 53 |
+
|
| 54 |
+
if rope is not None:
|
| 55 |
+
q = rope(q)
|
| 56 |
+
k = rope(k)
|
| 57 |
+
else:
|
| 58 |
+
if pos_embed is not None:
|
| 59 |
+
pos_embed = rearrange(
|
| 60 |
+
pos_embed, "b n (h d) -> b n h d", h=self.num_heads
|
| 61 |
+
)
|
| 62 |
+
q = q + pos_embed
|
| 63 |
+
if pos_embed_context is not None:
|
| 64 |
+
pos_embed_context = rearrange(
|
| 65 |
+
pos_embed_context, "b n (h d) -> b n h d", h=self.num_heads
|
| 66 |
+
)
|
| 67 |
+
k = k + pos_embed_context
|
| 68 |
+
|
| 69 |
+
if self.cosine:
|
| 70 |
+
q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
|
| 71 |
+
x = self.attention_fn(q, k, v, key_padding_mask=attn_bias)
|
| 72 |
+
x = rearrange(x, "b n h d -> b n (h d)")
|
| 73 |
+
x = self.out(x)
|
| 74 |
+
return x
|
flash3d/unidepth/layers/positional_encoding.py
ADDED
|
@@ -0,0 +1,228 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from math import pi
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PositionEmbeddingSine(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.num_pos_feats = num_pos_feats
|
| 21 |
+
self.temperature = temperature
|
| 22 |
+
self.normalize = normalize
|
| 23 |
+
if scale is not None and normalize is False:
|
| 24 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 25 |
+
if scale is None:
|
| 26 |
+
scale = 2 * pi
|
| 27 |
+
self.scale = scale
|
| 28 |
+
|
| 29 |
+
def forward(
|
| 30 |
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
if mask is None:
|
| 33 |
+
mask = torch.zeros(
|
| 34 |
+
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
| 35 |
+
)
|
| 36 |
+
not_mask = ~mask
|
| 37 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 38 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 39 |
+
if self.normalize:
|
| 40 |
+
eps = 1e-6
|
| 41 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 42 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 43 |
+
|
| 44 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 45 |
+
dim_t = self.temperature ** (
|
| 46 |
+
2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 50 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 51 |
+
pos_x = torch.stack(
|
| 52 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 53 |
+
).flatten(3)
|
| 54 |
+
pos_y = torch.stack(
|
| 55 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 56 |
+
).flatten(3)
|
| 57 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 58 |
+
return pos
|
| 59 |
+
|
| 60 |
+
def __repr__(self, _repr_indent=4):
|
| 61 |
+
head = "Positional encoding " + self.__class__.__name__
|
| 62 |
+
body = [
|
| 63 |
+
"num_pos_feats: {}".format(self.num_pos_feats),
|
| 64 |
+
"temperature: {}".format(self.temperature),
|
| 65 |
+
"normalize: {}".format(self.normalize),
|
| 66 |
+
"scale: {}".format(self.scale),
|
| 67 |
+
]
|
| 68 |
+
# _repr_indent = 4
|
| 69 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
| 70 |
+
return "\n".join(lines)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
| 74 |
+
def __init__(self, dim):
|
| 75 |
+
super().__init__()
|
| 76 |
+
assert (dim % 2) == 0
|
| 77 |
+
half_dim = dim // 2
|
| 78 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
x = rearrange(x, "b -> b 1")
|
| 82 |
+
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
|
| 83 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
| 84 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
| 85 |
+
return fouriered
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def generate_fourier_features(x, max_freq=64, num_bands=16):
|
| 89 |
+
x = x.unsqueeze(-1)
|
| 90 |
+
device, dtype, orig_x = x.device, x.dtype, x
|
| 91 |
+
|
| 92 |
+
scales = torch.linspace(
|
| 93 |
+
-max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype
|
| 94 |
+
)
|
| 95 |
+
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
|
| 96 |
+
|
| 97 |
+
x = x * scales * pi
|
| 98 |
+
x = torch.cat([x.sin(), x.cos()], dim=-1)
|
| 99 |
+
x = torch.cat((x, orig_x), dim=-1)
|
| 100 |
+
return x.flatten(-2)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def broadcat(tensors, dim=-1):
|
| 104 |
+
num_tensors = len(tensors)
|
| 105 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 106 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
| 107 |
+
shape_len = list(shape_lens)[0]
|
| 108 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 109 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 110 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 111 |
+
assert all(
|
| 112 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
| 113 |
+
), "invalid dimensions for broadcastable concatentation"
|
| 114 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 115 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 116 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 117 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 118 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 119 |
+
return torch.cat(tensors, dim=dim)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def rotate_half(x):
|
| 123 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 124 |
+
x1, x2 = x.unbind(dim=-1)
|
| 125 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 126 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 130 |
+
def __init__(
|
| 131 |
+
self,
|
| 132 |
+
dim,
|
| 133 |
+
pt_seq_len,
|
| 134 |
+
ft_seq_len=None,
|
| 135 |
+
custom_freqs=None,
|
| 136 |
+
freqs_for="lang",
|
| 137 |
+
theta=10000,
|
| 138 |
+
max_freq=10,
|
| 139 |
+
num_freqs=1,
|
| 140 |
+
):
|
| 141 |
+
super().__init__()
|
| 142 |
+
if custom_freqs:
|
| 143 |
+
freqs = custom_freqs
|
| 144 |
+
elif freqs_for == "lang":
|
| 145 |
+
freqs = 1.0 / (
|
| 146 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 147 |
+
)
|
| 148 |
+
elif freqs_for == "pixel":
|
| 149 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 150 |
+
elif freqs_for == "constant":
|
| 151 |
+
freqs = torch.ones(num_freqs).float()
|
| 152 |
+
else:
|
| 153 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
| 154 |
+
|
| 155 |
+
if ft_seq_len is None:
|
| 156 |
+
ft_seq_len = pt_seq_len
|
| 157 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 158 |
+
|
| 159 |
+
freqs_h = torch.einsum("..., f -> ... f", t, freqs)
|
| 160 |
+
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
| 161 |
+
|
| 162 |
+
freqs_w = torch.einsum("..., f -> ... f", t, freqs)
|
| 163 |
+
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
| 164 |
+
|
| 165 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
|
| 166 |
+
|
| 167 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
| 168 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
| 169 |
+
|
| 170 |
+
print("======== shape of rope freq", self.freqs_cos.shape, "========")
|
| 171 |
+
|
| 172 |
+
def forward(self, t, start_index=0):
|
| 173 |
+
rot_dim = self.freqs_cos.shape[-1]
|
| 174 |
+
end_index = start_index + rot_dim
|
| 175 |
+
assert (
|
| 176 |
+
rot_dim <= t.shape[-1]
|
| 177 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
| 178 |
+
t_left, t, t_right = (
|
| 179 |
+
t[..., :start_index],
|
| 180 |
+
t[..., start_index:end_index],
|
| 181 |
+
t[..., end_index:],
|
| 182 |
+
)
|
| 183 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
| 184 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 188 |
+
def __init__(
|
| 189 |
+
self,
|
| 190 |
+
dim,
|
| 191 |
+
pt_seq_len,
|
| 192 |
+
ft_seq_len=None,
|
| 193 |
+
custom_freqs=None,
|
| 194 |
+
freqs_for="lang",
|
| 195 |
+
theta=10000,
|
| 196 |
+
max_freq=10,
|
| 197 |
+
num_freqs=1,
|
| 198 |
+
):
|
| 199 |
+
super().__init__()
|
| 200 |
+
if custom_freqs:
|
| 201 |
+
freqs = custom_freqs
|
| 202 |
+
elif freqs_for == "lang":
|
| 203 |
+
freqs = 1.0 / (
|
| 204 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 205 |
+
)
|
| 206 |
+
elif freqs_for == "pixel":
|
| 207 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 208 |
+
elif freqs_for == "constant":
|
| 209 |
+
freqs = torch.ones(num_freqs).float()
|
| 210 |
+
else:
|
| 211 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
| 212 |
+
|
| 213 |
+
if ft_seq_len is None:
|
| 214 |
+
ft_seq_len = pt_seq_len
|
| 215 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 216 |
+
|
| 217 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 218 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 219 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 220 |
+
|
| 221 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 222 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 223 |
+
|
| 224 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 225 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 226 |
+
|
| 227 |
+
def forward(self, t):
|
| 228 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
flash3d/unidepth/layers/upsample.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
|
| 10 |
+
from .convnext import CvnxtBlock
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class ConvUpsample(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
hidden_dim,
|
| 17 |
+
num_layers: int = 2,
|
| 18 |
+
expansion: int = 4,
|
| 19 |
+
layer_scale: float = 1.0,
|
| 20 |
+
kernel_size: int = 7,
|
| 21 |
+
**kwargs
|
| 22 |
+
):
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.convs = nn.ModuleList([])
|
| 25 |
+
for _ in range(num_layers):
|
| 26 |
+
self.convs.append(
|
| 27 |
+
CvnxtBlock(
|
| 28 |
+
hidden_dim,
|
| 29 |
+
kernel_size=kernel_size,
|
| 30 |
+
expansion=expansion,
|
| 31 |
+
layer_scale=layer_scale,
|
| 32 |
+
)
|
| 33 |
+
)
|
| 34 |
+
self.up = nn.Sequential(
|
| 35 |
+
nn.Conv2d(hidden_dim, hidden_dim // 2, kernel_size=1, padding=0),
|
| 36 |
+
nn.UpsamplingBilinear2d(scale_factor=2),
|
| 37 |
+
nn.Conv2d(hidden_dim // 2, hidden_dim // 2, kernel_size=3, padding=1),
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def forward(self, x: torch.Tensor):
|
| 41 |
+
for conv in self.convs:
|
| 42 |
+
x = conv(x)
|
| 43 |
+
x = self.up(x)
|
| 44 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 45 |
+
return x
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class ConvUpsampleShuffle(nn.Module):
|
| 49 |
+
def __init__(
|
| 50 |
+
self, hidden_dim, expansion: int = 4, layer_scale: float = 1.0, **kwargs
|
| 51 |
+
):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.conv1 = CvnxtBlock(
|
| 54 |
+
hidden_dim, expansion=expansion, layer_scale=layer_scale
|
| 55 |
+
)
|
| 56 |
+
self.conv2 = CvnxtBlock(
|
| 57 |
+
hidden_dim, expansion=expansion, layer_scale=layer_scale
|
| 58 |
+
)
|
| 59 |
+
self.up = nn.Sequential(
|
| 60 |
+
nn.PixelShuffle(2),
|
| 61 |
+
nn.Conv2d(hidden_dim // 4, hidden_dim // 2, kernel_size=3, padding=1),
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def forward(self, x: torch.Tensor):
|
| 65 |
+
x = self.conv1(x)
|
| 66 |
+
x = self.conv2(x)
|
| 67 |
+
x = self.up(x)
|
| 68 |
+
x = rearrange(x, "b c h w -> b (h w) c")
|
| 69 |
+
return x
|
flash3d/unidepth/models/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .unidepthv1 import UniDepthV1
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"UniDepthV1",
|
| 5 |
+
]
|
flash3d/unidepth/models/backbones/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .convnext2 import ConvNeXtV2
|
| 2 |
+
from .convnext import ConvNeXt
|
| 3 |
+
from .dinov2 import _make_dinov2_model
|
| 4 |
+
|
| 5 |
+
__all__ = [
|
| 6 |
+
"ConvNeXt",
|
| 7 |
+
"ConvNeXtV2",
|
| 8 |
+
"_make_dinov2_model",
|
| 9 |
+
]
|
flash3d/unidepth/models/backbones/convnext.py
ADDED
|
@@ -0,0 +1,590 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import OrderedDict
|
| 2 |
+
from functools import partial
|
| 3 |
+
from typing import Callable, Optional, Tuple, Union, Sequence
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from torch.utils.checkpoint import checkpoint
|
| 8 |
+
|
| 9 |
+
from timm.layers import (
|
| 10 |
+
trunc_normal_,
|
| 11 |
+
AvgPool2dSame,
|
| 12 |
+
DropPath,
|
| 13 |
+
Mlp,
|
| 14 |
+
GlobalResponseNormMlp,
|
| 15 |
+
LayerNorm2d,
|
| 16 |
+
LayerNorm,
|
| 17 |
+
create_conv2d,
|
| 18 |
+
get_act_layer,
|
| 19 |
+
make_divisible,
|
| 20 |
+
to_ntuple,
|
| 21 |
+
)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_num_layer_for_convnext(var_name):
|
| 25 |
+
"""
|
| 26 |
+
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
|
| 27 |
+
consecutive blocks, including possible neighboring downsample layers;
|
| 28 |
+
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
|
| 29 |
+
"""
|
| 30 |
+
if var_name.startswith("downsample_layers"):
|
| 31 |
+
stage_id = int(var_name.split(".")[1])
|
| 32 |
+
if stage_id == 0:
|
| 33 |
+
layer_id = 0
|
| 34 |
+
elif stage_id == 1 or stage_id == 2:
|
| 35 |
+
layer_id = stage_id + 1
|
| 36 |
+
elif stage_id == 3:
|
| 37 |
+
layer_id = 12
|
| 38 |
+
|
| 39 |
+
elif var_name.startswith("stages"):
|
| 40 |
+
stage_id = int(var_name.split(".")[1])
|
| 41 |
+
block_id = int(var_name.split(".")[3])
|
| 42 |
+
if stage_id == 0 or stage_id == 1:
|
| 43 |
+
layer_id = stage_id + 1
|
| 44 |
+
elif stage_id == 2:
|
| 45 |
+
layer_id = 3 + block_id // 3
|
| 46 |
+
elif stage_id == 3:
|
| 47 |
+
layer_id = 12
|
| 48 |
+
|
| 49 |
+
elif var_name.startswith("stem"):
|
| 50 |
+
return 0
|
| 51 |
+
else:
|
| 52 |
+
layer_id = 12
|
| 53 |
+
return layer_id + 1
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None):
|
| 57 |
+
parameter_group_names = {}
|
| 58 |
+
parameter_group_vars = {}
|
| 59 |
+
skip = set()
|
| 60 |
+
if skip_list is not None:
|
| 61 |
+
skip = skip_list
|
| 62 |
+
if hasattr(model, "no_weight_decay"):
|
| 63 |
+
skip.update(model.no_weight_decay())
|
| 64 |
+
num_layers = 12
|
| 65 |
+
layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
|
| 66 |
+
for name, param in model.named_parameters():
|
| 67 |
+
if not param.requires_grad:
|
| 68 |
+
continue # frozen weights
|
| 69 |
+
if len(param.shape) == 1 or name.endswith(".bias") or name in skip:
|
| 70 |
+
group_name = "no_decay"
|
| 71 |
+
this_wd = 0.0
|
| 72 |
+
else:
|
| 73 |
+
group_name = "decay"
|
| 74 |
+
this_wd = wd
|
| 75 |
+
|
| 76 |
+
layer_id = get_num_layer_for_convnext(name)
|
| 77 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 78 |
+
|
| 79 |
+
if group_name not in parameter_group_names:
|
| 80 |
+
scale = layer_scale[layer_id]
|
| 81 |
+
cur_lr = lr * scale
|
| 82 |
+
|
| 83 |
+
parameter_group_names[group_name] = {
|
| 84 |
+
"weight_decay": this_wd,
|
| 85 |
+
"weight_decay_init": this_wd,
|
| 86 |
+
"weight_decay_base": this_wd,
|
| 87 |
+
"params": [],
|
| 88 |
+
"lr_init": cur_lr,
|
| 89 |
+
"lr_base": lr,
|
| 90 |
+
"lr": cur_lr,
|
| 91 |
+
}
|
| 92 |
+
parameter_group_vars[group_name] = {
|
| 93 |
+
"weight_decay": this_wd,
|
| 94 |
+
"weight_decay_init": this_wd,
|
| 95 |
+
"weight_decay_base": this_wd,
|
| 96 |
+
"params": [],
|
| 97 |
+
"lr_init": cur_lr,
|
| 98 |
+
"lr_base": lr,
|
| 99 |
+
"lr": cur_lr,
|
| 100 |
+
}
|
| 101 |
+
if this_wd == 0.0:
|
| 102 |
+
parameter_group_names[group_name]["weight_decay_final"] = 0.0
|
| 103 |
+
parameter_group_vars[group_name]["weight_decay_final"] = 0.0
|
| 104 |
+
parameter_group_vars[group_name]["params"].append(param)
|
| 105 |
+
parameter_group_names[group_name]["params"].append(name)
|
| 106 |
+
# from unidepth.utils import is_main_process
|
| 107 |
+
# import json
|
| 108 |
+
# if is_main_process():
|
| 109 |
+
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
| 110 |
+
return list(parameter_group_vars.values()), [
|
| 111 |
+
v["lr"] for k, v in parameter_group_vars.items()
|
| 112 |
+
]
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
class Downsample(nn.Module):
|
| 116 |
+
def __init__(self, in_chs, out_chs, stride=1, dilation=1):
|
| 117 |
+
super().__init__()
|
| 118 |
+
avg_stride = stride if dilation == 1 else 1
|
| 119 |
+
if stride > 1 or dilation > 1:
|
| 120 |
+
avg_pool_fn = (
|
| 121 |
+
AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
|
| 122 |
+
)
|
| 123 |
+
self.pool = avg_pool_fn(
|
| 124 |
+
2, avg_stride, ceil_mode=True, count_include_pad=False
|
| 125 |
+
)
|
| 126 |
+
else:
|
| 127 |
+
self.pool = nn.Identity()
|
| 128 |
+
|
| 129 |
+
if in_chs != out_chs:
|
| 130 |
+
self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
|
| 131 |
+
else:
|
| 132 |
+
self.conv = nn.Identity()
|
| 133 |
+
|
| 134 |
+
def forward(self, x):
|
| 135 |
+
x = self.pool(x)
|
| 136 |
+
x = self.conv(x)
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class ConvNeXtBlock(nn.Module):
|
| 141 |
+
"""ConvNeXt Block
|
| 142 |
+
There are two equivalent implementations:
|
| 143 |
+
(1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
|
| 144 |
+
(2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
|
| 145 |
+
|
| 146 |
+
Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
|
| 147 |
+
choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
|
| 148 |
+
is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
|
| 149 |
+
"""
|
| 150 |
+
|
| 151 |
+
def __init__(
|
| 152 |
+
self,
|
| 153 |
+
in_chs: int,
|
| 154 |
+
out_chs: Optional[int] = None,
|
| 155 |
+
kernel_size: int = 7,
|
| 156 |
+
stride: int = 1,
|
| 157 |
+
dilation: Union[int, Tuple[int, int]] = (1, 1),
|
| 158 |
+
mlp_ratio: float = 4,
|
| 159 |
+
conv_mlp: bool = False,
|
| 160 |
+
conv_bias: bool = True,
|
| 161 |
+
use_grn: bool = False,
|
| 162 |
+
ls_init_value: Optional[float] = 1e-6,
|
| 163 |
+
act_layer: Union[str, Callable] = "gelu",
|
| 164 |
+
norm_layer: Optional[Callable] = None,
|
| 165 |
+
drop_path: float = 0.0,
|
| 166 |
+
):
|
| 167 |
+
"""
|
| 168 |
+
|
| 169 |
+
Args:
|
| 170 |
+
in_chs: Block input channels.
|
| 171 |
+
out_chs: Block output channels (same as in_chs if None).
|
| 172 |
+
kernel_size: Depthwise convolution kernel size.
|
| 173 |
+
stride: Stride of depthwise convolution.
|
| 174 |
+
dilation: Tuple specifying input and output dilation of block.
|
| 175 |
+
mlp_ratio: MLP expansion ratio.
|
| 176 |
+
conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
|
| 177 |
+
conv_bias: Apply bias for all convolution (linear) layers.
|
| 178 |
+
use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
|
| 179 |
+
ls_init_value: Layer-scale init values, layer-scale applied if not None.
|
| 180 |
+
act_layer: Activation layer.
|
| 181 |
+
norm_layer: Normalization layer (defaults to LN if not specified).
|
| 182 |
+
drop_path: Stochastic depth probability.
|
| 183 |
+
"""
|
| 184 |
+
super().__init__()
|
| 185 |
+
out_chs = out_chs or in_chs
|
| 186 |
+
dilation = to_ntuple(2)(dilation)
|
| 187 |
+
act_layer = get_act_layer(act_layer)
|
| 188 |
+
if not norm_layer:
|
| 189 |
+
norm_layer = LayerNorm2d if conv_mlp else LayerNorm
|
| 190 |
+
mlp_layer = partial(
|
| 191 |
+
GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp
|
| 192 |
+
)
|
| 193 |
+
self.use_conv_mlp = conv_mlp
|
| 194 |
+
self.conv_dw = create_conv2d(
|
| 195 |
+
in_chs,
|
| 196 |
+
out_chs,
|
| 197 |
+
kernel_size=kernel_size,
|
| 198 |
+
stride=stride,
|
| 199 |
+
dilation=dilation[0],
|
| 200 |
+
depthwise=True,
|
| 201 |
+
bias=conv_bias,
|
| 202 |
+
)
|
| 203 |
+
self.norm = norm_layer(out_chs)
|
| 204 |
+
self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
|
| 205 |
+
self.gamma = (
|
| 206 |
+
nn.Parameter(ls_init_value * torch.ones(out_chs))
|
| 207 |
+
if ls_init_value is not None
|
| 208 |
+
else None
|
| 209 |
+
)
|
| 210 |
+
if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
|
| 211 |
+
self.shortcut = Downsample(
|
| 212 |
+
in_chs, out_chs, stride=stride, dilation=dilation[0]
|
| 213 |
+
)
|
| 214 |
+
else:
|
| 215 |
+
self.shortcut = nn.Identity()
|
| 216 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 217 |
+
|
| 218 |
+
def forward(self, x):
|
| 219 |
+
shortcut = x
|
| 220 |
+
x = self.conv_dw(x.contiguous())
|
| 221 |
+
if self.use_conv_mlp:
|
| 222 |
+
x = self.norm(x)
|
| 223 |
+
x = self.mlp(x)
|
| 224 |
+
else:
|
| 225 |
+
x = x.permute(0, 2, 3, 1).contiguous()
|
| 226 |
+
x = self.norm(x)
|
| 227 |
+
x = self.mlp(x)
|
| 228 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
| 229 |
+
if self.gamma is not None:
|
| 230 |
+
x = x.mul(self.gamma.reshape(1, -1, 1, 1))
|
| 231 |
+
|
| 232 |
+
x = self.drop_path(x) + self.shortcut(shortcut)
|
| 233 |
+
return x.contiguous()
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class ConvNeXtStage(nn.Module):
|
| 237 |
+
def __init__(
|
| 238 |
+
self,
|
| 239 |
+
in_chs,
|
| 240 |
+
out_chs,
|
| 241 |
+
kernel_size=7,
|
| 242 |
+
stride=2,
|
| 243 |
+
depth=2,
|
| 244 |
+
dilation=(1, 1),
|
| 245 |
+
drop_path_rates=None,
|
| 246 |
+
ls_init_value=1.0,
|
| 247 |
+
conv_mlp=False,
|
| 248 |
+
conv_bias=True,
|
| 249 |
+
use_grn=False,
|
| 250 |
+
act_layer="gelu",
|
| 251 |
+
norm_layer=None,
|
| 252 |
+
norm_layer_cl=None,
|
| 253 |
+
):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.grad_checkpointing = False
|
| 256 |
+
|
| 257 |
+
if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
|
| 258 |
+
ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
|
| 259 |
+
pad = (
|
| 260 |
+
"same" if dilation[1] > 1 else 0
|
| 261 |
+
) # same padding needed if dilation used
|
| 262 |
+
self.downsample = nn.Sequential(
|
| 263 |
+
norm_layer(in_chs),
|
| 264 |
+
create_conv2d(
|
| 265 |
+
in_chs,
|
| 266 |
+
out_chs,
|
| 267 |
+
kernel_size=ds_ks,
|
| 268 |
+
stride=stride,
|
| 269 |
+
dilation=dilation[0],
|
| 270 |
+
padding=pad,
|
| 271 |
+
bias=conv_bias,
|
| 272 |
+
),
|
| 273 |
+
)
|
| 274 |
+
in_chs = out_chs
|
| 275 |
+
else:
|
| 276 |
+
self.downsample = nn.Identity()
|
| 277 |
+
|
| 278 |
+
drop_path_rates = drop_path_rates or [0.0] * depth
|
| 279 |
+
stage_blocks = []
|
| 280 |
+
for i in range(depth):
|
| 281 |
+
stage_blocks.append(
|
| 282 |
+
ConvNeXtBlock(
|
| 283 |
+
in_chs=in_chs,
|
| 284 |
+
out_chs=out_chs,
|
| 285 |
+
kernel_size=kernel_size,
|
| 286 |
+
dilation=dilation[1],
|
| 287 |
+
drop_path=drop_path_rates[i],
|
| 288 |
+
ls_init_value=ls_init_value,
|
| 289 |
+
conv_mlp=conv_mlp,
|
| 290 |
+
conv_bias=conv_bias,
|
| 291 |
+
use_grn=use_grn,
|
| 292 |
+
act_layer=act_layer,
|
| 293 |
+
norm_layer=norm_layer if conv_mlp else norm_layer_cl,
|
| 294 |
+
)
|
| 295 |
+
)
|
| 296 |
+
in_chs = out_chs
|
| 297 |
+
self.blocks = nn.ModuleList(stage_blocks)
|
| 298 |
+
|
| 299 |
+
def forward(self, x):
|
| 300 |
+
xs = []
|
| 301 |
+
x = self.downsample(x)
|
| 302 |
+
for block in self.blocks:
|
| 303 |
+
if self.grad_checkpointing:
|
| 304 |
+
x = checkpoint(block, x)
|
| 305 |
+
else:
|
| 306 |
+
x = block(x)
|
| 307 |
+
xs.append(x)
|
| 308 |
+
return xs
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class ConvNeXt(nn.Module):
|
| 312 |
+
def __init__(
|
| 313 |
+
self,
|
| 314 |
+
in_chans: int = 3,
|
| 315 |
+
output_stride: int = 32,
|
| 316 |
+
depths: Tuple[int, ...] = (3, 3, 9, 3),
|
| 317 |
+
dims: Tuple[int, ...] = (96, 192, 384, 768),
|
| 318 |
+
kernel_sizes: Union[int, Tuple[int, ...]] = 7,
|
| 319 |
+
ls_init_value: Optional[float] = 1e-6,
|
| 320 |
+
stem_type: str = "patch",
|
| 321 |
+
patch_size: int = 4,
|
| 322 |
+
conv_mlp: bool = False,
|
| 323 |
+
conv_bias: bool = True,
|
| 324 |
+
use_grn: bool = False,
|
| 325 |
+
act_layer: Union[str, Callable] = "gelu",
|
| 326 |
+
norm_layer: Optional[Union[str, Callable]] = None,
|
| 327 |
+
norm_eps: Optional[float] = None,
|
| 328 |
+
drop_path_rate: float = 0.0,
|
| 329 |
+
output_idx=[],
|
| 330 |
+
use_checkpoint=False,
|
| 331 |
+
):
|
| 332 |
+
"""
|
| 333 |
+
Args:
|
| 334 |
+
in_chans: Number of input image channels.
|
| 335 |
+
num_classes: Number of classes for classification head.
|
| 336 |
+
global_pool: Global pooling type.
|
| 337 |
+
output_stride: Output stride of network, one of (8, 16, 32).
|
| 338 |
+
depths: Number of blocks at each stage.
|
| 339 |
+
dims: Feature dimension at each stage.
|
| 340 |
+
kernel_sizes: Depthwise convolution kernel-sizes for each stage.
|
| 341 |
+
ls_init_value: Init value for Layer Scale, disabled if None.
|
| 342 |
+
stem_type: Type of stem.
|
| 343 |
+
patch_size: Stem patch size for patch stem.
|
| 344 |
+
head_init_scale: Init scaling value for classifier weights and biases.
|
| 345 |
+
head_norm_first: Apply normalization before global pool + head.
|
| 346 |
+
head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
|
| 347 |
+
conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
|
| 348 |
+
conv_bias: Use bias layers w/ all convolutions.
|
| 349 |
+
use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
|
| 350 |
+
act_layer: Activation layer type.
|
| 351 |
+
norm_layer: Normalization layer type.
|
| 352 |
+
drop_rate: Head pre-classifier dropout rate.
|
| 353 |
+
drop_path_rate: Stochastic depth drop rate.
|
| 354 |
+
"""
|
| 355 |
+
super().__init__()
|
| 356 |
+
self.num_layers = len(depths)
|
| 357 |
+
self.depths = output_idx
|
| 358 |
+
self.embed_dims = [
|
| 359 |
+
int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
|
| 360 |
+
]
|
| 361 |
+
self.embed_dim = dims[0]
|
| 362 |
+
|
| 363 |
+
assert output_stride in (8, 16, 32)
|
| 364 |
+
kernel_sizes = to_ntuple(4)(kernel_sizes)
|
| 365 |
+
if norm_layer is None:
|
| 366 |
+
norm_layer = LayerNorm2d
|
| 367 |
+
norm_layer_cl = norm_layer if conv_mlp else LayerNorm
|
| 368 |
+
if norm_eps is not None:
|
| 369 |
+
norm_layer = partial(norm_layer, eps=norm_eps)
|
| 370 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
| 371 |
+
else:
|
| 372 |
+
assert (
|
| 373 |
+
conv_mlp
|
| 374 |
+
), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
|
| 375 |
+
norm_layer_cl = norm_layer
|
| 376 |
+
if norm_eps is not None:
|
| 377 |
+
norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
|
| 378 |
+
|
| 379 |
+
self.feature_info = []
|
| 380 |
+
|
| 381 |
+
assert stem_type in ("patch", "overlap", "overlap_tiered")
|
| 382 |
+
if stem_type == "patch":
|
| 383 |
+
# NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
|
| 384 |
+
self.stem = nn.Sequential(
|
| 385 |
+
nn.Conv2d(
|
| 386 |
+
in_chans,
|
| 387 |
+
dims[0],
|
| 388 |
+
kernel_size=patch_size,
|
| 389 |
+
stride=patch_size,
|
| 390 |
+
bias=conv_bias,
|
| 391 |
+
),
|
| 392 |
+
norm_layer(dims[0]),
|
| 393 |
+
)
|
| 394 |
+
stem_stride = patch_size
|
| 395 |
+
else:
|
| 396 |
+
mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0]
|
| 397 |
+
self.stem = nn.Sequential(
|
| 398 |
+
nn.Conv2d(
|
| 399 |
+
in_chans,
|
| 400 |
+
mid_chs,
|
| 401 |
+
kernel_size=3,
|
| 402 |
+
stride=2,
|
| 403 |
+
padding=1,
|
| 404 |
+
bias=conv_bias,
|
| 405 |
+
),
|
| 406 |
+
nn.Conv2d(
|
| 407 |
+
mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias
|
| 408 |
+
),
|
| 409 |
+
norm_layer(dims[0]),
|
| 410 |
+
)
|
| 411 |
+
stem_stride = 4
|
| 412 |
+
|
| 413 |
+
self.stages = nn.Sequential()
|
| 414 |
+
dp_rates = [
|
| 415 |
+
x.tolist()
|
| 416 |
+
for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
|
| 417 |
+
]
|
| 418 |
+
stages = []
|
| 419 |
+
prev_chs = dims[0]
|
| 420 |
+
curr_stride = stem_stride
|
| 421 |
+
dilation = 1
|
| 422 |
+
# 4 feature resolution stages, each consisting of multiple residual blocks
|
| 423 |
+
for i in range(4):
|
| 424 |
+
stride = 2 if curr_stride == 2 or i > 0 else 1
|
| 425 |
+
if curr_stride >= output_stride and stride > 1:
|
| 426 |
+
dilation *= stride
|
| 427 |
+
stride = 1
|
| 428 |
+
curr_stride *= stride
|
| 429 |
+
first_dilation = 1 if dilation in (1, 2) else 2
|
| 430 |
+
out_chs = dims[i]
|
| 431 |
+
stages.append(
|
| 432 |
+
ConvNeXtStage(
|
| 433 |
+
prev_chs,
|
| 434 |
+
out_chs,
|
| 435 |
+
kernel_size=kernel_sizes[i],
|
| 436 |
+
stride=stride,
|
| 437 |
+
dilation=(first_dilation, dilation),
|
| 438 |
+
depth=depths[i],
|
| 439 |
+
drop_path_rates=dp_rates[i],
|
| 440 |
+
ls_init_value=ls_init_value,
|
| 441 |
+
conv_mlp=conv_mlp,
|
| 442 |
+
conv_bias=conv_bias,
|
| 443 |
+
use_grn=use_grn,
|
| 444 |
+
act_layer=act_layer,
|
| 445 |
+
norm_layer=norm_layer,
|
| 446 |
+
norm_layer_cl=norm_layer_cl,
|
| 447 |
+
)
|
| 448 |
+
)
|
| 449 |
+
prev_chs = out_chs
|
| 450 |
+
# NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
|
| 451 |
+
self.feature_info += [
|
| 452 |
+
dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
|
| 453 |
+
]
|
| 454 |
+
self.stages = nn.ModuleList(stages)
|
| 455 |
+
self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
|
| 456 |
+
self.num_features = prev_chs
|
| 457 |
+
self.apply(self._init_weights)
|
| 458 |
+
self.set_grad_checkpointing(use_checkpoint)
|
| 459 |
+
|
| 460 |
+
def _init_weights(self, module):
|
| 461 |
+
if isinstance(module, nn.Conv2d):
|
| 462 |
+
trunc_normal_(module.weight, std=0.02)
|
| 463 |
+
if module.bias is not None:
|
| 464 |
+
nn.init.zeros_(module.bias)
|
| 465 |
+
elif isinstance(module, nn.Linear):
|
| 466 |
+
trunc_normal_(module.weight, std=0.02)
|
| 467 |
+
nn.init.zeros_(module.bias)
|
| 468 |
+
|
| 469 |
+
def forward(self, x, masks=None):
|
| 470 |
+
outs = []
|
| 471 |
+
x = self.stem(x)
|
| 472 |
+
if masks is not None:
|
| 473 |
+
masks = torch.nn.functional.interpolate(
|
| 474 |
+
masks.float(), size=x.shape[-2:], mode="nearest"
|
| 475 |
+
)
|
| 476 |
+
x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous()
|
| 477 |
+
for stage in self.stages:
|
| 478 |
+
xs = stage(x)
|
| 479 |
+
outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs])
|
| 480 |
+
x = xs[-1]
|
| 481 |
+
return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
|
| 482 |
+
|
| 483 |
+
@torch.jit.ignore
|
| 484 |
+
def group_matcher(self, coarse=False):
|
| 485 |
+
return dict(
|
| 486 |
+
stem=r"^stem",
|
| 487 |
+
blocks=(
|
| 488 |
+
r"^stages\.(\d+)"
|
| 489 |
+
if coarse
|
| 490 |
+
else [
|
| 491 |
+
(r"^stages\.(\d+)\.downsample", (0,)), # blocks
|
| 492 |
+
(r"^stages\.(\d+)\.blocks\.(\d+)", None),
|
| 493 |
+
(r"^norm_pre", (99999,)),
|
| 494 |
+
]
|
| 495 |
+
),
|
| 496 |
+
)
|
| 497 |
+
|
| 498 |
+
@torch.jit.ignore
|
| 499 |
+
def set_grad_checkpointing(self, enable=True):
|
| 500 |
+
for s in self.stages:
|
| 501 |
+
s.grad_checkpointing = enable
|
| 502 |
+
|
| 503 |
+
def freeze(self) -> None:
|
| 504 |
+
for module in self.modules():
|
| 505 |
+
module.eval()
|
| 506 |
+
for parameters in self.parameters():
|
| 507 |
+
parameters.requires_grad = False
|
| 508 |
+
|
| 509 |
+
def get_params(self, lr, wd, ld, *args, **kwargs):
|
| 510 |
+
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
|
| 511 |
+
return encoder_p, encoder_lr
|
| 512 |
+
|
| 513 |
+
def no_weight_decay(self):
|
| 514 |
+
return {"mask_token"}
|
| 515 |
+
|
| 516 |
+
@classmethod
|
| 517 |
+
def build(cls, config):
|
| 518 |
+
obj = globals()[config["model"]["encoder"]["name"]](config)
|
| 519 |
+
return obj
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def checkpoint_filter_fn(state_dict, model):
|
| 523 |
+
"""Remap FB checkpoints -> timm"""
|
| 524 |
+
if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict:
|
| 525 |
+
return state_dict # non-FB checkpoint
|
| 526 |
+
if "model" in state_dict:
|
| 527 |
+
state_dict = state_dict["model"]
|
| 528 |
+
|
| 529 |
+
out_dict = {}
|
| 530 |
+
if "visual.trunk.stem.0.weight" in state_dict:
|
| 531 |
+
out_dict = {
|
| 532 |
+
k.replace("visual.trunk.", ""): v
|
| 533 |
+
for k, v in state_dict.items()
|
| 534 |
+
if k.startswith("visual.trunk.")
|
| 535 |
+
}
|
| 536 |
+
if "visual.head.proj.weight" in state_dict:
|
| 537 |
+
out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"]
|
| 538 |
+
out_dict["head.fc.bias"] = torch.zeros(
|
| 539 |
+
state_dict["visual.head.proj.weight"].shape[0]
|
| 540 |
+
)
|
| 541 |
+
elif "visual.head.mlp.fc1.weight" in state_dict:
|
| 542 |
+
out_dict["head.pre_logits.fc.weight"] = state_dict[
|
| 543 |
+
"visual.head.mlp.fc1.weight"
|
| 544 |
+
]
|
| 545 |
+
out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"]
|
| 546 |
+
out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"]
|
| 547 |
+
out_dict["head.fc.bias"] = torch.zeros(
|
| 548 |
+
state_dict["visual.head.mlp.fc2.weight"].shape[0]
|
| 549 |
+
)
|
| 550 |
+
return out_dict
|
| 551 |
+
|
| 552 |
+
import re
|
| 553 |
+
|
| 554 |
+
for k, v in state_dict.items():
|
| 555 |
+
k = k.replace("downsample_layers.0.", "stem.")
|
| 556 |
+
k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
|
| 557 |
+
k = re.sub(
|
| 558 |
+
r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
|
| 559 |
+
)
|
| 560 |
+
k = k.replace("dwconv", "conv_dw")
|
| 561 |
+
k = k.replace("pwconv", "mlp.fc")
|
| 562 |
+
if "grn" in k:
|
| 563 |
+
k = k.replace("grn.beta", "mlp.grn.bias")
|
| 564 |
+
k = k.replace("grn.gamma", "mlp.grn.weight")
|
| 565 |
+
v = v.reshape(v.shape[-1])
|
| 566 |
+
k = k.replace("head.", "head.fc.")
|
| 567 |
+
if k.startswith("norm."):
|
| 568 |
+
k = k.replace("norm", "head.norm")
|
| 569 |
+
if v.ndim == 2 and "head" not in k:
|
| 570 |
+
model_shape = model.state_dict()[k].shape
|
| 571 |
+
v = v.reshape(model_shape)
|
| 572 |
+
out_dict[k] = v
|
| 573 |
+
|
| 574 |
+
return out_dict
|
| 575 |
+
|
| 576 |
+
|
| 577 |
+
HF_URL = {
|
| 578 |
+
"convnext_xxlarge_pt": (
|
| 579 |
+
"laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup",
|
| 580 |
+
"open_clip_pytorch_model.bin",
|
| 581 |
+
),
|
| 582 |
+
"convnext_large_pt": (
|
| 583 |
+
"laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
|
| 584 |
+
"open_clip_pytorch_model.bin",
|
| 585 |
+
),
|
| 586 |
+
"convnext_large": (
|
| 587 |
+
"timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384",
|
| 588 |
+
"pytorch_model.bin",
|
| 589 |
+
),
|
| 590 |
+
}
|
flash3d/unidepth/models/backbones/convnext2.py
ADDED
|
@@ -0,0 +1,288 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
from timm.models.layers import trunc_normal_, DropPath
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def get_num_layer_for_convnext_single(var_name, depths):
|
| 8 |
+
"""
|
| 9 |
+
Each layer is assigned distinctive layer ids
|
| 10 |
+
"""
|
| 11 |
+
if var_name.startswith("downsample_layers"):
|
| 12 |
+
stage_id = int(var_name.split(".")[1])
|
| 13 |
+
layer_id = sum(depths[:stage_id]) + 1
|
| 14 |
+
return layer_id
|
| 15 |
+
|
| 16 |
+
elif var_name.startswith("stages"):
|
| 17 |
+
stage_id = int(var_name.split(".")[1])
|
| 18 |
+
block_id = int(var_name.split(".")[2])
|
| 19 |
+
layer_id = sum(depths[:stage_id]) + block_id + 1
|
| 20 |
+
return layer_id
|
| 21 |
+
|
| 22 |
+
else:
|
| 23 |
+
return sum(depths) + 1
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_num_layer_for_convnext(var_name):
|
| 27 |
+
"""
|
| 28 |
+
Divide [3, 3, 27, 3] layers into 12 groups; each group is three
|
| 29 |
+
consecutive blocks, including possible neighboring downsample layers;
|
| 30 |
+
adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
|
| 31 |
+
"""
|
| 32 |
+
num_max_layer = 12
|
| 33 |
+
if var_name.startswith("downsample_layers"):
|
| 34 |
+
stage_id = int(var_name.split(".")[1])
|
| 35 |
+
if stage_id == 0:
|
| 36 |
+
layer_id = 0
|
| 37 |
+
elif stage_id == 1 or stage_id == 2:
|
| 38 |
+
layer_id = stage_id + 1
|
| 39 |
+
elif stage_id == 3:
|
| 40 |
+
layer_id = 12
|
| 41 |
+
return layer_id
|
| 42 |
+
|
| 43 |
+
elif var_name.startswith("stages"):
|
| 44 |
+
stage_id = int(var_name.split(".")[1])
|
| 45 |
+
block_id = int(var_name.split(".")[2])
|
| 46 |
+
if stage_id == 0 or stage_id == 1:
|
| 47 |
+
layer_id = stage_id + 1
|
| 48 |
+
elif stage_id == 2:
|
| 49 |
+
layer_id = 3 + block_id // 3
|
| 50 |
+
elif stage_id == 3:
|
| 51 |
+
layer_id = 12
|
| 52 |
+
return layer_id
|
| 53 |
+
else:
|
| 54 |
+
return num_max_layer + 1
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
|
| 58 |
+
parameter_group_names = {}
|
| 59 |
+
parameter_group_vars = {}
|
| 60 |
+
skip = {}
|
| 61 |
+
if skip_list is not None:
|
| 62 |
+
skip = skip_list
|
| 63 |
+
elif hasattr(model, "no_weight_decay"):
|
| 64 |
+
skip = model.no_weight_decay()
|
| 65 |
+
num_layers = 12 # sum(model.depths)
|
| 66 |
+
layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
|
| 67 |
+
for name, param in model.named_parameters():
|
| 68 |
+
if not param.requires_grad:
|
| 69 |
+
continue # frozen weights
|
| 70 |
+
if (
|
| 71 |
+
len(param.shape) == 1
|
| 72 |
+
or name.endswith(".bias")
|
| 73 |
+
or name in skip
|
| 74 |
+
or name.endswith(".gamma")
|
| 75 |
+
or name.endswith(".beta")
|
| 76 |
+
):
|
| 77 |
+
group_name = "no_decay"
|
| 78 |
+
this_weight_decay = 0.0
|
| 79 |
+
else:
|
| 80 |
+
group_name = "decay"
|
| 81 |
+
this_weight_decay = wd
|
| 82 |
+
|
| 83 |
+
# layer_id = get_num_layer_for_convnext_single(name, model.depths)
|
| 84 |
+
layer_id = get_num_layer_for_convnext(name)
|
| 85 |
+
group_name = "layer_%d_%s" % (layer_id, group_name)
|
| 86 |
+
|
| 87 |
+
if group_name not in parameter_group_names:
|
| 88 |
+
scale = layer_scale[layer_id]
|
| 89 |
+
cur_lr = lr * scale
|
| 90 |
+
|
| 91 |
+
parameter_group_names[group_name] = {
|
| 92 |
+
"weight_decay": this_weight_decay,
|
| 93 |
+
"params": [],
|
| 94 |
+
"lr_scale": scale,
|
| 95 |
+
"lr": cur_lr,
|
| 96 |
+
}
|
| 97 |
+
parameter_group_vars[group_name] = {
|
| 98 |
+
"weight_decay": this_weight_decay,
|
| 99 |
+
"params": [],
|
| 100 |
+
"lr_scale": scale,
|
| 101 |
+
"lr": cur_lr,
|
| 102 |
+
}
|
| 103 |
+
parameter_group_vars[group_name]["params"].append(param)
|
| 104 |
+
parameter_group_names[group_name]["params"].append(name)
|
| 105 |
+
# if is_main_process():
|
| 106 |
+
# print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
|
| 107 |
+
return list(parameter_group_vars.values()), [
|
| 108 |
+
v["lr"] for k, v in parameter_group_vars.items()
|
| 109 |
+
]
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class LayerNorm(nn.Module):
|
| 113 |
+
"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
|
| 114 |
+
The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
|
| 115 |
+
shape (batch_size, height, width, channels) while channels_first corresponds to inputs
|
| 116 |
+
with shape (batch_size, channels, height, width).
|
| 117 |
+
"""
|
| 118 |
+
|
| 119 |
+
def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
|
| 120 |
+
super().__init__()
|
| 121 |
+
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
| 122 |
+
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
| 123 |
+
self.eps = eps
|
| 124 |
+
self.data_format = data_format
|
| 125 |
+
if self.data_format not in ["channels_last", "channels_first"]:
|
| 126 |
+
raise NotImplementedError
|
| 127 |
+
self.normalized_shape = (normalized_shape,)
|
| 128 |
+
|
| 129 |
+
def forward(self, x):
|
| 130 |
+
if self.data_format == "channels_last":
|
| 131 |
+
return F.layer_norm(
|
| 132 |
+
x, self.normalized_shape, self.weight, self.bias, self.eps
|
| 133 |
+
)
|
| 134 |
+
elif self.data_format == "channels_first":
|
| 135 |
+
u = x.mean(1, keepdim=True)
|
| 136 |
+
s = (x - u).pow(2).mean(1, keepdim=True)
|
| 137 |
+
x = (x - u) / torch.sqrt(s + self.eps)
|
| 138 |
+
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
class GRN(nn.Module):
|
| 143 |
+
"""GRN (Global Response Normalization) layer"""
|
| 144 |
+
|
| 145 |
+
def __init__(self, dim):
|
| 146 |
+
super().__init__()
|
| 147 |
+
self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
| 148 |
+
self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
|
| 149 |
+
|
| 150 |
+
def forward(self, x):
|
| 151 |
+
Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
| 152 |
+
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
| 153 |
+
return self.gamma * (x * Nx) + self.beta + x
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
class Block(nn.Module):
|
| 157 |
+
"""ConvNeXtV2 Block.
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
dim (int): Number of input channels.
|
| 161 |
+
drop_path (float): Stochastic depth rate. Default: 0.0
|
| 162 |
+
"""
|
| 163 |
+
|
| 164 |
+
def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.dwconv = nn.Conv2d(
|
| 167 |
+
dim, dim, kernel_size=7, padding=3, groups=dim
|
| 168 |
+
) # depthwise conv
|
| 169 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 170 |
+
self.pwconv1 = nn.Linear(
|
| 171 |
+
dim, mult * dim
|
| 172 |
+
) # pointwise/1x1 convs, implemented with linear layers
|
| 173 |
+
self.act = nn.GELU()
|
| 174 |
+
self.grn = GRN(mult * dim)
|
| 175 |
+
self.pwconv2 = nn.Linear(mult * dim, dim)
|
| 176 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 177 |
+
self.use_checkpoint = use_checkpoint
|
| 178 |
+
|
| 179 |
+
def forward(self, x):
|
| 180 |
+
input = x
|
| 181 |
+
x = self.dwconv(x)
|
| 182 |
+
x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
|
| 183 |
+
x = self.norm(x)
|
| 184 |
+
x = self.pwconv1(x)
|
| 185 |
+
x = self.act(x)
|
| 186 |
+
x = self.grn(x)
|
| 187 |
+
x = self.pwconv2(x)
|
| 188 |
+
x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
|
| 189 |
+
|
| 190 |
+
x = input + self.drop_path(x)
|
| 191 |
+
return x
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
class ConvNeXtV2(nn.Module):
|
| 195 |
+
"""ConvNeXt V2
|
| 196 |
+
|
| 197 |
+
Args:
|
| 198 |
+
in_chans (int): Number of input image channels. Default: 3
|
| 199 |
+
num_classes (int): Number of classes for classification head. Default: 1000
|
| 200 |
+
depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
|
| 201 |
+
dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
|
| 202 |
+
drop_path_rate (float): Stochastic depth rate. Default: 0.
|
| 203 |
+
head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
|
| 204 |
+
"""
|
| 205 |
+
|
| 206 |
+
def __init__(
|
| 207 |
+
self,
|
| 208 |
+
in_chans=3,
|
| 209 |
+
depths=[3, 3, 9, 3],
|
| 210 |
+
dims=96,
|
| 211 |
+
drop_path_rate=0.0,
|
| 212 |
+
output_idx=[],
|
| 213 |
+
use_checkpoint=False,
|
| 214 |
+
):
|
| 215 |
+
super().__init__()
|
| 216 |
+
self.num_layers = len(depths)
|
| 217 |
+
self.depths = output_idx
|
| 218 |
+
self.embed_dims = [
|
| 219 |
+
int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
|
| 220 |
+
]
|
| 221 |
+
self.embed_dim = dims[0]
|
| 222 |
+
|
| 223 |
+
self.downsample_layers = (
|
| 224 |
+
nn.ModuleList()
|
| 225 |
+
) # stem and 3 intermediate downsampling conv layers
|
| 226 |
+
stem = nn.Sequential(
|
| 227 |
+
nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
|
| 228 |
+
LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
|
| 229 |
+
)
|
| 230 |
+
self.downsample_layers.append(stem)
|
| 231 |
+
for i in range(3):
|
| 232 |
+
downsample_layer = nn.Sequential(
|
| 233 |
+
LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
|
| 234 |
+
nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
|
| 235 |
+
)
|
| 236 |
+
self.downsample_layers.append(downsample_layer)
|
| 237 |
+
|
| 238 |
+
self.stages = (
|
| 239 |
+
nn.ModuleList()
|
| 240 |
+
) # 4 feature resolution stages, each consisting of multiple residual blocks
|
| 241 |
+
self.out_norms = nn.ModuleList()
|
| 242 |
+
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
| 243 |
+
cur = 0
|
| 244 |
+
for i in range(4):
|
| 245 |
+
stage = nn.ModuleList(
|
| 246 |
+
[
|
| 247 |
+
Block(
|
| 248 |
+
dim=dims[i],
|
| 249 |
+
drop_path=dp_rates[cur + j],
|
| 250 |
+
use_checkpoint=use_checkpoint,
|
| 251 |
+
)
|
| 252 |
+
for j in range(depths[i])
|
| 253 |
+
]
|
| 254 |
+
)
|
| 255 |
+
self.stages.append(stage)
|
| 256 |
+
cur += depths[i]
|
| 257 |
+
|
| 258 |
+
self.apply(self._init_weights)
|
| 259 |
+
|
| 260 |
+
def _init_weights(self, m):
|
| 261 |
+
if isinstance(m, (nn.Conv2d, nn.Linear)):
|
| 262 |
+
trunc_normal_(m.weight, std=0.02)
|
| 263 |
+
nn.init.constant_(m.bias, 0)
|
| 264 |
+
|
| 265 |
+
def forward(self, x):
|
| 266 |
+
outs = []
|
| 267 |
+
for i in range(4):
|
| 268 |
+
x = self.downsample_layers[i](x)
|
| 269 |
+
for stage in self.stages[i]:
|
| 270 |
+
x = stage(x)
|
| 271 |
+
outs.append(x.permute(0, 2, 3, 1))
|
| 272 |
+
cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
|
| 273 |
+
return outs, cls_tokens
|
| 274 |
+
|
| 275 |
+
def get_params(self, lr, wd, ld, *args, **kwargs):
|
| 276 |
+
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
|
| 277 |
+
return encoder_p, encoder_lr
|
| 278 |
+
|
| 279 |
+
def freeze(self) -> None:
|
| 280 |
+
for module in self.modules():
|
| 281 |
+
module.eval()
|
| 282 |
+
for parameters in self.parameters():
|
| 283 |
+
parameters.requires_grad = False
|
| 284 |
+
|
| 285 |
+
@classmethod
|
| 286 |
+
def build(cls, config):
|
| 287 |
+
obj = globals()[config["model"]["encoder"]["name"]](config)
|
| 288 |
+
return obj
|
flash3d/unidepth/models/backbones/dinov2.py
ADDED
|
@@ -0,0 +1,552 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from functools import partial
|
| 2 |
+
import math
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Sequence, Tuple, Union, Callable
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from torch.utils.checkpoint import checkpoint
|
| 9 |
+
from torch.nn.init import trunc_normal_
|
| 10 |
+
|
| 11 |
+
from .metadinov2 import (
|
| 12 |
+
Mlp,
|
| 13 |
+
PatchEmbed,
|
| 14 |
+
SwiGLUFFNFused,
|
| 15 |
+
MemEffAttention,
|
| 16 |
+
NestedTensorBlock as Block,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
logger = logging.getLogger("dinov2")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def named_apply(
|
| 24 |
+
fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
|
| 25 |
+
) -> nn.Module:
|
| 26 |
+
if not depth_first and include_root:
|
| 27 |
+
fn(module=module, name=name)
|
| 28 |
+
for child_name, child_module in module.named_children():
|
| 29 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
| 30 |
+
named_apply(
|
| 31 |
+
fn=fn,
|
| 32 |
+
module=child_module,
|
| 33 |
+
name=child_name,
|
| 34 |
+
depth_first=depth_first,
|
| 35 |
+
include_root=True,
|
| 36 |
+
)
|
| 37 |
+
if depth_first and include_root:
|
| 38 |
+
fn(module=module, name=name)
|
| 39 |
+
return module
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
|
| 43 |
+
parameter_group_names = {}
|
| 44 |
+
parameter_group_vars = {}
|
| 45 |
+
skip = {}
|
| 46 |
+
if skip_list is not None:
|
| 47 |
+
skip = skip_list
|
| 48 |
+
elif hasattr(model, "no_weight_decay"):
|
| 49 |
+
skip = model.no_weight_decay()
|
| 50 |
+
|
| 51 |
+
num_layers = model.n_blocks
|
| 52 |
+
layer_scale = list(ld ** (num_layers - i) for i in range(num_layers))
|
| 53 |
+
|
| 54 |
+
for name, param in model.named_parameters():
|
| 55 |
+
if not param.requires_grad:
|
| 56 |
+
continue
|
| 57 |
+
|
| 58 |
+
if len(param.shape) == 1: # norm
|
| 59 |
+
group_name = "no_decay"
|
| 60 |
+
this_wd = 0.0
|
| 61 |
+
# layer scale, bias beta?
|
| 62 |
+
elif (
|
| 63 |
+
name in skip
|
| 64 |
+
or name.endswith(".gamma")
|
| 65 |
+
or name.endswith(".beta")
|
| 66 |
+
or name.endswith(".bias")
|
| 67 |
+
):
|
| 68 |
+
group_name = "no_decay"
|
| 69 |
+
this_wd = 0.0
|
| 70 |
+
elif "cls_token" in name or "pos_embed" in name or "mask_token" in name:
|
| 71 |
+
group_name = "no_decay"
|
| 72 |
+
this_wd = 0.0
|
| 73 |
+
else:
|
| 74 |
+
group_name = "decay"
|
| 75 |
+
this_wd = wd
|
| 76 |
+
|
| 77 |
+
if name.startswith("blocks"):
|
| 78 |
+
layer_id = int(name.split(".")[1])
|
| 79 |
+
elif name.startswith("patch_embed"):
|
| 80 |
+
layer_id = 0
|
| 81 |
+
else:
|
| 82 |
+
layer_id = 0
|
| 83 |
+
|
| 84 |
+
group_name = f"layer_{layer_id}_{group_name}"
|
| 85 |
+
|
| 86 |
+
if group_name not in parameter_group_names:
|
| 87 |
+
scale = layer_scale[layer_id]
|
| 88 |
+
cur_lr = lr * scale
|
| 89 |
+
|
| 90 |
+
parameter_group_names[group_name] = {
|
| 91 |
+
"weight_decay": this_wd,
|
| 92 |
+
"params": [],
|
| 93 |
+
"lr_init": cur_lr,
|
| 94 |
+
"lr_base": lr,
|
| 95 |
+
"lr": cur_lr,
|
| 96 |
+
}
|
| 97 |
+
parameter_group_vars[group_name] = {
|
| 98 |
+
"weight_decay": this_wd,
|
| 99 |
+
"params": [],
|
| 100 |
+
"lr_init": cur_lr,
|
| 101 |
+
"lr_base": lr,
|
| 102 |
+
"lr": cur_lr,
|
| 103 |
+
}
|
| 104 |
+
parameter_group_vars[group_name]["params"].append(param)
|
| 105 |
+
parameter_group_names[group_name]["params"].append(name)
|
| 106 |
+
return list(parameter_group_vars.values()), [
|
| 107 |
+
v["lr"] for k, v in parameter_group_vars.items()
|
| 108 |
+
]
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class BlockChunk(nn.ModuleList):
|
| 112 |
+
def forward(self, x):
|
| 113 |
+
for b in self:
|
| 114 |
+
x = b(x)
|
| 115 |
+
return x
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class DinoVisionTransformer(nn.Module):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
img_size=224,
|
| 122 |
+
patch_size=16,
|
| 123 |
+
in_chans=3,
|
| 124 |
+
embed_dim=768,
|
| 125 |
+
depth=12,
|
| 126 |
+
num_heads=12,
|
| 127 |
+
mlp_ratio=4.0,
|
| 128 |
+
qkv_bias=True,
|
| 129 |
+
ffn_bias=True,
|
| 130 |
+
proj_bias=True,
|
| 131 |
+
drop_path_rate=0.0,
|
| 132 |
+
drop_path_uniform=False,
|
| 133 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
| 134 |
+
embed_layer=PatchEmbed,
|
| 135 |
+
act_layer=nn.GELU,
|
| 136 |
+
block_fn=Block,
|
| 137 |
+
ffn_layer="mlp",
|
| 138 |
+
block_chunks=1,
|
| 139 |
+
output_idx=[5, 12, 18, 24],
|
| 140 |
+
checkpoint: bool = False,
|
| 141 |
+
num_register_tokens=0,
|
| 142 |
+
interpolate_antialias=False,
|
| 143 |
+
interpolate_offset=0.1,
|
| 144 |
+
):
|
| 145 |
+
"""
|
| 146 |
+
Args:
|
| 147 |
+
img_size (int, tuple): input image size
|
| 148 |
+
patch_size (int, tuple): patch size
|
| 149 |
+
in_chans (int): number of input channels
|
| 150 |
+
embed_dim (int): embedding dimension
|
| 151 |
+
depth (int): depth of transformer
|
| 152 |
+
num_heads (int): number of attention heads
|
| 153 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
| 154 |
+
qkv_bias (bool): enable bias for qkv if True
|
| 155 |
+
proj_bias (bool): enable bias for proj in attn if True
|
| 156 |
+
ffn_bias (bool): enable bias for ffn if True
|
| 157 |
+
drop_path_rate (float): stochastic depth rate
|
| 158 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
| 159 |
+
weight_init (str): weight init scheme
|
| 160 |
+
init_values (float): layer-scale init values
|
| 161 |
+
embed_layer (nn.Module): patch embedding layer
|
| 162 |
+
act_layer (nn.Module): MLP activation layer
|
| 163 |
+
block_fn (nn.Module): transformer block class
|
| 164 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
| 165 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
| 166 |
+
"""
|
| 167 |
+
super().__init__()
|
| 168 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
| 169 |
+
|
| 170 |
+
self.num_features = self.embed_dim = (
|
| 171 |
+
embed_dim # num_features for consistency with other models
|
| 172 |
+
)
|
| 173 |
+
self.embed_dims = [embed_dim] * output_idx[-1]
|
| 174 |
+
self.num_tokens = 1
|
| 175 |
+
self.n_blocks = depth
|
| 176 |
+
self.num_heads = num_heads
|
| 177 |
+
self.patch_size = patch_size
|
| 178 |
+
self.depths = output_idx
|
| 179 |
+
self.checkpoint = checkpoint
|
| 180 |
+
self.num_register_tokens = num_register_tokens
|
| 181 |
+
self.interpolate_antialias = interpolate_antialias
|
| 182 |
+
self.interpolate_offset = interpolate_offset
|
| 183 |
+
|
| 184 |
+
self.patch_embed = embed_layer(
|
| 185 |
+
img_size=img_size,
|
| 186 |
+
patch_size=patch_size,
|
| 187 |
+
in_chans=in_chans,
|
| 188 |
+
embed_dim=embed_dim,
|
| 189 |
+
)
|
| 190 |
+
num_patches = self.patch_embed.num_patches
|
| 191 |
+
|
| 192 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 193 |
+
self.pos_embed = nn.Parameter(
|
| 194 |
+
torch.zeros(1, num_patches + self.num_tokens, embed_dim)
|
| 195 |
+
)
|
| 196 |
+
assert num_register_tokens >= 0
|
| 197 |
+
self.register_tokens = nn.Parameter(
|
| 198 |
+
torch.zeros(1, max(1, num_register_tokens), embed_dim)
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
if drop_path_uniform is True:
|
| 202 |
+
dpr = [drop_path_rate] * depth
|
| 203 |
+
else:
|
| 204 |
+
dpr = [
|
| 205 |
+
x.item() for x in torch.linspace(0, drop_path_rate, depth)
|
| 206 |
+
] # stochastic depth decay rule
|
| 207 |
+
|
| 208 |
+
if ffn_layer == "mlp":
|
| 209 |
+
logger.info("using MLP layer as FFN")
|
| 210 |
+
ffn_layer = Mlp
|
| 211 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
| 212 |
+
logger.info("using SwiGLU layer as FFN")
|
| 213 |
+
ffn_layer = SwiGLUFFNFused
|
| 214 |
+
elif ffn_layer == "identity":
|
| 215 |
+
logger.info("using Identity layer as FFN")
|
| 216 |
+
|
| 217 |
+
def f(*args, **kwargs):
|
| 218 |
+
return nn.Identity()
|
| 219 |
+
|
| 220 |
+
ffn_layer = f
|
| 221 |
+
else:
|
| 222 |
+
raise NotImplementedError
|
| 223 |
+
|
| 224 |
+
blocks_list = [
|
| 225 |
+
block_fn(
|
| 226 |
+
dim=embed_dim,
|
| 227 |
+
num_heads=num_heads,
|
| 228 |
+
mlp_ratio=mlp_ratio,
|
| 229 |
+
qkv_bias=qkv_bias,
|
| 230 |
+
proj_bias=proj_bias,
|
| 231 |
+
ffn_bias=ffn_bias,
|
| 232 |
+
drop_path=dpr[i],
|
| 233 |
+
norm_layer=norm_layer,
|
| 234 |
+
act_layer=act_layer,
|
| 235 |
+
ffn_layer=ffn_layer,
|
| 236 |
+
init_values=init_values,
|
| 237 |
+
)
|
| 238 |
+
for i in range(depth)
|
| 239 |
+
]
|
| 240 |
+
if block_chunks > 0:
|
| 241 |
+
self.chunked_blocks = True
|
| 242 |
+
chunked_blocks = []
|
| 243 |
+
chunksize = depth // block_chunks
|
| 244 |
+
for i in range(0, depth, chunksize):
|
| 245 |
+
# this is to keep the block index consistent if we chunk the block list
|
| 246 |
+
chunked_blocks.append(
|
| 247 |
+
[nn.Identity()] * i + blocks_list[i : i + chunksize]
|
| 248 |
+
)
|
| 249 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
| 250 |
+
else:
|
| 251 |
+
self.chunked_blocks = False
|
| 252 |
+
self.blocks = nn.ModuleList(blocks_list)
|
| 253 |
+
|
| 254 |
+
# self.norm = norm_layer(embed_dim)
|
| 255 |
+
self.head = nn.Identity()
|
| 256 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
| 257 |
+
self.init_weights()
|
| 258 |
+
|
| 259 |
+
def init_weights(self):
|
| 260 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
| 261 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
| 262 |
+
if self.num_register_tokens:
|
| 263 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
| 264 |
+
named_apply(init_weights_vit_timm, self)
|
| 265 |
+
|
| 266 |
+
def interpolate_pos_encoding(self, x, w, h):
|
| 267 |
+
previous_dtype = x.dtype
|
| 268 |
+
npatch = x.shape[1] - 1
|
| 269 |
+
N = self.pos_embed.shape[1] - 1
|
| 270 |
+
if npatch == N and w == h:
|
| 271 |
+
return self.pos_embed
|
| 272 |
+
pos_embed = self.pos_embed.float()
|
| 273 |
+
class_pos_embed = pos_embed[:, 0]
|
| 274 |
+
patch_pos_embed = pos_embed[:, 1:]
|
| 275 |
+
dim = x.shape[-1]
|
| 276 |
+
w0 = w // self.patch_size
|
| 277 |
+
h0 = h // self.patch_size
|
| 278 |
+
# we add a small number to avoid floating point error in the interpolation
|
| 279 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
| 280 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
| 281 |
+
|
| 282 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 283 |
+
patch_pos_embed.reshape(
|
| 284 |
+
1, int(math.sqrt(N)), int(math.sqrt(N)), dim
|
| 285 |
+
).permute(0, 3, 1, 2),
|
| 286 |
+
scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
|
| 287 |
+
mode="bicubic",
|
| 288 |
+
antialias=self.interpolate_antialias,
|
| 289 |
+
)
|
| 290 |
+
|
| 291 |
+
assert (
|
| 292 |
+
int(w0) == patch_pos_embed.shape[-2]
|
| 293 |
+
and int(h0) == patch_pos_embed.shape[-1]
|
| 294 |
+
)
|
| 295 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
| 296 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
|
| 297 |
+
previous_dtype
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
| 301 |
+
B, nc, w, h = x.shape
|
| 302 |
+
x = self.patch_embed(x)
|
| 303 |
+
if masks is not None:
|
| 304 |
+
masks = masks.bool().view(B, -1, 1)
|
| 305 |
+
x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x)
|
| 306 |
+
|
| 307 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
| 308 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
| 309 |
+
|
| 310 |
+
if self.num_register_tokens:
|
| 311 |
+
x = torch.cat(
|
| 312 |
+
(x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
|
| 313 |
+
dim=1,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
return x
|
| 317 |
+
|
| 318 |
+
def forward_features(self, x, masks=None):
|
| 319 |
+
# if isinstance(x, list):
|
| 320 |
+
# return self.forward_features_list(x, masks)
|
| 321 |
+
shapes = [val // self.patch_size for val in x.shape[-2:]]
|
| 322 |
+
batch_size = x.shape[0]
|
| 323 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
| 324 |
+
output, cls_tokens = [], []
|
| 325 |
+
|
| 326 |
+
for i, blk in enumerate(self.blocks):
|
| 327 |
+
x = blk(x)
|
| 328 |
+
cls_token = x[:, :1]
|
| 329 |
+
|
| 330 |
+
out = x[:, self.num_register_tokens + 1 :]
|
| 331 |
+
# was like this before, add cls to dense features
|
| 332 |
+
# out = out + cls_token
|
| 333 |
+
|
| 334 |
+
output.append(out.view(batch_size, *shapes, -1))
|
| 335 |
+
cls_tokens.append(cls_token)
|
| 336 |
+
return (output, cls_tokens)
|
| 337 |
+
|
| 338 |
+
def get_params(self, lr, wd, ld, *args, **kwargs):
|
| 339 |
+
encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
|
| 340 |
+
return encoder_p, encoder_lr
|
| 341 |
+
|
| 342 |
+
def freeze(self) -> None:
|
| 343 |
+
for module in self.modules():
|
| 344 |
+
module.eval()
|
| 345 |
+
for parameters in self.parameters():
|
| 346 |
+
parameters.requires_grad = False
|
| 347 |
+
|
| 348 |
+
def train(self, mode=True):
|
| 349 |
+
super().train(mode)
|
| 350 |
+
self.mask_token.requires_grad = False
|
| 351 |
+
self.register_tokens.requires_grad = False
|
| 352 |
+
|
| 353 |
+
def forward(self, *args, is_training=False, **kwargs):
|
| 354 |
+
ret = self.forward_features(*args, **kwargs)
|
| 355 |
+
return ret
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
| 359 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
| 360 |
+
if isinstance(module, nn.Linear):
|
| 361 |
+
trunc_normal_(module.weight, std=0.02)
|
| 362 |
+
if module.bias is not None:
|
| 363 |
+
nn.init.zeros_(module.bias)
|
| 364 |
+
|
| 365 |
+
|
| 366 |
+
def vit_small(patch_size=16, **kwargs):
|
| 367 |
+
model = DinoVisionTransformer(
|
| 368 |
+
patch_size=patch_size,
|
| 369 |
+
embed_dim=384,
|
| 370 |
+
depth=12,
|
| 371 |
+
num_heads=6,
|
| 372 |
+
mlp_ratio=4,
|
| 373 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 374 |
+
**kwargs,
|
| 375 |
+
)
|
| 376 |
+
return model
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
| 380 |
+
model = DinoVisionTransformer(
|
| 381 |
+
patch_size=patch_size,
|
| 382 |
+
embed_dim=768,
|
| 383 |
+
depth=12,
|
| 384 |
+
num_heads=12,
|
| 385 |
+
mlp_ratio=4,
|
| 386 |
+
num_register_tokens=num_register_tokens,
|
| 387 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 388 |
+
**kwargs,
|
| 389 |
+
)
|
| 390 |
+
return model
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
| 394 |
+
model = DinoVisionTransformer(
|
| 395 |
+
patch_size=patch_size,
|
| 396 |
+
embed_dim=1024,
|
| 397 |
+
depth=24,
|
| 398 |
+
num_heads=16,
|
| 399 |
+
mlp_ratio=4,
|
| 400 |
+
num_register_tokens=num_register_tokens,
|
| 401 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 402 |
+
**kwargs,
|
| 403 |
+
)
|
| 404 |
+
return model
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def vit_giant2(patch_size=16, **kwargs):
|
| 408 |
+
"""
|
| 409 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
| 410 |
+
"""
|
| 411 |
+
model = DinoVisionTransformer(
|
| 412 |
+
patch_size=patch_size,
|
| 413 |
+
embed_dim=1536,
|
| 414 |
+
depth=40,
|
| 415 |
+
num_heads=24,
|
| 416 |
+
mlp_ratio=4,
|
| 417 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
| 418 |
+
**kwargs,
|
| 419 |
+
)
|
| 420 |
+
return model
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
import torch
|
| 424 |
+
import torch.nn as nn
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
dependencies = ["torch"]
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
|
| 434 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
| 435 |
+
return f"dinov2_{compact_arch_name}{patch_size}"
|
| 436 |
+
|
| 437 |
+
|
| 438 |
+
def _make_dinov2_model(
|
| 439 |
+
*,
|
| 440 |
+
arch_name: str = "vit_large",
|
| 441 |
+
img_size: int = 518,
|
| 442 |
+
patch_size: int = 14,
|
| 443 |
+
init_values: float = 1.0,
|
| 444 |
+
ffn_layer: str = "mlp",
|
| 445 |
+
block_chunks: int = 0,
|
| 446 |
+
pretrained: str = "",
|
| 447 |
+
output_idx: Sequence[int] = [],
|
| 448 |
+
num_register_tokens: int = 0,
|
| 449 |
+
drop_path_rate: float = 0.0,
|
| 450 |
+
**kwargs,
|
| 451 |
+
):
|
| 452 |
+
model_name = _make_dinov2_model_name(arch_name, patch_size)
|
| 453 |
+
print("Instantiate:", model_name)
|
| 454 |
+
|
| 455 |
+
vit_kwargs = dict(
|
| 456 |
+
img_size=img_size,
|
| 457 |
+
patch_size=patch_size,
|
| 458 |
+
init_values=init_values,
|
| 459 |
+
ffn_layer=ffn_layer,
|
| 460 |
+
block_chunks=block_chunks,
|
| 461 |
+
output_idx=output_idx,
|
| 462 |
+
drop_path_rate=drop_path_rate,
|
| 463 |
+
num_register_tokens=num_register_tokens,
|
| 464 |
+
)
|
| 465 |
+
vit_kwargs.update(**kwargs)
|
| 466 |
+
model = eval(arch_name)(**vit_kwargs)
|
| 467 |
+
if pretrained == "":
|
| 468 |
+
url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}"
|
| 469 |
+
if num_register_tokens > 0:
|
| 470 |
+
url += "_reg4"
|
| 471 |
+
url += "_pretrain.pth"
|
| 472 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 473 |
+
url, map_location="cpu", progress=False
|
| 474 |
+
)
|
| 475 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 476 |
+
print(info)
|
| 477 |
+
elif pretrained is not None:
|
| 478 |
+
state_dict = torch.load(pretrained, map_location="cpu")
|
| 479 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 480 |
+
print(f"loading from {pretrained} with:", info)
|
| 481 |
+
return model
|
| 482 |
+
|
| 483 |
+
# def forward_features_list(self, x_list, masks_list):
|
| 484 |
+
# x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
| 485 |
+
# for blk in self.blocks:
|
| 486 |
+
# x = blk(x)
|
| 487 |
+
|
| 488 |
+
# all_x = x
|
| 489 |
+
# output = []
|
| 490 |
+
# for x, masks in zip(all_x, masks_list):
|
| 491 |
+
# x_norm = self.norm(x)
|
| 492 |
+
# output.append(
|
| 493 |
+
# {
|
| 494 |
+
# "x_norm_clstoken": x_norm[:, 0],
|
| 495 |
+
# "x_norm_patchtokens": x_norm[:, 1:],
|
| 496 |
+
# "x_prenorm": x,
|
| 497 |
+
# "masks": masks,
|
| 498 |
+
# }
|
| 499 |
+
# )
|
| 500 |
+
# return output
|
| 501 |
+
|
| 502 |
+
# def _get_intermediate_layers_not_chunked(self, x, n=1):
|
| 503 |
+
# x = self.prepare_tokens_with_masks(x)
|
| 504 |
+
# # If n is an int, take the n last blocks. If it's a list, take them
|
| 505 |
+
# output, total_block_len = [], len(self.blocks)
|
| 506 |
+
# blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 507 |
+
# for i, blk in enumerate(self.blocks):
|
| 508 |
+
# x = blk(x)
|
| 509 |
+
# if i in blocks_to_take:
|
| 510 |
+
# output.append(x)
|
| 511 |
+
# assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 512 |
+
# return output
|
| 513 |
+
|
| 514 |
+
# def _get_intermediate_layers_chunked(self, x, n=1):
|
| 515 |
+
# x = self.prepare_tokens_with_masks(x)
|
| 516 |
+
# output, i, total_block_len = [], 0, len(self.blocks[-1])
|
| 517 |
+
# # If n is an int, take the n last blocks. If it's a list, take them
|
| 518 |
+
# blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
| 519 |
+
# for block_chunk in self.blocks:
|
| 520 |
+
# for blk in block_chunk[i:]: # Passing the nn.Identity()
|
| 521 |
+
# x = blk(x)
|
| 522 |
+
# if i in blocks_to_take:
|
| 523 |
+
# output.append(x)
|
| 524 |
+
# i += 1
|
| 525 |
+
# assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
| 526 |
+
# return output
|
| 527 |
+
|
| 528 |
+
# def get_intermediate_layers(
|
| 529 |
+
# self,
|
| 530 |
+
# x: torch.Tensor,
|
| 531 |
+
# n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
| 532 |
+
# reshape: bool = False,
|
| 533 |
+
# return_class_token: bool = False,
|
| 534 |
+
# norm=True,
|
| 535 |
+
# ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
| 536 |
+
# if self.chunked_blocks:
|
| 537 |
+
# outputs = self._get_intermediate_layers_chunked(x, n)
|
| 538 |
+
# else:
|
| 539 |
+
# outputs = self._get_intermediate_layers_not_chunked(x, n)
|
| 540 |
+
# if norm:
|
| 541 |
+
# outputs = [self.norm(out) for out in outputs]
|
| 542 |
+
# class_tokens = [out[:, 0] for out in outputs]
|
| 543 |
+
# outputs = [out[:, 1:] for out in outputs]
|
| 544 |
+
# if reshape:
|
| 545 |
+
# B, _, w, h = x.shape
|
| 546 |
+
# outputs = [
|
| 547 |
+
# out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
| 548 |
+
# for out in outputs
|
| 549 |
+
# ]
|
| 550 |
+
# if return_class_token:
|
| 551 |
+
# return tuple(zip(outputs, class_tokens))
|
| 552 |
+
# return tuple(outputs)
|
flash3d/unidepth/models/backbones/metadinov2/__init__.py
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from .dino_head import DINOHead
|
| 8 |
+
from .mlp import Mlp
|
| 9 |
+
from .patch_embed import PatchEmbed
|
| 10 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
| 11 |
+
from .block import NestedTensorBlock
|
| 12 |
+
from .attention import MemEffAttention
|
flash3d/unidepth/models/backbones/metadinov2/attention.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger("dinov2")
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
| 22 |
+
|
| 23 |
+
XFORMERS_AVAILABLE = True
|
| 24 |
+
except ImportError:
|
| 25 |
+
logger.warning("xFormers not available")
|
| 26 |
+
XFORMERS_AVAILABLE = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Attention(nn.Module):
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
dim: int,
|
| 33 |
+
num_heads: int = 8,
|
| 34 |
+
qkv_bias: bool = False,
|
| 35 |
+
proj_bias: bool = True,
|
| 36 |
+
attn_drop: float = 0.0,
|
| 37 |
+
proj_drop: float = 0.0,
|
| 38 |
+
) -> None:
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.num_heads = num_heads
|
| 41 |
+
head_dim = dim // num_heads
|
| 42 |
+
self.scale = head_dim**-0.5
|
| 43 |
+
|
| 44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
| 47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 48 |
+
|
| 49 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 50 |
+
B, N, C = x.shape
|
| 51 |
+
qkv = (
|
| 52 |
+
self.qkv(x)
|
| 53 |
+
.reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 54 |
+
.permute(2, 0, 3, 1, 4)
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
| 58 |
+
attn = q @ k.transpose(-2, -1)
|
| 59 |
+
|
| 60 |
+
attn = attn.softmax(dim=-1)
|
| 61 |
+
attn = self.attn_drop(attn)
|
| 62 |
+
|
| 63 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
| 64 |
+
x = self.proj(x)
|
| 65 |
+
x = self.proj_drop(x)
|
| 66 |
+
return x
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class MemEffAttention(Attention):
|
| 70 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
| 71 |
+
if not XFORMERS_AVAILABLE:
|
| 72 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
| 73 |
+
return super().forward(x)
|
| 74 |
+
|
| 75 |
+
B, N, C = x.shape
|
| 76 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
| 77 |
+
|
| 78 |
+
q, k, v = unbind(qkv, 2)
|
| 79 |
+
|
| 80 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
| 81 |
+
x = x.reshape([B, N, C])
|
| 82 |
+
|
| 83 |
+
x = self.proj(x)
|
| 84 |
+
x = self.proj_drop(x)
|
| 85 |
+
return x
|
flash3d/unidepth/models/backbones/metadinov2/block.py
ADDED
|
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
import logging
|
| 12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
| 13 |
+
|
| 14 |
+
import torch
|
| 15 |
+
import torch.nn as nn
|
| 16 |
+
|
| 17 |
+
from .attention import Attention, MemEffAttention
|
| 18 |
+
from .drop_path import DropPath
|
| 19 |
+
from .layer_scale import LayerScale
|
| 20 |
+
from .mlp import Mlp
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
logger = logging.getLogger("dinov2")
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
from xformers.ops import fmha
|
| 28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
| 29 |
+
|
| 30 |
+
XFORMERS_AVAILABLE = True
|
| 31 |
+
except ImportError:
|
| 32 |
+
logger.warning("xFormers not available")
|
| 33 |
+
XFORMERS_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class Block(nn.Module):
|
| 37 |
+
def __init__(
|
| 38 |
+
self,
|
| 39 |
+
dim: int,
|
| 40 |
+
num_heads: int,
|
| 41 |
+
mlp_ratio: float = 4.0,
|
| 42 |
+
qkv_bias: bool = False,
|
| 43 |
+
proj_bias: bool = True,
|
| 44 |
+
ffn_bias: bool = True,
|
| 45 |
+
drop: float = 0.0,
|
| 46 |
+
attn_drop: float = 0.0,
|
| 47 |
+
init_values=None,
|
| 48 |
+
drop_path: float = 0.0,
|
| 49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
| 51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
| 52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
| 53 |
+
) -> None:
|
| 54 |
+
super().__init__()
|
| 55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
| 56 |
+
self.norm1 = norm_layer(dim)
|
| 57 |
+
self.attn = attn_class(
|
| 58 |
+
dim,
|
| 59 |
+
num_heads=num_heads,
|
| 60 |
+
qkv_bias=qkv_bias,
|
| 61 |
+
proj_bias=proj_bias,
|
| 62 |
+
attn_drop=attn_drop,
|
| 63 |
+
proj_drop=drop,
|
| 64 |
+
)
|
| 65 |
+
self.ls1 = (
|
| 66 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 67 |
+
)
|
| 68 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 69 |
+
|
| 70 |
+
self.norm2 = norm_layer(dim)
|
| 71 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
| 72 |
+
self.mlp = ffn_layer(
|
| 73 |
+
in_features=dim,
|
| 74 |
+
hidden_features=mlp_hidden_dim,
|
| 75 |
+
act_layer=act_layer,
|
| 76 |
+
drop=drop,
|
| 77 |
+
bias=ffn_bias,
|
| 78 |
+
)
|
| 79 |
+
self.ls2 = (
|
| 80 |
+
LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
| 81 |
+
)
|
| 82 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
| 83 |
+
|
| 84 |
+
self.sample_drop_ratio = drop_path
|
| 85 |
+
|
| 86 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 87 |
+
def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
| 88 |
+
return self.ls1(self.attn(self.norm1(x)))
|
| 89 |
+
|
| 90 |
+
def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
|
| 91 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 92 |
+
|
| 93 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
| 94 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
| 95 |
+
x = drop_add_residual_stochastic_depth(
|
| 96 |
+
x,
|
| 97 |
+
residual_func=attn_residual_func,
|
| 98 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 99 |
+
)
|
| 100 |
+
x = drop_add_residual_stochastic_depth(
|
| 101 |
+
x,
|
| 102 |
+
residual_func=ffn_residual_func,
|
| 103 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 104 |
+
)
|
| 105 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
| 106 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
| 107 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
| 108 |
+
else:
|
| 109 |
+
x = x + attn_residual_func(x)
|
| 110 |
+
x = x + ffn_residual_func(x)
|
| 111 |
+
return x
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def drop_add_residual_stochastic_depth(
|
| 115 |
+
x: torch.Tensor,
|
| 116 |
+
residual_func: Callable[[torch.Tensor], torch.Tensor],
|
| 117 |
+
sample_drop_ratio: float = 0.0,
|
| 118 |
+
) -> torch.Tensor:
|
| 119 |
+
# 1) extract subset using permutation
|
| 120 |
+
b, n, d = x.shape
|
| 121 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 122 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 123 |
+
x_subset = x[brange]
|
| 124 |
+
|
| 125 |
+
# 2) apply residual_func to get residual
|
| 126 |
+
residual = residual_func(x_subset)
|
| 127 |
+
|
| 128 |
+
x_flat = x.flatten(1)
|
| 129 |
+
residual = residual.flatten(1)
|
| 130 |
+
|
| 131 |
+
residual_scale_factor = b / sample_subset_size
|
| 132 |
+
|
| 133 |
+
# 3) add the residual
|
| 134 |
+
x_plus_residual = torch.index_add(
|
| 135 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 136 |
+
)
|
| 137 |
+
return x_plus_residual.view_as(x)
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
| 141 |
+
b, n, d = x.shape
|
| 142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
| 143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
| 144 |
+
residual_scale_factor = b / sample_subset_size
|
| 145 |
+
return brange, residual_scale_factor
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
| 149 |
+
if scaling_vector is None:
|
| 150 |
+
x_flat = x.flatten(1)
|
| 151 |
+
residual = residual.flatten(1)
|
| 152 |
+
x_plus_residual = torch.index_add(
|
| 153 |
+
x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
x_plus_residual = scaled_index_add(
|
| 157 |
+
x,
|
| 158 |
+
brange,
|
| 159 |
+
residual.to(dtype=x.dtype),
|
| 160 |
+
scaling=scaling_vector,
|
| 161 |
+
alpha=residual_scale_factor,
|
| 162 |
+
)
|
| 163 |
+
return x_plus_residual
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
| 170 |
+
"""
|
| 171 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
| 172 |
+
"""
|
| 173 |
+
batch_sizes = (
|
| 174 |
+
[b.shape[0] for b in branges]
|
| 175 |
+
if branges is not None
|
| 176 |
+
else [x.shape[0] for x in x_list]
|
| 177 |
+
)
|
| 178 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
| 179 |
+
if all_shapes not in attn_bias_cache.keys():
|
| 180 |
+
seqlens = []
|
| 181 |
+
for b, x in zip(batch_sizes, x_list):
|
| 182 |
+
for _ in range(b):
|
| 183 |
+
seqlens.append(x.shape[1])
|
| 184 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
| 185 |
+
attn_bias._batch_sizes = batch_sizes
|
| 186 |
+
attn_bias_cache[all_shapes] = attn_bias
|
| 187 |
+
|
| 188 |
+
if branges is not None:
|
| 189 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
|
| 190 |
+
1, -1, x_list[0].shape[-1]
|
| 191 |
+
)
|
| 192 |
+
else:
|
| 193 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
| 194 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
| 195 |
+
|
| 196 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def drop_add_residual_stochastic_depth_list(
|
| 200 |
+
x_list: List[torch.Tensor],
|
| 201 |
+
residual_func: Callable[[torch.Tensor, Any], torch.Tensor],
|
| 202 |
+
sample_drop_ratio: float = 0.0,
|
| 203 |
+
scaling_vector=None,
|
| 204 |
+
) -> torch.Tensor:
|
| 205 |
+
# 1) generate random set of indices for dropping samples in the batch
|
| 206 |
+
branges_scales = [
|
| 207 |
+
get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
|
| 208 |
+
]
|
| 209 |
+
branges = [s[0] for s in branges_scales]
|
| 210 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
| 211 |
+
|
| 212 |
+
# 2) get attention bias and index+concat the tensors
|
| 213 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
| 214 |
+
|
| 215 |
+
# 3) apply residual_func to get residual, and split the result
|
| 216 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
| 217 |
+
|
| 218 |
+
outputs = []
|
| 219 |
+
for x, brange, residual, residual_scale_factor in zip(
|
| 220 |
+
x_list, branges, residual_list, residual_scale_factors
|
| 221 |
+
):
|
| 222 |
+
outputs.append(
|
| 223 |
+
add_residual(
|
| 224 |
+
x, brange, residual, residual_scale_factor, scaling_vector
|
| 225 |
+
).view_as(x)
|
| 226 |
+
)
|
| 227 |
+
return outputs
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class NestedTensorBlock(Block):
|
| 231 |
+
def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
| 232 |
+
"""
|
| 233 |
+
x_list contains a list of tensors to nest together and run
|
| 234 |
+
"""
|
| 235 |
+
assert isinstance(self.attn, MemEffAttention)
|
| 236 |
+
|
| 237 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
| 238 |
+
|
| 239 |
+
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 240 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
| 241 |
+
|
| 242 |
+
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 243 |
+
return self.mlp(self.norm2(x))
|
| 244 |
+
|
| 245 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 246 |
+
x_list,
|
| 247 |
+
residual_func=attn_residual_func,
|
| 248 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 249 |
+
scaling_vector=(
|
| 250 |
+
self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
|
| 251 |
+
),
|
| 252 |
+
)
|
| 253 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
| 254 |
+
x_list,
|
| 255 |
+
residual_func=ffn_residual_func,
|
| 256 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
| 257 |
+
scaling_vector=(
|
| 258 |
+
self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
|
| 259 |
+
),
|
| 260 |
+
)
|
| 261 |
+
return x_list
|
| 262 |
+
else:
|
| 263 |
+
|
| 264 |
+
def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 265 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
| 266 |
+
|
| 267 |
+
def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
| 268 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
| 269 |
+
|
| 270 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
| 271 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
| 272 |
+
x = x + ffn_residual_func(x)
|
| 273 |
+
return attn_bias.split(x)
|
| 274 |
+
|
| 275 |
+
def forward(self, x_or_x_list):
|
| 276 |
+
if isinstance(x_or_x_list, torch.Tensor):
|
| 277 |
+
return super().forward(x_or_x_list)
|
| 278 |
+
elif isinstance(x_or_x_list, list):
|
| 279 |
+
assert (
|
| 280 |
+
XFORMERS_AVAILABLE
|
| 281 |
+
), "Please install xFormers for nested tensors usage"
|
| 282 |
+
return self.forward_nested(x_or_x_list)
|
| 283 |
+
else:
|
| 284 |
+
raise AssertionError
|
flash3d/unidepth/models/backbones/metadinov2/dino_head.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn.init import trunc_normal_
|
| 10 |
+
from torch.nn.utils import weight_norm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class DINOHead(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_dim,
|
| 17 |
+
out_dim,
|
| 18 |
+
use_bn=False,
|
| 19 |
+
nlayers=3,
|
| 20 |
+
hidden_dim=2048,
|
| 21 |
+
bottleneck_dim=256,
|
| 22 |
+
mlp_bias=True,
|
| 23 |
+
):
|
| 24 |
+
super().__init__()
|
| 25 |
+
nlayers = max(nlayers, 1)
|
| 26 |
+
self.mlp = _build_mlp(
|
| 27 |
+
nlayers,
|
| 28 |
+
in_dim,
|
| 29 |
+
bottleneck_dim,
|
| 30 |
+
hidden_dim=hidden_dim,
|
| 31 |
+
use_bn=use_bn,
|
| 32 |
+
bias=mlp_bias,
|
| 33 |
+
)
|
| 34 |
+
self.apply(self._init_weights)
|
| 35 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
| 36 |
+
self.last_layer.weight_g.data.fill_(1)
|
| 37 |
+
|
| 38 |
+
def _init_weights(self, m):
|
| 39 |
+
if isinstance(m, nn.Linear):
|
| 40 |
+
trunc_normal_(m.weight, std=0.02)
|
| 41 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
| 42 |
+
nn.init.constant_(m.bias, 0)
|
| 43 |
+
|
| 44 |
+
def forward(self, x):
|
| 45 |
+
x = self.mlp(x)
|
| 46 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
| 47 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
| 48 |
+
x = self.last_layer(x)
|
| 49 |
+
return x
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _build_mlp(
|
| 53 |
+
nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
|
| 54 |
+
):
|
| 55 |
+
if nlayers == 1:
|
| 56 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
| 57 |
+
else:
|
| 58 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
| 59 |
+
if use_bn:
|
| 60 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 61 |
+
layers.append(nn.GELU())
|
| 62 |
+
for _ in range(nlayers - 2):
|
| 63 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
| 64 |
+
if use_bn:
|
| 65 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
| 66 |
+
layers.append(nn.GELU())
|
| 67 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
| 68 |
+
return nn.Sequential(*layers)
|
flash3d/unidepth/models/backbones/metadinov2/drop_path.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
| 16 |
+
if drop_prob == 0.0 or not training:
|
| 17 |
+
return x
|
| 18 |
+
keep_prob = 1 - drop_prob
|
| 19 |
+
shape = (x.shape[0],) + (1,) * (
|
| 20 |
+
x.ndim - 1
|
| 21 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
| 22 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
| 23 |
+
if keep_prob > 0.0:
|
| 24 |
+
random_tensor.div_(keep_prob)
|
| 25 |
+
output = x * random_tensor
|
| 26 |
+
return output
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class DropPath(nn.Module):
|
| 30 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
| 31 |
+
|
| 32 |
+
def __init__(self, drop_prob=None):
|
| 33 |
+
super(DropPath, self).__init__()
|
| 34 |
+
self.drop_prob = drop_prob
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
return drop_path(x, self.drop_prob, self.training)
|
flash3d/unidepth/models/backbones/metadinov2/layer_scale.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
| 8 |
+
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
from torch import Tensor
|
| 13 |
+
import torch.nn as nn
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class LayerScale(nn.Module):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
| 21 |
+
inplace: bool = False,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
self.inplace = inplace
|
| 25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
| 26 |
+
|
| 27 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
flash3d/unidepth/models/backbones/metadinov2/mlp.py
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from typing import Callable, Optional
|
| 13 |
+
|
| 14 |
+
from torch import Tensor, nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Mlp(nn.Module):
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
in_features: int,
|
| 21 |
+
hidden_features: Optional[int] = None,
|
| 22 |
+
out_features: Optional[int] = None,
|
| 23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
| 24 |
+
drop: float = 0.0,
|
| 25 |
+
bias: bool = True,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
out_features = out_features or in_features
|
| 29 |
+
hidden_features = hidden_features or in_features
|
| 30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
| 31 |
+
self.act = act_layer()
|
| 32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 33 |
+
self.drop = nn.Dropout(drop)
|
| 34 |
+
|
| 35 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 36 |
+
x = self.fc1(x)
|
| 37 |
+
x = self.act(x)
|
| 38 |
+
x = self.drop(x)
|
| 39 |
+
x = self.fc2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
return x
|
flash3d/unidepth/models/backbones/metadinov2/patch_embed.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
# References:
|
| 8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
| 9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
| 10 |
+
|
| 11 |
+
from typing import Callable, Optional, Tuple, Union
|
| 12 |
+
|
| 13 |
+
from torch import Tensor
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def make_2tuple(x):
|
| 18 |
+
if isinstance(x, tuple):
|
| 19 |
+
assert len(x) == 2
|
| 20 |
+
return x
|
| 21 |
+
|
| 22 |
+
assert isinstance(x, int)
|
| 23 |
+
return (x, x)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class PatchEmbed(nn.Module):
|
| 27 |
+
"""
|
| 28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
| 29 |
+
|
| 30 |
+
Args:
|
| 31 |
+
img_size: Image size.
|
| 32 |
+
patch_size: Patch token size.
|
| 33 |
+
in_chans: Number of input image channels.
|
| 34 |
+
embed_dim: Number of linear projection output channels.
|
| 35 |
+
norm_layer: Normalization layer.
|
| 36 |
+
"""
|
| 37 |
+
|
| 38 |
+
def __init__(
|
| 39 |
+
self,
|
| 40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
| 41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
| 42 |
+
in_chans: int = 3,
|
| 43 |
+
embed_dim: int = 768,
|
| 44 |
+
norm_layer: Optional[Callable] = None,
|
| 45 |
+
flatten_embedding: bool = True,
|
| 46 |
+
) -> None:
|
| 47 |
+
super().__init__()
|
| 48 |
+
|
| 49 |
+
image_HW = make_2tuple(img_size)
|
| 50 |
+
patch_HW = make_2tuple(patch_size)
|
| 51 |
+
patch_grid_size = (
|
| 52 |
+
image_HW[0] // patch_HW[0],
|
| 53 |
+
image_HW[1] // patch_HW[1],
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
self.img_size = image_HW
|
| 57 |
+
self.patch_size = patch_HW
|
| 58 |
+
self.patches_resolution = patch_grid_size
|
| 59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
| 60 |
+
|
| 61 |
+
self.in_chans = in_chans
|
| 62 |
+
self.embed_dim = embed_dim
|
| 63 |
+
|
| 64 |
+
self.flatten_embedding = flatten_embedding
|
| 65 |
+
|
| 66 |
+
self.proj = nn.Conv2d(
|
| 67 |
+
in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
|
| 68 |
+
)
|
| 69 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 70 |
+
|
| 71 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 72 |
+
_, _, H, W = x.shape
|
| 73 |
+
patch_H, patch_W = self.patch_size
|
| 74 |
+
|
| 75 |
+
assert (
|
| 76 |
+
H % patch_H == 0
|
| 77 |
+
), f"Input image height {H} is not a multiple of patch height {patch_H}"
|
| 78 |
+
assert (
|
| 79 |
+
W % patch_W == 0
|
| 80 |
+
), f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
| 81 |
+
|
| 82 |
+
x = self.proj(x) # B C H W
|
| 83 |
+
H, W = x.size(2), x.size(3)
|
| 84 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
| 85 |
+
x = self.norm(x)
|
| 86 |
+
if not self.flatten_embedding:
|
| 87 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
| 88 |
+
return x
|
| 89 |
+
|
| 90 |
+
def flops(self) -> float:
|
| 91 |
+
Ho, Wo = self.patches_resolution
|
| 92 |
+
flops = (
|
| 93 |
+
Ho
|
| 94 |
+
* Wo
|
| 95 |
+
* self.embed_dim
|
| 96 |
+
* self.in_chans
|
| 97 |
+
* (self.patch_size[0] * self.patch_size[1])
|
| 98 |
+
)
|
| 99 |
+
if self.norm is not None:
|
| 100 |
+
flops += Ho * Wo * self.embed_dim
|
| 101 |
+
return flops
|
flash3d/unidepth/models/backbones/metadinov2/swiglu_ffn.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
| 2 |
+
# All rights reserved.
|
| 3 |
+
#
|
| 4 |
+
# This source code is licensed under the license found in the
|
| 5 |
+
# LICENSE file in the root directory of this source tree.
|
| 6 |
+
|
| 7 |
+
from typing import Callable, Optional
|
| 8 |
+
|
| 9 |
+
from torch import Tensor, nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class SwiGLUFFN(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
in_features: int,
|
| 17 |
+
hidden_features: Optional[int] = None,
|
| 18 |
+
out_features: Optional[int] = None,
|
| 19 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 20 |
+
drop: float = 0.0,
|
| 21 |
+
bias: bool = True,
|
| 22 |
+
) -> None:
|
| 23 |
+
super().__init__()
|
| 24 |
+
out_features = out_features or in_features
|
| 25 |
+
hidden_features = hidden_features or in_features
|
| 26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
| 27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
| 28 |
+
|
| 29 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 30 |
+
x12 = self.w12(x)
|
| 31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
| 32 |
+
hidden = F.silu(x1) * x2
|
| 33 |
+
return self.w3(hidden)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
try:
|
| 37 |
+
from xformers.ops import SwiGLU
|
| 38 |
+
|
| 39 |
+
XFORMERS_AVAILABLE = True
|
| 40 |
+
except ImportError:
|
| 41 |
+
SwiGLU = SwiGLUFFN
|
| 42 |
+
XFORMERS_AVAILABLE = False
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class SwiGLUFFNFused(SwiGLU):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
in_features: int,
|
| 49 |
+
hidden_features: Optional[int] = None,
|
| 50 |
+
out_features: Optional[int] = None,
|
| 51 |
+
act_layer: Callable[..., nn.Module] = None,
|
| 52 |
+
drop: float = 0.0,
|
| 53 |
+
bias: bool = True,
|
| 54 |
+
) -> None:
|
| 55 |
+
out_features = out_features or in_features
|
| 56 |
+
hidden_features = hidden_features or in_features
|
| 57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
| 58 |
+
super().__init__(
|
| 59 |
+
in_features=in_features,
|
| 60 |
+
hidden_features=hidden_features,
|
| 61 |
+
out_features=out_features,
|
| 62 |
+
bias=bias,
|
| 63 |
+
)
|
flash3d/unidepth/models/encoder.py
ADDED
|
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from unidepth.models.backbones import ConvNeXtV2, _make_dinov2_model, ConvNeXt
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ModelWrap(nn.Module):
|
| 8 |
+
def __init__(self, model) -> None:
|
| 9 |
+
super().__init__()
|
| 10 |
+
self.backbone = model
|
| 11 |
+
|
| 12 |
+
def forward(self, x, *args, **kwargs):
|
| 13 |
+
features = []
|
| 14 |
+
for layer in self.backbone.features:
|
| 15 |
+
x = layer(x)
|
| 16 |
+
features.append(x)
|
| 17 |
+
return features
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def convnextv2_base(config, **kwargs):
|
| 21 |
+
model = ConvNeXtV2(
|
| 22 |
+
depths=[3, 3, 27, 3],
|
| 23 |
+
dims=[128, 256, 512, 1024],
|
| 24 |
+
output_idx=config.get("output_idx", [3, 6, 33, 36]),
|
| 25 |
+
use_checkpoint=config.get("use_checkpoint", False),
|
| 26 |
+
**kwargs,
|
| 27 |
+
)
|
| 28 |
+
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt"
|
| 29 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 30 |
+
url, map_location="cpu", progress=False
|
| 31 |
+
)["model"]
|
| 32 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 33 |
+
print(info)
|
| 34 |
+
return model
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def convnextv2_large(config, **kwargs):
|
| 38 |
+
model = ConvNeXtV2(
|
| 39 |
+
depths=[3, 3, 27, 3],
|
| 40 |
+
dims=[192, 384, 768, 1536],
|
| 41 |
+
output_idx=config.get("output_idx", [3, 6, 33, 36]),
|
| 42 |
+
use_checkpoint=config.get("use_checkpoint", False),
|
| 43 |
+
**kwargs,
|
| 44 |
+
)
|
| 45 |
+
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt"
|
| 46 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 47 |
+
url, map_location="cpu", progress=False
|
| 48 |
+
)["model"]
|
| 49 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 50 |
+
print(info)
|
| 51 |
+
return model
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
def convnextv2_large_mae(config, **kwargs):
|
| 55 |
+
model = ConvNeXtV2(
|
| 56 |
+
depths=[3, 3, 27, 3],
|
| 57 |
+
dims=[192, 384, 768, 1536],
|
| 58 |
+
output_idx=config.get("output_idx", [3, 6, 33, 36]),
|
| 59 |
+
use_checkpoint=config.get("use_checkpoint", False),
|
| 60 |
+
**kwargs,
|
| 61 |
+
)
|
| 62 |
+
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt"
|
| 63 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 64 |
+
url, map_location="cpu", progress=False
|
| 65 |
+
)["model"]
|
| 66 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 67 |
+
print(info)
|
| 68 |
+
return model
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def convnextv2_huge(config, **kwargs):
|
| 72 |
+
model = ConvNeXtV2(
|
| 73 |
+
depths=[3, 3, 27, 3],
|
| 74 |
+
dims=[352, 704, 1408, 2816],
|
| 75 |
+
output_idx=config.get("output_idx", [3, 6, 33, 36]),
|
| 76 |
+
use_checkpoint=config.get("use_checkpoint", False),
|
| 77 |
+
**kwargs,
|
| 78 |
+
)
|
| 79 |
+
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_huge_22k_512_ema.pt"
|
| 80 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 81 |
+
url, map_location="cpu", progress=False
|
| 82 |
+
)["model"]
|
| 83 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 84 |
+
print(info)
|
| 85 |
+
return model
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def convnextv2_huge_mae(config, **kwargs):
|
| 89 |
+
model = ConvNeXtV2(
|
| 90 |
+
depths=[3, 3, 27, 3],
|
| 91 |
+
dims=[352, 704, 1408, 2816],
|
| 92 |
+
output_idx=config.get("output_idx", [3, 6, 33, 36]),
|
| 93 |
+
use_checkpoint=config.get("use_checkpoint", False),
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_huge_1k_224_fcmae.pt"
|
| 97 |
+
state_dict = torch.hub.load_state_dict_from_url(
|
| 98 |
+
url, map_location="cpu", progress=False
|
| 99 |
+
)["model"]
|
| 100 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 101 |
+
print(info)
|
| 102 |
+
return model
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def convnext_large_pt(config, **kwargs):
|
| 106 |
+
model = ConvNeXt(
|
| 107 |
+
depths=[3, 3, 27, 3],
|
| 108 |
+
dims=[192, 384, 768, 1536],
|
| 109 |
+
output_idx=config.get("output_idx", [3, 6, 33, 36]),
|
| 110 |
+
use_checkpoint=config.get("use_checkpoint", False),
|
| 111 |
+
**kwargs,
|
| 112 |
+
)
|
| 113 |
+
from unidepth.models.backbones.convnext import HF_URL, checkpoint_filter_fn
|
| 114 |
+
from huggingface_hub import hf_hub_download
|
| 115 |
+
from huggingface_hub.utils import disable_progress_bars
|
| 116 |
+
|
| 117 |
+
disable_progress_bars()
|
| 118 |
+
repo_id, filename = HF_URL["convnext_large_pt"]
|
| 119 |
+
state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename))
|
| 120 |
+
state_dict = checkpoint_filter_fn(state_dict, model)
|
| 121 |
+
info = model.load_state_dict(state_dict, strict=False)
|
| 122 |
+
print(info)
|
| 123 |
+
return model
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def convnext_large(config, **kwargs):
|
| 127 |
+
model = ConvNeXt(
|
| 128 |
+
depths=[3, 3, 27, 3],
|
| 129 |
+
dims=[192, 384, 768, 1536],
|
| 130 |
+
output_idx=config.get("output_idx", [3, 6, 33, 36]),
|
| 131 |
+
use_checkpoint=config.get("use_checkpoint", False),
|
| 132 |
+
drop_path_rate=config.get("drop_path", 0.0),
|
| 133 |
+
**kwargs,
|
| 134 |
+
)
|
| 135 |
+
return model
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def dinov2_vitb14(config, pretrained: bool = True, **kwargs):
|
| 139 |
+
"""
|
| 140 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 141 |
+
"""
|
| 142 |
+
vit = _make_dinov2_model(
|
| 143 |
+
arch_name="vit_base",
|
| 144 |
+
pretrained=pretrained,
|
| 145 |
+
output_idx=config.get("output_idx", [3, 6, 9, 12]),
|
| 146 |
+
checkpoint=config.get("use_checkpoint", False),
|
| 147 |
+
drop_path_rate=config.get("drop_path", 0.0),
|
| 148 |
+
num_register_tokens=config.get("num_register_tokens", 0),
|
| 149 |
+
**kwargs,
|
| 150 |
+
)
|
| 151 |
+
return vit
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def dinov2_vitl14(config, pretrained: str = "", **kwargs):
|
| 155 |
+
"""
|
| 156 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 157 |
+
"""
|
| 158 |
+
vit = _make_dinov2_model(
|
| 159 |
+
arch_name="vit_large",
|
| 160 |
+
pretrained=config["pretrained"],
|
| 161 |
+
output_idx=config.get("output_idx", [5, 12, 18, 24]),
|
| 162 |
+
checkpoint=config.get("use_checkpoint", False),
|
| 163 |
+
drop_path_rate=config.get("drop_path", 0.0),
|
| 164 |
+
num_register_tokens=config.get("num_register_tokens", 0),
|
| 165 |
+
**kwargs,
|
| 166 |
+
)
|
| 167 |
+
return vit
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def dinov2_vitg14(config, pretrained: bool = True, **kwargs):
|
| 171 |
+
"""
|
| 172 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
| 173 |
+
"""
|
| 174 |
+
vit = _make_dinov2_model(
|
| 175 |
+
arch_name="vit_giant2",
|
| 176 |
+
ffn_layer="swiglufused",
|
| 177 |
+
pretrained=pretrained,
|
| 178 |
+
output_idx=config.get("output_idx", [10, 20, 30, 40]),
|
| 179 |
+
checkpoint=config.get("use_checkpoint", False),
|
| 180 |
+
drop_path_rate=config.get("drop_path", 0.0),
|
| 181 |
+
num_register_tokens=config.get("num_register_tokens", 0),
|
| 182 |
+
**kwargs,
|
| 183 |
+
)
|
| 184 |
+
return vit
|
flash3d/unidepth/models/unidepthv1/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .unidepthv1 import UniDepthV1
|
| 2 |
+
|
| 3 |
+
__all__ = [
|
| 4 |
+
"UniDepthV1",
|
| 5 |
+
]
|
flash3d/unidepth/models/unidepthv1/decoder.py
ADDED
|
@@ -0,0 +1,542 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
from einops import rearrange
|
| 9 |
+
from timm.models.layers import trunc_normal_
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn as nn
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
from unidepth.layers import (
|
| 15 |
+
MLP,
|
| 16 |
+
AttentionBlock,
|
| 17 |
+
NystromBlock,
|
| 18 |
+
PositionEmbeddingSine,
|
| 19 |
+
ConvUpsample,
|
| 20 |
+
)
|
| 21 |
+
from unidepth.utils.sht import rsh_cart_8
|
| 22 |
+
from unidepth.utils.geometric import (
|
| 23 |
+
generate_rays,
|
| 24 |
+
flat_interpolate,
|
| 25 |
+
)
|
| 26 |
+
from unidepth.utils.misc import max_stack
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ListAdapter(nn.Module):
|
| 30 |
+
def __init__(self, input_dims: List[int], hidden_dim: int):
|
| 31 |
+
super().__init__()
|
| 32 |
+
self.input_adapters = nn.ModuleList([])
|
| 33 |
+
self.num_chunks = len(input_dims)
|
| 34 |
+
for input_dim in input_dims:
|
| 35 |
+
self.input_adapters.append(
|
| 36 |
+
nn.Sequential(
|
| 37 |
+
nn.LayerNorm(input_dim), nn.Linear(input_dim, hidden_dim), nn.GELU()
|
| 38 |
+
)
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor, splits: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
xs = torch.split(x, splits.int().tolist(), dim=-1)
|
| 43 |
+
xs = [adapter(x) for x, adapter in zip(xs, self.input_adapters)]
|
| 44 |
+
return torch.cat(xs, dim=-1)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
class CameraHead(nn.Module):
|
| 48 |
+
def __init__(
|
| 49 |
+
self,
|
| 50 |
+
input_dim: int,
|
| 51 |
+
hidden_dim: int,
|
| 52 |
+
num_heads: int = 8,
|
| 53 |
+
expansion: int = 4,
|
| 54 |
+
depth: int = 4,
|
| 55 |
+
dropout: float = 0.0,
|
| 56 |
+
layer_scale: float = 1.0,
|
| 57 |
+
**kwargs,
|
| 58 |
+
):
|
| 59 |
+
super().__init__()
|
| 60 |
+
|
| 61 |
+
self.aggregate = AttentionBlock(
|
| 62 |
+
hidden_dim,
|
| 63 |
+
num_heads=1,
|
| 64 |
+
expansion=expansion,
|
| 65 |
+
dropout=dropout,
|
| 66 |
+
layer_scale=layer_scale,
|
| 67 |
+
)
|
| 68 |
+
self.latents_pos = nn.Parameter(
|
| 69 |
+
torch.randn(1, 4, hidden_dim), requires_grad=True
|
| 70 |
+
)
|
| 71 |
+
self.layers = nn.ModuleList([])
|
| 72 |
+
self.in_features = MLP(hidden_dim, expansion=2, dropout=dropout)
|
| 73 |
+
for _ in range(depth):
|
| 74 |
+
blk = AttentionBlock(
|
| 75 |
+
hidden_dim,
|
| 76 |
+
num_heads=num_heads,
|
| 77 |
+
expansion=expansion,
|
| 78 |
+
dropout=dropout,
|
| 79 |
+
layer_scale=layer_scale,
|
| 80 |
+
)
|
| 81 |
+
self.layers.append(blk)
|
| 82 |
+
self.out = MLP(hidden_dim, expansion=2, dropout=0.0, output_dim=1)
|
| 83 |
+
self.cls_project = nn.Sequential(
|
| 84 |
+
nn.LayerNorm(input_dim),
|
| 85 |
+
nn.Linear(input_dim, hidden_dim // 2),
|
| 86 |
+
nn.GELU(),
|
| 87 |
+
nn.Linear(hidden_dim // 2, hidden_dim),
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
def forward(self, features, cls_tokens, pos_embed) -> torch.Tensor:
|
| 91 |
+
features = features.unbind(dim=-1)
|
| 92 |
+
cls_tokens = self.cls_project(cls_tokens)
|
| 93 |
+
features_stack = torch.cat(features, dim=1)
|
| 94 |
+
features_stack = features_stack + pos_embed
|
| 95 |
+
latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1)
|
| 96 |
+
features_stack = self.in_features(features_stack)
|
| 97 |
+
features = torch.cat((features_stack, cls_tokens), dim=1)
|
| 98 |
+
cls_tokens = self.aggregate(cls_tokens, context=features, pos_embed=latents_pos)
|
| 99 |
+
for i, layer in enumerate(self.layers):
|
| 100 |
+
cls_tokens = layer(cls_tokens, pos_embed=latents_pos)
|
| 101 |
+
|
| 102 |
+
# project
|
| 103 |
+
x = self.out(cls_tokens).squeeze(-1)
|
| 104 |
+
camera_intrinsics = torch.zeros(
|
| 105 |
+
x.shape[0], 3, 3, device=x.device, requires_grad=False
|
| 106 |
+
)
|
| 107 |
+
camera_intrinsics[:, 0, 0] = x[:, 0].exp()
|
| 108 |
+
camera_intrinsics[:, 1, 1] = x[:, 1].exp()
|
| 109 |
+
camera_intrinsics[:, 0, 2] = x[:, 2].sigmoid()
|
| 110 |
+
camera_intrinsics[:, 1, 2] = x[:, 3].sigmoid()
|
| 111 |
+
camera_intrinsics[:, 2, 2] = 1.0
|
| 112 |
+
return camera_intrinsics
|
| 113 |
+
|
| 114 |
+
def set_shapes(self, shapes: Tuple[int, int]):
|
| 115 |
+
self.shapes = shapes
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class DepthHead(nn.Module):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
hidden_dim: int,
|
| 122 |
+
num_heads: int = 8,
|
| 123 |
+
expansion: int = 4,
|
| 124 |
+
depths: int | list[int] = 4,
|
| 125 |
+
camera_dim: int = 256,
|
| 126 |
+
num_resolutions: int = 4,
|
| 127 |
+
dropout: float = 0.0,
|
| 128 |
+
layer_scale: float = 1.0,
|
| 129 |
+
**kwargs,
|
| 130 |
+
) -> None:
|
| 131 |
+
super().__init__()
|
| 132 |
+
if isinstance(depths, int):
|
| 133 |
+
depths = [depths] * 3
|
| 134 |
+
assert len(depths) == 3
|
| 135 |
+
|
| 136 |
+
self.project_rays16 = MLP(
|
| 137 |
+
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim
|
| 138 |
+
)
|
| 139 |
+
self.project_rays8 = MLP(
|
| 140 |
+
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 2
|
| 141 |
+
)
|
| 142 |
+
self.project_rays4 = MLP(
|
| 143 |
+
camera_dim, expansion=expansion, dropout=dropout, output_dim=hidden_dim // 4
|
| 144 |
+
)
|
| 145 |
+
self.to_latents = MLP(hidden_dim, expansion=2, dropout=dropout)
|
| 146 |
+
|
| 147 |
+
self.features_channel_cat = nn.Linear(hidden_dim * num_resolutions, hidden_dim)
|
| 148 |
+
|
| 149 |
+
self.up8 = ConvUpsample(
|
| 150 |
+
hidden_dim, expansion=expansion, layer_scale=layer_scale
|
| 151 |
+
)
|
| 152 |
+
self.up4 = ConvUpsample(
|
| 153 |
+
hidden_dim // 2, expansion=expansion, layer_scale=layer_scale
|
| 154 |
+
)
|
| 155 |
+
self.up2 = ConvUpsample(
|
| 156 |
+
hidden_dim // 4, expansion=expansion, layer_scale=layer_scale
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
self.layers_16 = nn.ModuleList([])
|
| 160 |
+
self.layers_8 = nn.ModuleList([])
|
| 161 |
+
self.layers_4 = nn.ModuleList([])
|
| 162 |
+
self.aggregate_16 = AttentionBlock(
|
| 163 |
+
hidden_dim,
|
| 164 |
+
num_heads=1,
|
| 165 |
+
expansion=expansion,
|
| 166 |
+
dropout=dropout,
|
| 167 |
+
layer_scale=layer_scale,
|
| 168 |
+
context_dim=hidden_dim,
|
| 169 |
+
)
|
| 170 |
+
self.prompt_camera = AttentionBlock(
|
| 171 |
+
hidden_dim,
|
| 172 |
+
num_heads=1,
|
| 173 |
+
expansion=expansion,
|
| 174 |
+
dropout=dropout,
|
| 175 |
+
layer_scale=layer_scale,
|
| 176 |
+
context_dim=hidden_dim,
|
| 177 |
+
)
|
| 178 |
+
for i, (blk_lst, depth) in enumerate(
|
| 179 |
+
zip([self.layers_16, self.layers_8, self.layers_4], depths)
|
| 180 |
+
):
|
| 181 |
+
attn_cls = AttentionBlock if i == 0 else NystromBlock
|
| 182 |
+
for _ in range(depth):
|
| 183 |
+
blk_lst.append(
|
| 184 |
+
attn_cls(
|
| 185 |
+
hidden_dim // (2**i),
|
| 186 |
+
num_heads=num_heads // (2**i),
|
| 187 |
+
expansion=expansion,
|
| 188 |
+
dropout=dropout,
|
| 189 |
+
layer_scale=layer_scale,
|
| 190 |
+
)
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
self.out2 = nn.Conv2d(hidden_dim // 8, 1, 3, padding=1)
|
| 194 |
+
self.out4 = nn.Conv2d(hidden_dim // 4, 1, 3, padding=1)
|
| 195 |
+
self.out8 = nn.Conv2d(hidden_dim // 2, 1, 3, padding=1)
|
| 196 |
+
|
| 197 |
+
def set_original_shapes(self, shapes: Tuple[int, int]):
|
| 198 |
+
self.original_shapes = shapes
|
| 199 |
+
|
| 200 |
+
def set_shapes(self, shapes: Tuple[int, int]):
|
| 201 |
+
self.shapes = shapes
|
| 202 |
+
|
| 203 |
+
def forward(
|
| 204 |
+
self, features: torch.Tensor, rays_hr: torch.Tensor, pos_embed, level_embed
|
| 205 |
+
) -> torch.Tensor:
|
| 206 |
+
features = features.unbind(dim=-1)
|
| 207 |
+
shapes = self.shapes
|
| 208 |
+
|
| 209 |
+
# camera_embedding
|
| 210 |
+
# torch.cuda.synchronize()
|
| 211 |
+
# start = time()
|
| 212 |
+
rays_embedding_16 = F.normalize(
|
| 213 |
+
flat_interpolate(rays_hr, old=self.original_shapes, new=shapes), dim=-1
|
| 214 |
+
)
|
| 215 |
+
rays_embedding_8 = F.normalize(
|
| 216 |
+
flat_interpolate(
|
| 217 |
+
rays_hr, old=self.original_shapes, new=[x * 2 for x in shapes]
|
| 218 |
+
),
|
| 219 |
+
dim=-1,
|
| 220 |
+
)
|
| 221 |
+
rays_embedding_4 = F.normalize(
|
| 222 |
+
flat_interpolate(
|
| 223 |
+
rays_hr, old=self.original_shapes, new=[x * 4 for x in shapes]
|
| 224 |
+
),
|
| 225 |
+
dim=-1,
|
| 226 |
+
)
|
| 227 |
+
rays_embedding_16 = self.project_rays16(rsh_cart_8(rays_embedding_16))
|
| 228 |
+
rays_embedding_8 = self.project_rays8(rsh_cart_8(rays_embedding_8))
|
| 229 |
+
rays_embedding_4 = self.project_rays4(rsh_cart_8(rays_embedding_4))
|
| 230 |
+
# torch.cuda.synchronize()
|
| 231 |
+
# print(f"camera_embedding took {time() - start} seconds")
|
| 232 |
+
features_tokens = torch.cat(features, dim=1)
|
| 233 |
+
features_tokens_pos = pos_embed + level_embed
|
| 234 |
+
|
| 235 |
+
# Generate latents with init as pooled features
|
| 236 |
+
features_channels = torch.cat(features, dim=-1)
|
| 237 |
+
features_16 = self.features_channel_cat(features_channels)
|
| 238 |
+
latents_16 = self.to_latents(
|
| 239 |
+
flat_interpolate(features_16, old=self.shapes, new=shapes, antialias=False)
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
# Aggregate features: F -> D
|
| 243 |
+
latents_16 = self.aggregate_16(
|
| 244 |
+
latents_16, context=features_tokens, pos_embed_context=features_tokens_pos
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
# Aggregate camera: D- > D|E
|
| 248 |
+
latents_16 = self.prompt_camera(latents_16, context=rays_embedding_16)
|
| 249 |
+
|
| 250 |
+
# Block 16 - Out 8
|
| 251 |
+
for layer in self.layers_16:
|
| 252 |
+
latents_16 = layer(latents_16, pos_embed=rays_embedding_16)
|
| 253 |
+
latents_8 = self.up8(
|
| 254 |
+
rearrange(
|
| 255 |
+
latents_16 + rays_embedding_16,
|
| 256 |
+
"b (h w) c -> b c h w",
|
| 257 |
+
h=shapes[0],
|
| 258 |
+
w=shapes[1],
|
| 259 |
+
).contiguous()
|
| 260 |
+
)
|
| 261 |
+
out8 = self.out8(
|
| 262 |
+
rearrange(
|
| 263 |
+
latents_8, "b (h w) c -> b c h w", h=shapes[0] * 2, w=shapes[1] * 2
|
| 264 |
+
)
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
# Block 8 - Out 4
|
| 268 |
+
for layer in self.layers_8:
|
| 269 |
+
latents_8 = layer(latents_8, pos_embed=rays_embedding_8)
|
| 270 |
+
latents_4 = self.up4(
|
| 271 |
+
rearrange(
|
| 272 |
+
latents_8 + rays_embedding_8,
|
| 273 |
+
"b (h w) c -> b c h w",
|
| 274 |
+
h=shapes[0] * 2,
|
| 275 |
+
w=shapes[1] * 2,
|
| 276 |
+
).contiguous()
|
| 277 |
+
)
|
| 278 |
+
out4 = self.out4(
|
| 279 |
+
rearrange(
|
| 280 |
+
latents_4, "b (h w) c -> b c h w", h=shapes[0] * 4, w=shapes[1] * 4
|
| 281 |
+
)
|
| 282 |
+
)
|
| 283 |
+
|
| 284 |
+
# Block 4 - Out 2
|
| 285 |
+
for layer in self.layers_4:
|
| 286 |
+
latents_4 = layer(latents_4, pos_embed=rays_embedding_4)
|
| 287 |
+
latents_2 = self.up2(
|
| 288 |
+
rearrange(
|
| 289 |
+
latents_4 + rays_embedding_4,
|
| 290 |
+
"b (h w) c -> b c h w",
|
| 291 |
+
h=shapes[0] * 4,
|
| 292 |
+
w=shapes[1] * 4,
|
| 293 |
+
).contiguous()
|
| 294 |
+
)
|
| 295 |
+
out2 = self.out2(
|
| 296 |
+
rearrange(
|
| 297 |
+
latents_2, "b (h w) c -> b c h w", h=shapes[0] * 8, w=shapes[1] * 8
|
| 298 |
+
)
|
| 299 |
+
)
|
| 300 |
+
|
| 301 |
+
# Depth features
|
| 302 |
+
proj_latents_16 = rearrange(
|
| 303 |
+
latents_16, "b (h w) c -> b c h w", h=shapes[0], w=shapes[1]
|
| 304 |
+
).contiguous()
|
| 305 |
+
|
| 306 |
+
# MS Outputs
|
| 307 |
+
out2 = out2.clamp(-10.0, 10.0).exp()
|
| 308 |
+
out4 = out4.clamp(-10.0, 10.0).exp()
|
| 309 |
+
out8 = out8.clamp(-10.0, 10.0).exp()
|
| 310 |
+
|
| 311 |
+
return out8, out4, out2, proj_latents_16
|
| 312 |
+
|
| 313 |
+
|
| 314 |
+
class Decoder(nn.Module):
|
| 315 |
+
def __init__(
|
| 316 |
+
self,
|
| 317 |
+
config,
|
| 318 |
+
*args,
|
| 319 |
+
**kwargs,
|
| 320 |
+
):
|
| 321 |
+
super().__init__()
|
| 322 |
+
self.build(config)
|
| 323 |
+
self.apply(self._init_weights)
|
| 324 |
+
self.test_fixed_camera = False
|
| 325 |
+
self.skip_camera = False
|
| 326 |
+
|
| 327 |
+
def _init_weights(self, m):
|
| 328 |
+
if isinstance(m, nn.Linear):
|
| 329 |
+
trunc_normal_(m.weight, std=0.02)
|
| 330 |
+
if m.bias is not None:
|
| 331 |
+
nn.init.constant_(m.bias, 0)
|
| 332 |
+
elif isinstance(m, nn.Conv2d):
|
| 333 |
+
trunc_normal_(m.weight, std=0.02)
|
| 334 |
+
if m.bias is not None:
|
| 335 |
+
nn.init.constant_(m.bias, 0)
|
| 336 |
+
elif isinstance(m, nn.LayerNorm):
|
| 337 |
+
nn.init.constant_(m.bias, 0)
|
| 338 |
+
nn.init.constant_(m.weight, 1.0)
|
| 339 |
+
|
| 340 |
+
def get_adapted_features(self, features_flat, splits):
|
| 341 |
+
features_flat_cat = torch.cat(features_flat, dim=-1)
|
| 342 |
+
features_projected = self.input_adapter(
|
| 343 |
+
features_flat_cat, splits
|
| 344 |
+
) # list [b hw c] shapes
|
| 345 |
+
features = torch.chunk(features_projected, len(splits), dim=-1)
|
| 346 |
+
return features
|
| 347 |
+
|
| 348 |
+
def run_camera(self, cls_tokens, features, pos_embed, original_shapes, rays):
|
| 349 |
+
# get cls tokens projections
|
| 350 |
+
cls_tokens_splits = torch.tensor(
|
| 351 |
+
[x.shape[-1] for x in cls_tokens],
|
| 352 |
+
device=features.device,
|
| 353 |
+
requires_grad=False,
|
| 354 |
+
dtype=features.dtype,
|
| 355 |
+
)
|
| 356 |
+
cls_tokens = torch.cat(cls_tokens, dim=-1)
|
| 357 |
+
cls_tokens = self.token_adapter(cls_tokens, cls_tokens_splits)
|
| 358 |
+
cls_tokens = torch.cat(
|
| 359 |
+
torch.chunk(cls_tokens, len(cls_tokens_splits), dim=-1), dim=1
|
| 360 |
+
)
|
| 361 |
+
|
| 362 |
+
# camera layer
|
| 363 |
+
intrinsics = self.camera_layer(
|
| 364 |
+
features=features, cls_tokens=cls_tokens, pos_embed=pos_embed
|
| 365 |
+
)
|
| 366 |
+
intrinsics[:, 0, 0] = max(original_shapes) / 2 * intrinsics[:, 0, 0]
|
| 367 |
+
intrinsics[:, 1, 1] = max(original_shapes) / 2 * intrinsics[:, 1, 1]
|
| 368 |
+
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * original_shapes[1]
|
| 369 |
+
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * original_shapes[0]
|
| 370 |
+
if not self.test_fixed_camera:
|
| 371 |
+
rays, _ = generate_rays(intrinsics, original_shapes, noisy=False)
|
| 372 |
+
|
| 373 |
+
return intrinsics, rays
|
| 374 |
+
|
| 375 |
+
def forward(self, inputs, image_metas) -> torch.Tensor:
|
| 376 |
+
B, _, H, W = inputs["image"].shape
|
| 377 |
+
device = inputs["image"].device
|
| 378 |
+
|
| 379 |
+
# make stride happy?
|
| 380 |
+
original_encoder_outputs = [x.contiguous() for x in inputs["encoder_outputs"]]
|
| 381 |
+
cls_tokens = [x.contiguous() for x in inputs["cls_tokens"]]
|
| 382 |
+
|
| 383 |
+
# collect features and tokens
|
| 384 |
+
original_encoder_outputs = [
|
| 385 |
+
max_stack(original_encoder_outputs[i:j])
|
| 386 |
+
for i, j in self.slices_encoder_range
|
| 387 |
+
]
|
| 388 |
+
cls_tokens = [cls_tokens[-i - 1] for i in range(len(self.slices_encoder_range))]
|
| 389 |
+
|
| 390 |
+
# get features in b n d format
|
| 391 |
+
# level shapes, the shape per level, for swin like [[128, 128], [64, 64],...], for vit [[32,32]] -> mult times resolutions
|
| 392 |
+
resolutions = [
|
| 393 |
+
tuple(sorted([x.shape[1], x.shape[2]])) for x in original_encoder_outputs
|
| 394 |
+
]
|
| 395 |
+
level_shapes = sorted(list(set(resolutions)))[::-1]
|
| 396 |
+
|
| 397 |
+
if len(level_shapes) == 1:
|
| 398 |
+
level_shapes = level_shapes * self.num_resolutions
|
| 399 |
+
input_shapes = [
|
| 400 |
+
level_shapes[i]
|
| 401 |
+
for i, (start, end) in enumerate(self.slices_encoder)
|
| 402 |
+
for _ in range(end - start)
|
| 403 |
+
]
|
| 404 |
+
common_shape = level_shapes[-2]
|
| 405 |
+
|
| 406 |
+
# input shapes repeat shapes for each level, times the amount of the layers:
|
| 407 |
+
features_flat = [
|
| 408 |
+
flat_interpolate(
|
| 409 |
+
rearrange(x, "b h w c -> b (h w) c"), old=input_shape, new=common_shape
|
| 410 |
+
)
|
| 411 |
+
for x, input_shape in zip(original_encoder_outputs, input_shapes)
|
| 412 |
+
]
|
| 413 |
+
features_splits = torch.tensor(
|
| 414 |
+
[x.shape[-1] for x in features_flat],
|
| 415 |
+
device=device,
|
| 416 |
+
requires_grad=False,
|
| 417 |
+
dtype=torch.float32,
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
# input adapter, then do mean of features in same blocks
|
| 421 |
+
features = self.get_adapted_features(features_flat, features_splits)
|
| 422 |
+
features = torch.stack(features, dim=-1)
|
| 423 |
+
|
| 424 |
+
# positional embeddings, spatial and level
|
| 425 |
+
level_embed = torch.cat(
|
| 426 |
+
[
|
| 427 |
+
self.level_embed_layer(self.level_embeds)[i : i + 1]
|
| 428 |
+
.unsqueeze(0)
|
| 429 |
+
.repeat(B, common_shape[0] * common_shape[1], 1)
|
| 430 |
+
for i in range(self.num_resolutions)
|
| 431 |
+
],
|
| 432 |
+
dim=1,
|
| 433 |
+
)
|
| 434 |
+
pos_embed = self.pos_embed(
|
| 435 |
+
torch.zeros(
|
| 436 |
+
B,
|
| 437 |
+
1,
|
| 438 |
+
common_shape[0],
|
| 439 |
+
common_shape[1],
|
| 440 |
+
device=device,
|
| 441 |
+
requires_grad=False,
|
| 442 |
+
)
|
| 443 |
+
)
|
| 444 |
+
pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(
|
| 445 |
+
1, self.num_resolutions, 1
|
| 446 |
+
)
|
| 447 |
+
|
| 448 |
+
self.camera_layer.set_shapes(common_shape)
|
| 449 |
+
intrinsics, rays = (
|
| 450 |
+
self.run_camera(
|
| 451 |
+
cls_tokens,
|
| 452 |
+
features=features,
|
| 453 |
+
pos_embed=pos_embed + level_embed,
|
| 454 |
+
original_shapes=(H, W),
|
| 455 |
+
rays=inputs.get("rays", None),
|
| 456 |
+
)
|
| 457 |
+
if not self.skip_camera
|
| 458 |
+
else (inputs["K"], inputs["rays"])
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
# run bulk of the model
|
| 462 |
+
self.depth_layer.set_shapes(common_shape)
|
| 463 |
+
self.depth_layer.set_original_shapes((H, W))
|
| 464 |
+
out8, out4, out2, depth_features = self.depth_layer(
|
| 465 |
+
features=features,
|
| 466 |
+
rays_hr=rays,
|
| 467 |
+
pos_embed=pos_embed,
|
| 468 |
+
level_embed=level_embed,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
return intrinsics, [out8, out4, out2], depth_features, rays
|
| 472 |
+
|
| 473 |
+
@torch.jit.ignore
|
| 474 |
+
def no_weight_decay_keywords(self):
|
| 475 |
+
return {"latents_pos", "level_embeds"}
|
| 476 |
+
|
| 477 |
+
def build(self, config):
|
| 478 |
+
depth = config["model"]["pixel_decoder"]["depths"]
|
| 479 |
+
input_dims = config["model"]["pixel_encoder"]["embed_dims"]
|
| 480 |
+
hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
|
| 481 |
+
num_heads = config["model"]["num_heads"]
|
| 482 |
+
expansion = config["model"]["expansion"]
|
| 483 |
+
dropout = config["model"]["pixel_decoder"]["dropout"]
|
| 484 |
+
depths_encoder = config["model"]["pixel_encoder"]["depths"]
|
| 485 |
+
num_steps = config["model"].get("num_steps", 100000)
|
| 486 |
+
layer_scale = 1.0
|
| 487 |
+
|
| 488 |
+
self.depth = depth
|
| 489 |
+
self.dim = hidden_dim
|
| 490 |
+
self.downsample = 4
|
| 491 |
+
self.num_heads = num_heads
|
| 492 |
+
self.num_resolutions = len(depths_encoder)
|
| 493 |
+
self.depths_encoder = depths_encoder
|
| 494 |
+
|
| 495 |
+
self.slices_encoder_single = list(
|
| 496 |
+
zip([d - 1 for d in self.depths_encoder], self.depths_encoder)
|
| 497 |
+
)
|
| 498 |
+
self.slices_encoder_range = list(
|
| 499 |
+
zip([0, *self.depths_encoder[:-1]], self.depths_encoder)
|
| 500 |
+
)
|
| 501 |
+
cls_token_input_dims = [input_dims[-i - 1] for i in range(len(depths_encoder))]
|
| 502 |
+
|
| 503 |
+
input_dims = [input_dims[d - 1] for d in depths_encoder]
|
| 504 |
+
self.slices_encoder = self.slices_encoder_single
|
| 505 |
+
|
| 506 |
+
# adapt from encoder features, just project
|
| 507 |
+
self.input_adapter = ListAdapter(input_dims, hidden_dim)
|
| 508 |
+
self.token_adapter = ListAdapter(cls_token_input_dims, hidden_dim)
|
| 509 |
+
|
| 510 |
+
# camera layer
|
| 511 |
+
self.camera_layer = CameraHead(
|
| 512 |
+
input_dim=hidden_dim,
|
| 513 |
+
hidden_dim=hidden_dim,
|
| 514 |
+
num_heads=num_heads,
|
| 515 |
+
expansion=expansion,
|
| 516 |
+
depth=2,
|
| 517 |
+
dropout=dropout,
|
| 518 |
+
layer_scale=layer_scale,
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
self.depth_layer = DepthHead(
|
| 522 |
+
hidden_dim=hidden_dim,
|
| 523 |
+
num_heads=num_heads,
|
| 524 |
+
expansion=expansion,
|
| 525 |
+
depths=depth,
|
| 526 |
+
dropout=dropout,
|
| 527 |
+
camera_dim=81,
|
| 528 |
+
num_resolutions=self.num_resolutions,
|
| 529 |
+
layer_scale=layer_scale,
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
# transformer part
|
| 533 |
+
self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
|
| 534 |
+
self.level_embeds = nn.Parameter(
|
| 535 |
+
torch.randn(len(input_dims), hidden_dim), requires_grad=True
|
| 536 |
+
)
|
| 537 |
+
self.level_embed_layer = nn.Sequential(
|
| 538 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 539 |
+
nn.GELU(),
|
| 540 |
+
nn.Linear(hidden_dim, hidden_dim),
|
| 541 |
+
nn.LayerNorm(hidden_dim),
|
| 542 |
+
)
|
flash3d/unidepth/models/unidepthv1/unidepthv1.py
ADDED
|
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from copy import deepcopy
|
| 7 |
+
import importlib
|
| 8 |
+
from typing import Any, Dict, Tuple
|
| 9 |
+
from math import ceil
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
import torchvision.transforms.functional as TF
|
| 15 |
+
from einops import rearrange
|
| 16 |
+
|
| 17 |
+
from unidepth.utils.geometric import (
|
| 18 |
+
generate_rays,
|
| 19 |
+
spherical_zbuffer_to_euclidean,
|
| 20 |
+
)
|
| 21 |
+
from unidepth.utils.misc import get_params
|
| 22 |
+
from unidepth.utils.distributed import is_main_process
|
| 23 |
+
from unidepth.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
|
| 24 |
+
from unidepth.models.unidepthv1.decoder import Decoder
|
| 25 |
+
|
| 26 |
+
from huggingface_hub import PyTorchModelHubMixin
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
MAP_BACKBONES = {"ViTL14": "vitl14", "ConvNextL": "cnvnxtl"}
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# inference helpers
|
| 33 |
+
def _paddings(image_shape, network_shape):
|
| 34 |
+
cur_h, cur_w = image_shape
|
| 35 |
+
h, w = network_shape
|
| 36 |
+
pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
|
| 37 |
+
pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
|
| 38 |
+
return pad_left, pad_right, pad_top, pad_bottom
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def _shapes(image_shape, network_shape):
|
| 42 |
+
h, w = image_shape
|
| 43 |
+
input_ratio = w / h
|
| 44 |
+
output_ratio = network_shape[1] / network_shape[0]
|
| 45 |
+
if output_ratio > input_ratio:
|
| 46 |
+
ratio = network_shape[0] / h
|
| 47 |
+
elif output_ratio <= input_ratio:
|
| 48 |
+
ratio = network_shape[1] / w
|
| 49 |
+
return (ceil(h * ratio - 0.5), ceil(w * ratio - 0.5)), ratio
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def _preprocess(rgbs, intrinsics, shapes, pads, ratio, output_shapes):
|
| 53 |
+
(pad_left, pad_right, pad_top, pad_bottom) = pads
|
| 54 |
+
rgbs = F.interpolate(
|
| 55 |
+
rgbs, size=shapes, mode="bilinear", align_corners=False, antialias=True
|
| 56 |
+
)
|
| 57 |
+
rgbs = F.pad(rgbs, (pad_left, pad_right, pad_top, pad_bottom), mode="constant")
|
| 58 |
+
if intrinsics is not None:
|
| 59 |
+
intrinsics = intrinsics.clone()
|
| 60 |
+
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] * ratio
|
| 61 |
+
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] * ratio
|
| 62 |
+
intrinsics[:, 0, 2] = intrinsics[:, 0, 2] * ratio + pad_left
|
| 63 |
+
intrinsics[:, 1, 2] = intrinsics[:, 1, 2] * ratio + pad_top
|
| 64 |
+
return rgbs, intrinsics
|
| 65 |
+
return rgbs, None
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def _postprocess(predictions, intrinsics, shapes, pads, ratio, original_shapes):
|
| 69 |
+
(pad_left, pad_right, pad_top, pad_bottom) = pads
|
| 70 |
+
# pred mean, trim paddings, and upsample to input dim
|
| 71 |
+
predictions = sum(
|
| 72 |
+
[
|
| 73 |
+
F.interpolate(
|
| 74 |
+
x.clone(),
|
| 75 |
+
size=shapes,
|
| 76 |
+
mode="bilinear",
|
| 77 |
+
align_corners=False,
|
| 78 |
+
antialias=True,
|
| 79 |
+
)
|
| 80 |
+
for x in predictions
|
| 81 |
+
]
|
| 82 |
+
) / len(predictions)
|
| 83 |
+
predictions = predictions[
|
| 84 |
+
..., pad_top : shapes[0] - pad_bottom, pad_left : shapes[1] - pad_right
|
| 85 |
+
]
|
| 86 |
+
predictions = F.interpolate(
|
| 87 |
+
predictions,
|
| 88 |
+
size=original_shapes,
|
| 89 |
+
mode="bilinear",
|
| 90 |
+
align_corners=False,
|
| 91 |
+
antialias=True,
|
| 92 |
+
)
|
| 93 |
+
intrinsics[:, 0, 0] = intrinsics[:, 0, 0] / ratio
|
| 94 |
+
intrinsics[:, 1, 1] = intrinsics[:, 1, 1] / ratio
|
| 95 |
+
intrinsics[:, 0, 2] = (intrinsics[:, 0, 2] - pad_left) / ratio
|
| 96 |
+
intrinsics[:, 1, 2] = (intrinsics[:, 1, 2] - pad_top) / ratio
|
| 97 |
+
return predictions, intrinsics
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class UniDepthV1(nn.Module,
|
| 101 |
+
PyTorchModelHubMixin,
|
| 102 |
+
library_name="UniDepth",
|
| 103 |
+
repo_url="https://github.com/lpiccinelli-eth/UniDepth",
|
| 104 |
+
tags=["monocular-metric-depth-estimation"]):
|
| 105 |
+
def __init__(
|
| 106 |
+
self,
|
| 107 |
+
config,
|
| 108 |
+
eps: float = 1e-6,
|
| 109 |
+
**kwargs,
|
| 110 |
+
):
|
| 111 |
+
super().__init__()
|
| 112 |
+
self.build(config)
|
| 113 |
+
self.eps = eps
|
| 114 |
+
|
| 115 |
+
def forward(self, inputs, image_metas):
|
| 116 |
+
rgbs = inputs["image"]
|
| 117 |
+
gt_intrinsics = inputs.get("K")
|
| 118 |
+
H, W = rgbs.shape[-2:]
|
| 119 |
+
|
| 120 |
+
# Encode
|
| 121 |
+
encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
|
| 122 |
+
if "dino" in self.pixel_encoder.__class__.__name__.lower():
|
| 123 |
+
encoder_outputs = [
|
| 124 |
+
(x + y.unsqueeze(1)).contiguous()
|
| 125 |
+
for x, y in zip(encoder_outputs, cls_tokens)
|
| 126 |
+
]
|
| 127 |
+
inputs["encoder_outputs"] = encoder_outputs
|
| 128 |
+
inputs["cls_tokens"] = cls_tokens
|
| 129 |
+
|
| 130 |
+
# Get camera infos, if any
|
| 131 |
+
if gt_intrinsics is not None:
|
| 132 |
+
rays, angles = generate_rays(
|
| 133 |
+
gt_intrinsics, self.image_shape, noisy=self.training
|
| 134 |
+
)
|
| 135 |
+
inputs["rays"] = rays
|
| 136 |
+
inputs["angles"] = angles
|
| 137 |
+
inputs["K"] = gt_intrinsics
|
| 138 |
+
self.pixel_decoder.test_fixed_camera = True # use GT camera in fwd
|
| 139 |
+
|
| 140 |
+
# Decode
|
| 141 |
+
pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {})
|
| 142 |
+
predictions = sum(
|
| 143 |
+
[
|
| 144 |
+
F.interpolate(
|
| 145 |
+
x.clone(),
|
| 146 |
+
size=self.image_shape,
|
| 147 |
+
mode="bilinear",
|
| 148 |
+
align_corners=False,
|
| 149 |
+
antialias=True,
|
| 150 |
+
)
|
| 151 |
+
for x in predictions
|
| 152 |
+
]
|
| 153 |
+
) / len(predictions)
|
| 154 |
+
|
| 155 |
+
# Final 3D points backprojection
|
| 156 |
+
pred_angles = generate_rays(pred_intrinsics, (H, W), noisy=False)[-1]
|
| 157 |
+
# You may want to use inputs["angles"] if available?
|
| 158 |
+
pred_angles = rearrange(pred_angles, "b (h w) c -> b c h w", h=H, w=W)
|
| 159 |
+
points_3d = torch.cat((pred_angles, predictions), dim=1)
|
| 160 |
+
points_3d = spherical_zbuffer_to_euclidean(
|
| 161 |
+
points_3d.permute(0, 2, 3, 1)
|
| 162 |
+
).permute(0, 3, 1, 2)
|
| 163 |
+
|
| 164 |
+
# Output data, use for loss computation
|
| 165 |
+
outputs = {
|
| 166 |
+
"angles": pred_angles,
|
| 167 |
+
"intrinsics": pred_intrinsics,
|
| 168 |
+
"points": points_3d,
|
| 169 |
+
"depth": predictions[:, -1:],
|
| 170 |
+
}
|
| 171 |
+
self.pixel_decoder.test_fixed_camera = False
|
| 172 |
+
return outputs
|
| 173 |
+
|
| 174 |
+
@torch.no_grad()
|
| 175 |
+
def infer(self, rgbs: torch.Tensor, intrinsics=None, skip_camera=False):
|
| 176 |
+
if rgbs.ndim == 3:
|
| 177 |
+
rgbs = rgbs.unsqueeze(0)
|
| 178 |
+
if intrinsics is not None and intrinsics.ndim == 2:
|
| 179 |
+
intrinsics = intrinsics.unsqueeze(0)
|
| 180 |
+
B, _, H, W = rgbs.shape
|
| 181 |
+
|
| 182 |
+
rgbs = rgbs.to(self.device)
|
| 183 |
+
if intrinsics is not None:
|
| 184 |
+
intrinsics = intrinsics.to(self.device)
|
| 185 |
+
|
| 186 |
+
# process image and intrinsiscs (if any) to match network input (slow?)
|
| 187 |
+
if rgbs.max() > 5 or rgbs.dtype == torch.uint8:
|
| 188 |
+
rgbs = TF.normalize(
|
| 189 |
+
rgbs.to(torch.float32).div(255),
|
| 190 |
+
mean=IMAGENET_DATASET_MEAN,
|
| 191 |
+
std=IMAGENET_DATASET_STD,
|
| 192 |
+
)
|
| 193 |
+
else:
|
| 194 |
+
pass
|
| 195 |
+
# print("Image not normalized, was it already normalized?")
|
| 196 |
+
(h, w), ratio = _shapes((H, W), self.image_shape)
|
| 197 |
+
pad_left, pad_right, pad_top, pad_bottom = _paddings((h, w), self.image_shape)
|
| 198 |
+
rgbs, gt_intrinsics = _preprocess(
|
| 199 |
+
rgbs,
|
| 200 |
+
intrinsics,
|
| 201 |
+
(h, w),
|
| 202 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 203 |
+
ratio,
|
| 204 |
+
self.image_shape,
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# run encoder
|
| 208 |
+
encoder_outputs, cls_tokens = self.pixel_encoder(rgbs)
|
| 209 |
+
if "dino" in self.pixel_encoder.__class__.__name__.lower():
|
| 210 |
+
encoder_outputs = [
|
| 211 |
+
(x + y.unsqueeze(1)).contiguous()
|
| 212 |
+
for x, y in zip(encoder_outputs, cls_tokens)
|
| 213 |
+
]
|
| 214 |
+
|
| 215 |
+
# get data for decoder and adapt to given camera
|
| 216 |
+
inputs = {}
|
| 217 |
+
inputs["encoder_outputs"] = encoder_outputs
|
| 218 |
+
inputs["cls_tokens"] = cls_tokens
|
| 219 |
+
inputs["image"] = rgbs
|
| 220 |
+
if gt_intrinsics is not None:
|
| 221 |
+
rays, angles = generate_rays(
|
| 222 |
+
gt_intrinsics, self.image_shape, noisy=self.training
|
| 223 |
+
)
|
| 224 |
+
inputs["rays"] = rays
|
| 225 |
+
inputs["angles"] = angles
|
| 226 |
+
inputs["K"] = gt_intrinsics
|
| 227 |
+
self.pixel_decoder.test_fixed_camera = True
|
| 228 |
+
self.pixel_decoder.skip_camera = skip_camera
|
| 229 |
+
|
| 230 |
+
# decode all
|
| 231 |
+
pred_intrinsics, predictions, _, _ = self.pixel_decoder(inputs, {})
|
| 232 |
+
|
| 233 |
+
# undo the reshaping and get original image size (slow)
|
| 234 |
+
predictions, pred_intrinsics = _postprocess(
|
| 235 |
+
predictions,
|
| 236 |
+
pred_intrinsics,
|
| 237 |
+
self.image_shape,
|
| 238 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 239 |
+
ratio,
|
| 240 |
+
(H, W),
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# final 3D points backprojection
|
| 244 |
+
intrinsics = gt_intrinsics if gt_intrinsics is not None else pred_intrinsics
|
| 245 |
+
angles = generate_rays(intrinsics, (H, W), noisy=False)[-1]
|
| 246 |
+
angles = rearrange(angles, "b (h w) c -> b c h w", h=H, w=W)
|
| 247 |
+
points_3d = torch.cat((angles, predictions), dim=1)
|
| 248 |
+
points_3d = spherical_zbuffer_to_euclidean(
|
| 249 |
+
points_3d.permute(0, 2, 3, 1)
|
| 250 |
+
).permute(0, 3, 1, 2)
|
| 251 |
+
|
| 252 |
+
# output data
|
| 253 |
+
outputs = {
|
| 254 |
+
"intrinsics": pred_intrinsics,
|
| 255 |
+
"points": points_3d,
|
| 256 |
+
"depth": predictions[:, -1:],
|
| 257 |
+
}
|
| 258 |
+
self.pixel_decoder.test_fixed_camera = False
|
| 259 |
+
self.pixel_decoder.skip_camera = False
|
| 260 |
+
return outputs
|
| 261 |
+
|
| 262 |
+
def load_pretrained(self, model_file):
|
| 263 |
+
device = (
|
| 264 |
+
torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 265 |
+
)
|
| 266 |
+
dict_model = torch.load(model_file, map_location=device)
|
| 267 |
+
|
| 268 |
+
if "model" in dict_model:
|
| 269 |
+
dict_model = dict_model["model"]
|
| 270 |
+
new_state_dict = deepcopy(
|
| 271 |
+
{k.replace("module.", ""): v for k, v in dict_model.items()}
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
info = self.load_state_dict(new_state_dict, strict=False)
|
| 275 |
+
if is_main_process():
|
| 276 |
+
print(
|
| 277 |
+
f"Loaded from {model_file} for {self.__class__.__name__} results in:",
|
| 278 |
+
info,
|
| 279 |
+
)
|
| 280 |
+
|
| 281 |
+
def get_params(self, config):
|
| 282 |
+
if hasattr(self.pixel_encoder, "get_params"):
|
| 283 |
+
encoder_p, encoder_lr = self.pixel_encoder.get_params(
|
| 284 |
+
config["model"]["pixel_encoder"]["lr"],
|
| 285 |
+
config["training"]["wd"],
|
| 286 |
+
config["training"]["ld"],
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
encoder_p, encoder_lr = get_params(
|
| 290 |
+
self.pixel_encoder,
|
| 291 |
+
config["model"]["pixel_encoder"]["lr"],
|
| 292 |
+
config["training"]["wd"],
|
| 293 |
+
)
|
| 294 |
+
decoder_p, decoder_lr = get_params(
|
| 295 |
+
self.pixel_decoder, config["training"]["lr"], config["training"]["wd"]
|
| 296 |
+
)
|
| 297 |
+
return [*encoder_p, *decoder_p], [*encoder_lr, *decoder_lr]
|
| 298 |
+
|
| 299 |
+
@property
|
| 300 |
+
def device(self):
|
| 301 |
+
return next(self.parameters()).device
|
| 302 |
+
|
| 303 |
+
def build(self, config: Dict[str, Dict[str, Any]]):
|
| 304 |
+
mod = importlib.import_module("unidepth.models.encoder")
|
| 305 |
+
pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
|
| 306 |
+
pixel_encoder_config = {
|
| 307 |
+
**config["training"],
|
| 308 |
+
**config["data"],
|
| 309 |
+
**config["model"]["pixel_encoder"],
|
| 310 |
+
}
|
| 311 |
+
pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
|
| 312 |
+
|
| 313 |
+
config["model"]["pixel_encoder"]["patch_size"] = (
|
| 314 |
+
14 if "dino" in config["model"]["pixel_encoder"]["name"] else 16
|
| 315 |
+
)
|
| 316 |
+
pixel_encoder_embed_dims = (
|
| 317 |
+
pixel_encoder.embed_dims
|
| 318 |
+
if hasattr(pixel_encoder, "embed_dims")
|
| 319 |
+
else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
|
| 320 |
+
)
|
| 321 |
+
config["model"]["pixel_encoder"]["embed_dim"] = getattr(
|
| 322 |
+
pixel_encoder, "embed_dim"
|
| 323 |
+
)
|
| 324 |
+
config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
|
| 325 |
+
config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
|
| 326 |
+
|
| 327 |
+
self.pixel_encoder = pixel_encoder
|
| 328 |
+
self.pixel_decoder = Decoder(config)
|
| 329 |
+
self.image_shape = config["data"]["image_shape"]
|
flash3d/unidepth/ops/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .losses import SILog, MSE, SelfCons
|
| 2 |
+
from .scheduler import CosineScheduler
|
| 3 |
+
|
| 4 |
+
__all__ = [
|
| 5 |
+
"SILog",
|
| 6 |
+
"MSE",
|
| 7 |
+
"SelfCons",
|
| 8 |
+
"CosineScheduler",
|
| 9 |
+
]
|
flash3d/unidepth/ops/losses.py
ADDED
|
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any, Optional, Dict, Tuple, List
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
FNS = {
|
| 14 |
+
"sqrt": torch.sqrt,
|
| 15 |
+
"log": torch.log,
|
| 16 |
+
"log1": lambda x: torch.log(x + 1),
|
| 17 |
+
"linear": lambda x: x,
|
| 18 |
+
"square": torch.square,
|
| 19 |
+
"disp": lambda x: 1 / x,
|
| 20 |
+
}
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
FNS_INV = {
|
| 24 |
+
"sqrt": torch.square,
|
| 25 |
+
"log": torch.exp,
|
| 26 |
+
"log1": lambda x: torch.exp(x) - 1,
|
| 27 |
+
"linear": lambda x: x,
|
| 28 |
+
"square": torch.sqrt,
|
| 29 |
+
"disp": lambda x: 1 / x,
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def masked_mean_var(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
| 34 |
+
if mask is None:
|
| 35 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
| 36 |
+
mask = mask.float()
|
| 37 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 38 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 39 |
+
mask_sum, min=1.0
|
| 40 |
+
)
|
| 41 |
+
mask_var = torch.sum(
|
| 42 |
+
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
| 43 |
+
) / torch.clamp(mask_sum, min=1.0)
|
| 44 |
+
return mask_mean.squeeze(dim), mask_var.squeeze(dim)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]):
|
| 48 |
+
if mask is None:
|
| 49 |
+
return data.mean(dim=dim, keepdim=True)
|
| 50 |
+
mask = mask.float()
|
| 51 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 52 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 53 |
+
mask_sum, min=1.0
|
| 54 |
+
)
|
| 55 |
+
return mask_mean
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def masked_mae(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]):
|
| 59 |
+
if mask is None:
|
| 60 |
+
return data.abs().mean(dim=dim, keepdim=True)
|
| 61 |
+
mask = mask.float()
|
| 62 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 63 |
+
mask_mean = torch.sum(data.abs() * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 64 |
+
mask_sum, min=1.0
|
| 65 |
+
)
|
| 66 |
+
return mask_mean
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def masked_mse(data: torch.Tensor, mask: torch.Tensor, dim: Tuple[int, ...]):
|
| 70 |
+
if mask is None:
|
| 71 |
+
return (data**2).mean(dim=dim, keepdim=True)
|
| 72 |
+
mask = mask.float()
|
| 73 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 74 |
+
mask_mean = torch.sum((data**2) * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 75 |
+
mask_sum, min=1.0
|
| 76 |
+
)
|
| 77 |
+
return mask_mean
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
| 81 |
+
ndim = data.ndim
|
| 82 |
+
data = data.flatten(ndim - len(dim))
|
| 83 |
+
mask = mask.flatten(ndim - len(dim))
|
| 84 |
+
mask_median = torch.median(data[mask], dim=-1).values
|
| 85 |
+
return mask_median
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def masked_median_mad(data: torch.Tensor, mask: torch.Tensor):
|
| 89 |
+
data = data.flatten()
|
| 90 |
+
mask = mask.flatten()
|
| 91 |
+
mask_median = torch.median(data[mask])
|
| 92 |
+
n_samples = torch.clamp(torch.sum(mask.float()), min=1.0)
|
| 93 |
+
mask_mad = torch.sum((data[mask] - mask_median).abs()) / n_samples
|
| 94 |
+
return mask_median, mask_mad
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def masked_weighted_mean_var(
|
| 98 |
+
data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...]
|
| 99 |
+
):
|
| 100 |
+
if mask is None:
|
| 101 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
| 102 |
+
mask = mask.float()
|
| 103 |
+
mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum(
|
| 104 |
+
mask * weights, dim=dim, keepdim=True
|
| 105 |
+
).clamp(min=1.0)
|
| 106 |
+
# V1**2 - V2, V1: sum w_i, V2: sum w_i**2
|
| 107 |
+
denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum(
|
| 108 |
+
(mask * weights).square(), dim=dim, keepdim=True
|
| 109 |
+
)
|
| 110 |
+
# correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd)
|
| 111 |
+
correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp(
|
| 112 |
+
min=1.0
|
| 113 |
+
)
|
| 114 |
+
mask_var = correction_factor * torch.sum(
|
| 115 |
+
weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
| 116 |
+
)
|
| 117 |
+
return mask_mean, mask_var
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
def masked_mean_var_q(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
|
| 121 |
+
if mask is None:
|
| 122 |
+
return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
|
| 123 |
+
mask = mask.float()
|
| 124 |
+
mask_sum = torch.sum(mask, dim=dim, keepdim=True)
|
| 125 |
+
mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
|
| 126 |
+
mask_sum, min=1.0
|
| 127 |
+
)
|
| 128 |
+
mask_var = torch.sum(
|
| 129 |
+
mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
|
| 130 |
+
) / torch.clamp(mask_sum, min=1.0)
|
| 131 |
+
return mask_mean, mask_var
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class SILog(nn.Module):
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
weight: float,
|
| 138 |
+
scale_pred_weight: float = 0.15,
|
| 139 |
+
output_fn: str = "sqrt",
|
| 140 |
+
input_fn: str = "log",
|
| 141 |
+
legacy: bool = False,
|
| 142 |
+
abs_rel: bool = False,
|
| 143 |
+
norm: bool = False,
|
| 144 |
+
eps: float = 1e-5,
|
| 145 |
+
):
|
| 146 |
+
super().__init__()
|
| 147 |
+
assert output_fn in FNS
|
| 148 |
+
self.name: str = self.__class__.__name__
|
| 149 |
+
self.weight: float = weight
|
| 150 |
+
|
| 151 |
+
self.scale_pred_weight: float = scale_pred_weight
|
| 152 |
+
self.dims = (-4, -3, -2, -1) if legacy else (-2, -1)
|
| 153 |
+
self.output_fn = FNS[output_fn]
|
| 154 |
+
self.input_fn = FNS[input_fn]
|
| 155 |
+
self.abs_rel = abs_rel
|
| 156 |
+
self.norm = norm
|
| 157 |
+
self.eps: float = eps
|
| 158 |
+
|
| 159 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 160 |
+
def forward(
|
| 161 |
+
self,
|
| 162 |
+
input: torch.Tensor,
|
| 163 |
+
target: torch.Tensor,
|
| 164 |
+
mask: Optional[torch.Tensor] = None,
|
| 165 |
+
interpolate: bool = True,
|
| 166 |
+
scale_inv: torch.Tensor | None = None,
|
| 167 |
+
ss_inv: torch.Tensor | None = None,
|
| 168 |
+
**kwargs
|
| 169 |
+
) -> torch.Tensor:
|
| 170 |
+
if interpolate:
|
| 171 |
+
input = F.interpolate(
|
| 172 |
+
input, target.shape[-2:], mode="bilinear", align_corners=False
|
| 173 |
+
)
|
| 174 |
+
if mask is not None:
|
| 175 |
+
mask = mask.to(torch.bool)
|
| 176 |
+
if ss_inv is not None:
|
| 177 |
+
ss_inv = ~ss_inv
|
| 178 |
+
|
| 179 |
+
if input.shape[1] > 1:
|
| 180 |
+
input_ = torch.cat(
|
| 181 |
+
[input[:, :-1], self.input_fn(input[:, -1:].clamp(min=self.eps))], dim=1
|
| 182 |
+
)
|
| 183 |
+
target_ = torch.cat(
|
| 184 |
+
[target[:, :-1], self.input_fn(target[:, -1:].clamp(min=self.eps))],
|
| 185 |
+
dim=1,
|
| 186 |
+
)
|
| 187 |
+
error = torch.norm(input_ - target_, dim=1, keepdim=True)
|
| 188 |
+
else:
|
| 189 |
+
input_ = self.input_fn(input.clamp(min=self.eps))
|
| 190 |
+
target_ = self.input_fn(target.clamp(min=self.eps))
|
| 191 |
+
error = input_ - target_
|
| 192 |
+
|
| 193 |
+
mean_error, var_error = masked_mean_var(data=error, mask=mask, dim=self.dims)
|
| 194 |
+
|
| 195 |
+
# prevoiusly was inverted!!
|
| 196 |
+
if self.abs_rel:
|
| 197 |
+
scale_error = (input - target).abs()[:, -1:] / target[:, -1:].clip(
|
| 198 |
+
min=self.eps
|
| 199 |
+
)
|
| 200 |
+
scale_error = masked_mean(data=scale_error, mask=mask, dim=self.dims)
|
| 201 |
+
else:
|
| 202 |
+
scale_error = mean_error**2
|
| 203 |
+
|
| 204 |
+
if var_error.ndim > 1:
|
| 205 |
+
var_error = var_error.sum(dim=1)
|
| 206 |
+
scale_error = scale_error.sum(dim=1)
|
| 207 |
+
|
| 208 |
+
# if scale inv -> mask scale error, if scale/shift, mask the full loss
|
| 209 |
+
if scale_inv is not None:
|
| 210 |
+
scale_error = (1 - scale_inv.int()) * scale_error
|
| 211 |
+
scale_error = self.scale_pred_weight * scale_error
|
| 212 |
+
loss = var_error + scale_error
|
| 213 |
+
out_loss = self.output_fn(loss.clamp(min=self.eps))
|
| 214 |
+
out_loss = masked_mean(data=out_loss, mask=ss_inv, dim=(0,))
|
| 215 |
+
return out_loss.mean()
|
| 216 |
+
|
| 217 |
+
@classmethod
|
| 218 |
+
def build(cls, config: Dict[str, Any]):
|
| 219 |
+
obj = cls(
|
| 220 |
+
weight=config["weight"],
|
| 221 |
+
legacy=config["legacy"],
|
| 222 |
+
output_fn=config["output_fn"],
|
| 223 |
+
input_fn=config["input_fn"],
|
| 224 |
+
norm=config.get("norm", False),
|
| 225 |
+
scale_pred_weight=config.get("gamma", 0.15),
|
| 226 |
+
abs_rel=config.get("abs_rel", False),
|
| 227 |
+
)
|
| 228 |
+
return obj
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
class MSE(nn.Module):
|
| 232 |
+
def __init__(
|
| 233 |
+
self,
|
| 234 |
+
weight: float = 1.0,
|
| 235 |
+
input_fn: str = "linear",
|
| 236 |
+
output_fn: str = "linear",
|
| 237 |
+
):
|
| 238 |
+
super().__init__()
|
| 239 |
+
self.name: str = self.__class__.__name__
|
| 240 |
+
self.output_fn = FNS[output_fn]
|
| 241 |
+
self.input_fn = FNS[input_fn]
|
| 242 |
+
self.weight: float = weight
|
| 243 |
+
self.eps = 1e-6
|
| 244 |
+
|
| 245 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 246 |
+
def forward(
|
| 247 |
+
self,
|
| 248 |
+
input: torch.Tensor,
|
| 249 |
+
target: torch.Tensor,
|
| 250 |
+
mask: torch.Tensor | None = None,
|
| 251 |
+
batch_mask: torch.Tensor | None = None,
|
| 252 |
+
**kwargs
|
| 253 |
+
) -> torch.Tensor:
|
| 254 |
+
input = input[..., : target.shape[-1]] # B N C or B H W C
|
| 255 |
+
error = self.input_fn(input + self.eps) - self.input_fn(target + self.eps)
|
| 256 |
+
abs_error = torch.square(error).sum(dim=-1)
|
| 257 |
+
mean_error = masked_mean(data=abs_error, mask=mask, dim=(-1,)).mean(dim=-1)
|
| 258 |
+
batched_error = masked_mean(
|
| 259 |
+
self.output_fn(mean_error.clamp(self.eps)), batch_mask, dim=(0,)
|
| 260 |
+
)
|
| 261 |
+
return batched_error.mean(), mean_error.detach()
|
| 262 |
+
|
| 263 |
+
@classmethod
|
| 264 |
+
def build(cls, config: Dict[str, Any]):
|
| 265 |
+
obj = cls(
|
| 266 |
+
weight=config["weight"],
|
| 267 |
+
output_fn=config["output_fn"],
|
| 268 |
+
input_fn=config["input_fn"],
|
| 269 |
+
)
|
| 270 |
+
return obj
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class SelfCons(nn.Module):
|
| 274 |
+
def __init__(
|
| 275 |
+
self,
|
| 276 |
+
weight: float,
|
| 277 |
+
scale_pred_weight: float = 0.15,
|
| 278 |
+
output_fn: str = "sqrt",
|
| 279 |
+
input_fn: str = "log",
|
| 280 |
+
abs_rel: bool = False,
|
| 281 |
+
norm: bool = False,
|
| 282 |
+
eps: float = 1e-5,
|
| 283 |
+
):
|
| 284 |
+
super().__init__()
|
| 285 |
+
assert output_fn in FNS
|
| 286 |
+
self.name: str = self.__class__.__name__
|
| 287 |
+
self.weight: float = weight
|
| 288 |
+
|
| 289 |
+
self.scale_pred_weight: float = scale_pred_weight
|
| 290 |
+
self.dims = (-2, -1)
|
| 291 |
+
self.output_fn = FNS[output_fn]
|
| 292 |
+
self.input_fn = FNS[input_fn]
|
| 293 |
+
self.abs_rel = abs_rel
|
| 294 |
+
self.norm = norm
|
| 295 |
+
self.eps: float = eps
|
| 296 |
+
|
| 297 |
+
@torch.cuda.amp.autocast(enabled=False)
|
| 298 |
+
def forward(
|
| 299 |
+
self,
|
| 300 |
+
input: torch.Tensor,
|
| 301 |
+
mask: torch.Tensor,
|
| 302 |
+
metas: List[Dict[str, torch.Tensor]],
|
| 303 |
+
) -> torch.Tensor:
|
| 304 |
+
chunks = input.shape[0] // 2
|
| 305 |
+
device = input.device
|
| 306 |
+
mask = F.interpolate(mask.float(), size=input.shape[-2:], mode="nearest")
|
| 307 |
+
|
| 308 |
+
rescales = input.shape[-2] / torch.tensor(
|
| 309 |
+
[x["resized_shape"][0] for x in metas], device=device
|
| 310 |
+
)
|
| 311 |
+
cams = torch.cat([x["K_target"] for x in metas], dim=0).to(device)
|
| 312 |
+
flips = torch.tensor([x["flip"] for x in metas], device=device)
|
| 313 |
+
|
| 314 |
+
iters = zip(
|
| 315 |
+
input.chunk(chunks),
|
| 316 |
+
mask.chunk(chunks),
|
| 317 |
+
cams.chunk(chunks),
|
| 318 |
+
rescales.chunk(chunks),
|
| 319 |
+
flips.chunk(chunks),
|
| 320 |
+
)
|
| 321 |
+
inputs0, inputs1, masks = [], [], []
|
| 322 |
+
for i, (pair_input, pair_mask, pair_cam, pair_rescale, pair_flip) in enumerate(
|
| 323 |
+
iters
|
| 324 |
+
):
|
| 325 |
+
mask0, mask1 = pair_mask
|
| 326 |
+
input0, input1 = pair_input
|
| 327 |
+
cam0, cam1 = pair_cam
|
| 328 |
+
rescale0, rescale1 = pair_rescale
|
| 329 |
+
flip0, flip1 = pair_flip
|
| 330 |
+
|
| 331 |
+
fx_0 = cam0[0, 0] * rescale0
|
| 332 |
+
fx_1 = cam1[0, 0] * rescale1
|
| 333 |
+
cx_0 = (cam0[0, 2] - 0.5) * rescale0 + 0.5
|
| 334 |
+
cx_1 = (cam1[0, 2] - 0.5) * rescale1 + 0.5
|
| 335 |
+
cy_0 = (cam0[1, 2] - 0.5) * rescale0 + 0.5
|
| 336 |
+
cy_1 = (cam1[1, 2] - 0.5) * rescale1 + 0.5
|
| 337 |
+
|
| 338 |
+
# flip image
|
| 339 |
+
if flip0 ^ flip1:
|
| 340 |
+
input0 = torch.flip(input0, dims=(2,))
|
| 341 |
+
mask0 = torch.flip(mask0, dims=(2,))
|
| 342 |
+
cx_0 = input0.shape[-1] - cx_0
|
| 343 |
+
|
| 344 |
+
# calc zoom
|
| 345 |
+
zoom_x = float(fx_1 / fx_0)
|
| 346 |
+
|
| 347 |
+
# apply zoom
|
| 348 |
+
input0 = F.interpolate(
|
| 349 |
+
input0.unsqueeze(0),
|
| 350 |
+
scale_factor=zoom_x,
|
| 351 |
+
mode="bilinear",
|
| 352 |
+
align_corners=True,
|
| 353 |
+
).squeeze(0)
|
| 354 |
+
mask0 = F.interpolate(
|
| 355 |
+
mask0.unsqueeze(0), scale_factor=zoom_x, mode="nearest"
|
| 356 |
+
).squeeze(0)
|
| 357 |
+
|
| 358 |
+
# calc translation
|
| 359 |
+
change_left = int(cx_1 - (cx_0 - 0.5) * zoom_x - 0.5)
|
| 360 |
+
change_top = int(cy_1 - (cy_0 - 0.5) * zoom_x - 0.5)
|
| 361 |
+
change_right = input1.shape[-1] - change_left - input0.shape[-1]
|
| 362 |
+
change_bottom = input1.shape[-2] - change_top - input0.shape[-2]
|
| 363 |
+
|
| 364 |
+
# apply translation
|
| 365 |
+
pad_left = max(0, change_left)
|
| 366 |
+
pad_right = max(0, change_right)
|
| 367 |
+
pad_top = max(0, change_top)
|
| 368 |
+
pad_bottom = max(0, change_bottom)
|
| 369 |
+
|
| 370 |
+
crop_left = max(0, -change_left)
|
| 371 |
+
crop_right = max(0, -change_right)
|
| 372 |
+
crop_top = max(0, -change_top)
|
| 373 |
+
crop_bottom = max(0, -change_bottom)
|
| 374 |
+
|
| 375 |
+
input0 = F.pad(
|
| 376 |
+
input0,
|
| 377 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 378 |
+
mode="constant",
|
| 379 |
+
value=0,
|
| 380 |
+
)
|
| 381 |
+
mask0 = F.pad(
|
| 382 |
+
mask0,
|
| 383 |
+
(pad_left, pad_right, pad_top, pad_bottom),
|
| 384 |
+
mode="constant",
|
| 385 |
+
value=0,
|
| 386 |
+
)
|
| 387 |
+
input0 = input0[
|
| 388 |
+
:,
|
| 389 |
+
crop_top : input0.shape[-2] - crop_bottom,
|
| 390 |
+
crop_left : input0.shape[-1] - crop_right,
|
| 391 |
+
]
|
| 392 |
+
mask0 = mask0[
|
| 393 |
+
:,
|
| 394 |
+
crop_top : mask0.shape[-2] - crop_bottom,
|
| 395 |
+
crop_left : mask0.shape[-1] - crop_right,
|
| 396 |
+
]
|
| 397 |
+
|
| 398 |
+
mask = torch.logical_and(mask0, mask1)
|
| 399 |
+
|
| 400 |
+
inputs0.append(input0)
|
| 401 |
+
inputs1.append(input1)
|
| 402 |
+
masks.append(mask)
|
| 403 |
+
|
| 404 |
+
inputs0 = torch.stack(inputs0, dim=0)
|
| 405 |
+
inputs1 = torch.stack(inputs1, dim=0)
|
| 406 |
+
masks = torch.stack(masks, dim=0)
|
| 407 |
+
loss1 = self.loss(inputs0, inputs1.detach(), masks)
|
| 408 |
+
loss2 = self.loss(inputs1, inputs0.detach(), masks)
|
| 409 |
+
return torch.cat([loss1, loss2], dim=0).mean()
|
| 410 |
+
|
| 411 |
+
def loss(
|
| 412 |
+
self,
|
| 413 |
+
input: torch.Tensor,
|
| 414 |
+
target: torch.Tensor,
|
| 415 |
+
mask: torch.Tensor,
|
| 416 |
+
) -> torch.Tensor:
|
| 417 |
+
loss = masked_mean(
|
| 418 |
+
(input - target).square().mean(dim=1), mask=mask, dim=(-2, -1)
|
| 419 |
+
)
|
| 420 |
+
return self.output_fn(loss + self.eps)
|
| 421 |
+
|
| 422 |
+
@classmethod
|
| 423 |
+
def build(cls, config: Dict[str, Any]):
|
| 424 |
+
obj = cls(
|
| 425 |
+
weight=config["weight"],
|
| 426 |
+
output_fn=config["output_fn"],
|
| 427 |
+
input_fn=config["input_fn"],
|
| 428 |
+
)
|
| 429 |
+
return obj
|
flash3d/unidepth/ops/scheduler.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CosineScheduler(object):
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
optimizer,
|
| 13 |
+
warmup_iters,
|
| 14 |
+
total_iters,
|
| 15 |
+
key,
|
| 16 |
+
overwrite=False,
|
| 17 |
+
init_value=None,
|
| 18 |
+
base_value=None,
|
| 19 |
+
final_value=None,
|
| 20 |
+
step_init=-1,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.iter = step_init
|
| 24 |
+
self.overwrite = overwrite
|
| 25 |
+
self.optimizer = optimizer
|
| 26 |
+
self.base_value = base_value
|
| 27 |
+
self.init_value = init_value
|
| 28 |
+
self.final_value = final_value
|
| 29 |
+
self.total_iters = total_iters
|
| 30 |
+
self.warmup_iters = warmup_iters
|
| 31 |
+
self.key = key
|
| 32 |
+
self.schedulers = [
|
| 33 |
+
self.get_schedulers(group) for group in optimizer.param_groups
|
| 34 |
+
]
|
| 35 |
+
|
| 36 |
+
def get_schedulers(self, group):
|
| 37 |
+
init_value = group.get(self.key + "_init", self.init_value)
|
| 38 |
+
base_value = group.get(self.key + "_base", self.base_value)
|
| 39 |
+
final_value = group.get(self.key + "_final", self.final_value)
|
| 40 |
+
warmup_iters = self.warmup_iters
|
| 41 |
+
total_iters = self.total_iters
|
| 42 |
+
if self.overwrite:
|
| 43 |
+
final_value = self.final_value
|
| 44 |
+
|
| 45 |
+
# normalize in 0,1, then apply function (power) and denormalize
|
| 46 |
+
normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
|
| 47 |
+
normalized_schedule = np.power(normalized_schedule, 2)
|
| 48 |
+
warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
|
| 49 |
+
|
| 50 |
+
# main scheduling
|
| 51 |
+
iters = np.arange(total_iters - warmup_iters)
|
| 52 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (
|
| 53 |
+
1 + np.cos(np.pi * iters / len(iters))
|
| 54 |
+
)
|
| 55 |
+
return np.concatenate((warmup_schedule, schedule))
|
| 56 |
+
|
| 57 |
+
def step(self):
|
| 58 |
+
self.iter = self.iter + 1
|
| 59 |
+
vals = self[self.iter]
|
| 60 |
+
for group, val in zip(self.optimizer.param_groups, vals):
|
| 61 |
+
if isinstance(group[self.key], (tuple, list)):
|
| 62 |
+
val = (val, *group[self.key][1:])
|
| 63 |
+
group[self.key] = val
|
| 64 |
+
|
| 65 |
+
def __getitem__(self, it):
|
| 66 |
+
it = min(it, self.total_iters - 1)
|
| 67 |
+
return [scheduler[it] for scheduler in self.schedulers]
|
| 68 |
+
|
| 69 |
+
def get(self):
|
| 70 |
+
return [group[self.key] for group in self.optimizer.param_groups]
|
flash3d/unidepth/utils/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .evaluation_depth import eval_depth, DICT_METRICS
|
| 2 |
+
from .visualization import colorize, image_grid, log_train_artifacts
|
| 3 |
+
from .misc import format_seconds, remove_padding, get_params, identity
|
| 4 |
+
from .distributed import (
|
| 5 |
+
is_main_process,
|
| 6 |
+
setup_multi_processes,
|
| 7 |
+
setup_slurm,
|
| 8 |
+
sync_tensor_across_gpus,
|
| 9 |
+
barrier,
|
| 10 |
+
get_rank,
|
| 11 |
+
get_dist_info,
|
| 12 |
+
)
|
| 13 |
+
from .geometric import unproject_points, spherical_zbuffer_to_euclidean
|
| 14 |
+
|
| 15 |
+
__all__ = [
|
| 16 |
+
"eval_depth",
|
| 17 |
+
"DICT_METRICS",
|
| 18 |
+
"colorize",
|
| 19 |
+
"image_grid",
|
| 20 |
+
"log_train_artifacts",
|
| 21 |
+
"format_seconds",
|
| 22 |
+
"remove_padding",
|
| 23 |
+
"get_params",
|
| 24 |
+
"identity",
|
| 25 |
+
"is_main_process",
|
| 26 |
+
"setup_multi_processes",
|
| 27 |
+
"setup_slurm",
|
| 28 |
+
"sync_tensor_across_gpus",
|
| 29 |
+
"barrier",
|
| 30 |
+
"get_rank",
|
| 31 |
+
"unproject_points",
|
| 32 |
+
"spherical_zbuffer_to_euclidean",
|
| 33 |
+
"validate",
|
| 34 |
+
"get_dist_info",
|
| 35 |
+
]
|
flash3d/unidepth/utils/constants.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
|
| 10 |
+
OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
|
| 11 |
+
IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
|
| 12 |
+
IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
|
| 13 |
+
DEPTH_BINS = torch.cat(
|
| 14 |
+
(
|
| 15 |
+
torch.logspace(math.log10(0.1), math.log10(180.0), steps=512),
|
| 16 |
+
torch.tensor([260.0]),
|
| 17 |
+
),
|
| 18 |
+
dim=0,
|
| 19 |
+
)
|
| 20 |
+
LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1)
|
| 21 |
+
LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1)
|
flash3d/unidepth/utils/distributed.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import platform
|
| 8 |
+
import warnings
|
| 9 |
+
import subprocess
|
| 10 |
+
|
| 11 |
+
import cv2
|
| 12 |
+
|
| 13 |
+
import torch
|
| 14 |
+
import torch.utils.data.distributed
|
| 15 |
+
from torch import multiprocessing as mp
|
| 16 |
+
from torch import distributed as dist
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def is_dist_avail_and_initialized():
|
| 20 |
+
if not dist.is_available():
|
| 21 |
+
return False
|
| 22 |
+
if not dist.is_initialized():
|
| 23 |
+
return False
|
| 24 |
+
return True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_rank():
|
| 28 |
+
if not is_dist_avail_and_initialized():
|
| 29 |
+
return 0
|
| 30 |
+
return dist.get_rank()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def barrier():
|
| 34 |
+
if not is_dist_avail_and_initialized():
|
| 35 |
+
return
|
| 36 |
+
dist.barrier()
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def is_main_process():
|
| 40 |
+
return get_rank() == 0
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def is_rank_zero(args):
|
| 44 |
+
return args.rank == 0
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def get_dist_info():
|
| 48 |
+
if dist.is_available() and dist.is_initialized():
|
| 49 |
+
rank = dist.get_rank()
|
| 50 |
+
world_size = dist.get_world_size()
|
| 51 |
+
else:
|
| 52 |
+
rank = 0
|
| 53 |
+
world_size = 1
|
| 54 |
+
return rank, world_size
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def setup_multi_processes(cfg):
|
| 58 |
+
"""Setup multi-processing environment variables."""
|
| 59 |
+
# set multi-process start method as `fork` to speed up the training
|
| 60 |
+
if platform.system() != "Windows":
|
| 61 |
+
mp_start_method = cfg.get("mp_start_method", "fork")
|
| 62 |
+
current_method = mp.get_start_method(allow_none=True)
|
| 63 |
+
if current_method is not None and current_method != mp_start_method:
|
| 64 |
+
warnings.warn(
|
| 65 |
+
f"Multi-processing start method `{mp_start_method}` is "
|
| 66 |
+
f"different from the previous setting `{current_method}`."
|
| 67 |
+
f"It will be force set to `{mp_start_method}`. You can change "
|
| 68 |
+
f"this behavior by changing `mp_start_method` in your config."
|
| 69 |
+
)
|
| 70 |
+
mp.set_start_method(mp_start_method, force=True)
|
| 71 |
+
|
| 72 |
+
# disable opencv multithreading to avoid system being overloaded
|
| 73 |
+
opencv_num_threads = cfg.get("opencv_num_threads", 0)
|
| 74 |
+
cv2.setNumThreads(opencv_num_threads)
|
| 75 |
+
|
| 76 |
+
# setup OMP threads
|
| 77 |
+
# This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
|
| 78 |
+
workers_per_gpu = cfg.get("workers_per_gpu", 4)
|
| 79 |
+
|
| 80 |
+
if "OMP_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
|
| 81 |
+
omp_num_threads = 1
|
| 82 |
+
warnings.warn(
|
| 83 |
+
f"Setting OMP_NUM_THREADS environment variable for each process "
|
| 84 |
+
f"to be {omp_num_threads} in default, to avoid your system being "
|
| 85 |
+
f"overloaded, please further tune the variable for optimal "
|
| 86 |
+
f"performance in your application as needed."
|
| 87 |
+
)
|
| 88 |
+
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)
|
| 89 |
+
|
| 90 |
+
# setup MKL threads
|
| 91 |
+
if "MKL_NUM_THREADS" not in os.environ and workers_per_gpu > 1:
|
| 92 |
+
mkl_num_threads = os.environ.get("OMP_NUM_THREADS", 1)
|
| 93 |
+
warnings.warn(
|
| 94 |
+
f"Setting MKL_NUM_THREADS environment variable for each process "
|
| 95 |
+
f"to be {mkl_num_threads} in default, to avoid your system being "
|
| 96 |
+
f"overloaded, please further tune the variable for optimal "
|
| 97 |
+
f"performance in your application as needed."
|
| 98 |
+
)
|
| 99 |
+
os.environ["MKL_NUM_THREADS"] = str(mkl_num_threads)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def setup_slurm(backend: str, port: str) -> None:
|
| 103 |
+
"""Initialize slurm distributed training environment.
|
| 104 |
+
If argument ``port`` is not specified, then the master port will be system
|
| 105 |
+
environment variable ``MASTER_PORT``. If ``MASTER_PORT`` is not in system
|
| 106 |
+
environment variable, then a default port ``29500`` will be used.
|
| 107 |
+
Args:
|
| 108 |
+
backend (str): Backend of torch.distributed.
|
| 109 |
+
port (int, optional): Master port. Defaults to None.
|
| 110 |
+
"""
|
| 111 |
+
proc_id = int(os.environ["SLURM_PROCID"])
|
| 112 |
+
ntasks = int(os.environ["SLURM_NTASKS"])
|
| 113 |
+
node_list = os.environ["SLURM_NODELIST"]
|
| 114 |
+
|
| 115 |
+
num_gpus = torch.cuda.device_count()
|
| 116 |
+
|
| 117 |
+
torch.cuda.set_device(proc_id % num_gpus)
|
| 118 |
+
addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
|
| 119 |
+
os.environ["MASTER_PORT"] = str(port)
|
| 120 |
+
os.environ["MASTER_ADDR"] = addr
|
| 121 |
+
os.environ["WORLD_SIZE"] = str(ntasks)
|
| 122 |
+
os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
|
| 123 |
+
os.environ["RANK"] = str(proc_id)
|
| 124 |
+
print(
|
| 125 |
+
proc_id,
|
| 126 |
+
ntasks,
|
| 127 |
+
num_gpus,
|
| 128 |
+
proc_id % num_gpus,
|
| 129 |
+
node_list,
|
| 130 |
+
addr,
|
| 131 |
+
os.environ["MASTER_PORT"],
|
| 132 |
+
os.system("nvidia-smi -L"),
|
| 133 |
+
)
|
| 134 |
+
dist.init_process_group(backend, rank=proc_id, world_size=ntasks)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
def sync_tensor_across_gpus(t, dim=0, cat=True):
|
| 138 |
+
if t is None or not (dist.is_available() and dist.is_initialized()):
|
| 139 |
+
return t
|
| 140 |
+
t = torch.atleast_1d(t)
|
| 141 |
+
group = dist.group.WORLD
|
| 142 |
+
group_size = torch.distributed.get_world_size(group)
|
| 143 |
+
|
| 144 |
+
local_size = torch.tensor(t.size(dim), device=t.device)
|
| 145 |
+
all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)]
|
| 146 |
+
dist.all_gather(all_sizes, local_size)
|
| 147 |
+
max_size = max(all_sizes)
|
| 148 |
+
size_diff = max_size.item() - local_size.item()
|
| 149 |
+
if size_diff:
|
| 150 |
+
padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype)
|
| 151 |
+
t = torch.cat((t, padding))
|
| 152 |
+
|
| 153 |
+
gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)]
|
| 154 |
+
dist.all_gather(gather_t_tensor, t)
|
| 155 |
+
all_ts = []
|
| 156 |
+
for t, size in zip(gather_t_tensor, all_sizes):
|
| 157 |
+
all_ts.append(t[:size])
|
| 158 |
+
if cat:
|
| 159 |
+
return torch.cat(all_ts, dim=0)
|
| 160 |
+
return all_ts
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
import pickle
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def sync_string_across_gpus(keys: list[str], device, dim=0):
|
| 167 |
+
keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL)
|
| 168 |
+
keys_serialized_tensor = torch.frombuffer(keys_serialized, dtype=torch.uint8).to(
|
| 169 |
+
device
|
| 170 |
+
)
|
| 171 |
+
keys_serialized_tensor = sync_tensor_across_gpus(
|
| 172 |
+
keys_serialized_tensor, dim=0, cat=False
|
| 173 |
+
)
|
| 174 |
+
keys = [
|
| 175 |
+
key
|
| 176 |
+
for keys in keys_serialized_tensor
|
| 177 |
+
for key in pickle.loads(bytes(keys.cpu().tolist()))
|
| 178 |
+
]
|
| 179 |
+
return keys
|
flash3d/unidepth/utils/ema_torch.py
ADDED
|
@@ -0,0 +1,342 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import division
|
| 7 |
+
from __future__ import unicode_literals
|
| 8 |
+
|
| 9 |
+
from typing import Iterable, Optional
|
| 10 |
+
import weakref
|
| 11 |
+
import copy
|
| 12 |
+
import contextlib
|
| 13 |
+
from math import tanh
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class DummyExponentialMovingAverage:
|
| 19 |
+
def __init__(self, *args, **kwargs):
|
| 20 |
+
pass
|
| 21 |
+
|
| 22 |
+
def _get_parameters(self, *args, **kwargs):
|
| 23 |
+
pass
|
| 24 |
+
|
| 25 |
+
def get_current_decay(self, *args, **kwargs):
|
| 26 |
+
pass
|
| 27 |
+
|
| 28 |
+
def update(self, *args, **kwargs):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
def copy_to(self, *args, **kwargs):
|
| 32 |
+
pass
|
| 33 |
+
|
| 34 |
+
def store(self, *args, **kwargs):
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
def restore(self, *args, **kwargs):
|
| 38 |
+
return
|
| 39 |
+
|
| 40 |
+
@contextlib.contextmanager
|
| 41 |
+
def average_parameters(self, *args, **kwargs):
|
| 42 |
+
try:
|
| 43 |
+
yield
|
| 44 |
+
finally:
|
| 45 |
+
pass
|
| 46 |
+
|
| 47 |
+
def to(self, *args, **kwargs):
|
| 48 |
+
pass
|
| 49 |
+
|
| 50 |
+
def state_dict(self, *args, **kwargs):
|
| 51 |
+
pass
|
| 52 |
+
|
| 53 |
+
def load_state_dict(self, *args, **kwargs):
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ExponentialMovingAverage:
|
| 58 |
+
"""
|
| 59 |
+
Maintains (exponential) moving average of a set of parameters.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
parameters: Iterable of `torch.nn.Parameter` (typically from
|
| 63 |
+
`model.parameters()`).
|
| 64 |
+
Note that EMA is computed on *all* provided parameters,
|
| 65 |
+
regardless of whether or not they have `requires_grad = True`;
|
| 66 |
+
this allows a single EMA object to be consistantly used even
|
| 67 |
+
if which parameters are trainable changes step to step.
|
| 68 |
+
|
| 69 |
+
If you want to some parameters in the EMA, do not pass them
|
| 70 |
+
to the object in the first place. For example:
|
| 71 |
+
|
| 72 |
+
ExponentialMovingAverage(
|
| 73 |
+
parameters=[p for p in model.parameters() if p.requires_grad],
|
| 74 |
+
decay=0.9
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
will ignore parameters that do not require grad.
|
| 78 |
+
|
| 79 |
+
decay: The exponential decay.
|
| 80 |
+
|
| 81 |
+
use_num_updates: Whether to use number of updates when computing
|
| 82 |
+
averages.
|
| 83 |
+
"""
|
| 84 |
+
|
| 85 |
+
def __init__(
|
| 86 |
+
self,
|
| 87 |
+
parameters: Iterable[torch.nn.Parameter],
|
| 88 |
+
decay: float,
|
| 89 |
+
use_num_updates: bool = True,
|
| 90 |
+
update_after_step: int = 10000,
|
| 91 |
+
tau: int = 20000,
|
| 92 |
+
switch: bool = False,
|
| 93 |
+
):
|
| 94 |
+
if decay < 0.0 or decay > 1.0:
|
| 95 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 96 |
+
self.decay = decay
|
| 97 |
+
self.switch = switch # fi keeping EMA params in model after epochs
|
| 98 |
+
self.num_updates = 0 if use_num_updates else None
|
| 99 |
+
parameters = list(parameters)
|
| 100 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
| 101 |
+
self.collected_params = None
|
| 102 |
+
# By maintaining only a weakref to each parameter,
|
| 103 |
+
# we maintain the old GC behaviour of ExponentialMovingAverage:
|
| 104 |
+
# if the model goes out of scope but the ExponentialMovingAverage
|
| 105 |
+
# is kept, no references to the model or its parameters will be
|
| 106 |
+
# maintained, and the model will be cleaned up.
|
| 107 |
+
self._params_refs = [weakref.ref(p) for p in parameters]
|
| 108 |
+
self.update_after_step = update_after_step
|
| 109 |
+
self.tau = tau
|
| 110 |
+
|
| 111 |
+
def _get_parameters(
|
| 112 |
+
self, parameters: Optional[Iterable[torch.nn.Parameter]]
|
| 113 |
+
) -> Iterable[torch.nn.Parameter]:
|
| 114 |
+
if parameters is None:
|
| 115 |
+
parameters = [p() for p in self._params_refs]
|
| 116 |
+
if any(p is None for p in parameters):
|
| 117 |
+
raise ValueError(
|
| 118 |
+
"(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);"
|
| 119 |
+
" please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected."
|
| 120 |
+
)
|
| 121 |
+
return parameters
|
| 122 |
+
else:
|
| 123 |
+
parameters = list(parameters)
|
| 124 |
+
if len(parameters) != len(self.shadow_params):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"Number of parameters passed as argument is different "
|
| 127 |
+
"from number of shadow parameters maintained by this "
|
| 128 |
+
"ExponentialMovingAverage"
|
| 129 |
+
)
|
| 130 |
+
return parameters
|
| 131 |
+
|
| 132 |
+
def get_current_decay(self):
|
| 133 |
+
epoch = max(self.num_updates - self.update_after_step - 1, 0.0)
|
| 134 |
+
if epoch <= 0:
|
| 135 |
+
return 0.0
|
| 136 |
+
value = tanh(epoch / self.tau) * self.decay
|
| 137 |
+
return value
|
| 138 |
+
|
| 139 |
+
def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
|
| 140 |
+
"""
|
| 141 |
+
Update currently maintained parameters.
|
| 142 |
+
|
| 143 |
+
Call this every time the parameters are updated, such as the result of
|
| 144 |
+
the `optimizer.step()` call.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
parameters: Iterable of `torch.nn.Parameter`; usually the same set of
|
| 148 |
+
parameters used to initialize this object. If `None`, the
|
| 149 |
+
parameters with which this `ExponentialMovingAverage` was
|
| 150 |
+
initialized will be used.
|
| 151 |
+
"""
|
| 152 |
+
parameters = self._get_parameters(parameters)
|
| 153 |
+
decay = self.get_current_decay()
|
| 154 |
+
if self.num_updates is not None:
|
| 155 |
+
self.num_updates += 1
|
| 156 |
+
|
| 157 |
+
one_minus_decay = 1.0 - decay
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 160 |
+
tmp = s_param - param
|
| 161 |
+
# tmp will be a new tensor so we can do in-place
|
| 162 |
+
tmp.mul_(one_minus_decay)
|
| 163 |
+
s_param.sub_(tmp)
|
| 164 |
+
|
| 165 |
+
def copy_to(
|
| 166 |
+
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
| 167 |
+
) -> None:
|
| 168 |
+
"""
|
| 169 |
+
Copy current averaged parameters into given collection of parameters.
|
| 170 |
+
|
| 171 |
+
Args:
|
| 172 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 173 |
+
updated with the stored moving averages. If `None`, the
|
| 174 |
+
parameters with which this `ExponentialMovingAverage` was
|
| 175 |
+
initialized will be used.
|
| 176 |
+
"""
|
| 177 |
+
parameters = self._get_parameters(parameters)
|
| 178 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 179 |
+
param.data.copy_(s_param.data)
|
| 180 |
+
|
| 181 |
+
def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
|
| 182 |
+
"""
|
| 183 |
+
Save the current parameters for restoring later.
|
| 184 |
+
|
| 185 |
+
Args:
|
| 186 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 187 |
+
temporarily stored. If `None`, the parameters of with which this
|
| 188 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 189 |
+
"""
|
| 190 |
+
parameters = self._get_parameters(parameters)
|
| 191 |
+
self.collected_params = [param.detach().clone() for param in parameters]
|
| 192 |
+
|
| 193 |
+
def restore(
|
| 194 |
+
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
| 195 |
+
) -> None:
|
| 196 |
+
"""
|
| 197 |
+
Restore the parameters stored with the `store` method.
|
| 198 |
+
Useful to validate the model with EMA parameters without affecting the
|
| 199 |
+
original optimization process. Store the parameters before the
|
| 200 |
+
`copy_to` method. After validation (or model saving), use this to
|
| 201 |
+
restore the former parameters.
|
| 202 |
+
|
| 203 |
+
Args:
|
| 204 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 205 |
+
updated with the stored parameters. If `None`, the
|
| 206 |
+
parameters with which this `ExponentialMovingAverage` was
|
| 207 |
+
initialized will be used.
|
| 208 |
+
"""
|
| 209 |
+
if self.collected_params is None:
|
| 210 |
+
raise RuntimeError(
|
| 211 |
+
"This ExponentialMovingAverage has no `store()`ed weights "
|
| 212 |
+
"to `restore()`"
|
| 213 |
+
)
|
| 214 |
+
parameters = self._get_parameters(parameters)
|
| 215 |
+
for c_param, param in zip(self.collected_params, parameters):
|
| 216 |
+
param.data.copy_(c_param.data)
|
| 217 |
+
|
| 218 |
+
@contextlib.contextmanager
|
| 219 |
+
def average_parameters(
|
| 220 |
+
self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
|
| 221 |
+
):
|
| 222 |
+
r"""
|
| 223 |
+
Context manager for validation/inference with averaged parameters.
|
| 224 |
+
|
| 225 |
+
Equivalent to:
|
| 226 |
+
|
| 227 |
+
ema.store()
|
| 228 |
+
ema.copy_to()
|
| 229 |
+
try:
|
| 230 |
+
...
|
| 231 |
+
finally:
|
| 232 |
+
ema.restore()
|
| 233 |
+
|
| 234 |
+
Args:
|
| 235 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 236 |
+
updated with the stored parameters. If `None`, the
|
| 237 |
+
parameters with which this `ExponentialMovingAverage` was
|
| 238 |
+
initialized will be used.
|
| 239 |
+
"""
|
| 240 |
+
parameters = self._get_parameters(parameters)
|
| 241 |
+
self.store(parameters)
|
| 242 |
+
self.copy_to(parameters)
|
| 243 |
+
try:
|
| 244 |
+
yield
|
| 245 |
+
finally:
|
| 246 |
+
if not self.switch:
|
| 247 |
+
self.restore(parameters)
|
| 248 |
+
|
| 249 |
+
def to(self, device=None, dtype=None) -> None:
|
| 250 |
+
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
| 251 |
+
|
| 252 |
+
Args:
|
| 253 |
+
device: like `device` argument to `torch.Tensor.to`
|
| 254 |
+
"""
|
| 255 |
+
# .to() on the tensors handles None correctly
|
| 256 |
+
self.shadow_params = [
|
| 257 |
+
(
|
| 258 |
+
p.to(device=device, dtype=dtype)
|
| 259 |
+
if p.is_floating_point()
|
| 260 |
+
else p.to(device=device)
|
| 261 |
+
)
|
| 262 |
+
for p in self.shadow_params
|
| 263 |
+
]
|
| 264 |
+
if self.collected_params is not None:
|
| 265 |
+
self.collected_params = [
|
| 266 |
+
(
|
| 267 |
+
p.to(device=device, dtype=dtype)
|
| 268 |
+
if p.is_floating_point()
|
| 269 |
+
else p.to(device=device)
|
| 270 |
+
)
|
| 271 |
+
for p in self.collected_params
|
| 272 |
+
]
|
| 273 |
+
return
|
| 274 |
+
|
| 275 |
+
def state_dict(self) -> dict:
|
| 276 |
+
r"""Returns the state of the ExponentialMovingAverage as a dict."""
|
| 277 |
+
# Following PyTorch conventions, references to tensors are returned:
|
| 278 |
+
# "returns a reference to the state and not its copy!" -
|
| 279 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
| 280 |
+
return {
|
| 281 |
+
"decay": self.decay,
|
| 282 |
+
"num_updates": self.num_updates,
|
| 283 |
+
"shadow_params": self.shadow_params,
|
| 284 |
+
"collected_params": self.collected_params,
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 288 |
+
r"""Loads the ExponentialMovingAverage state.
|
| 289 |
+
|
| 290 |
+
Args:
|
| 291 |
+
state_dict (dict): EMA state. Should be an object returned
|
| 292 |
+
from a call to :meth:`state_dict`.
|
| 293 |
+
"""
|
| 294 |
+
# deepcopy, to be consistent with module API
|
| 295 |
+
state_dict = copy.deepcopy(state_dict)
|
| 296 |
+
self.decay = state_dict["decay"]
|
| 297 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
| 298 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 299 |
+
self.num_updates = state_dict["num_updates"]
|
| 300 |
+
assert self.num_updates is None or isinstance(
|
| 301 |
+
self.num_updates, int
|
| 302 |
+
), "Invalid num_updates"
|
| 303 |
+
|
| 304 |
+
self.shadow_params = state_dict["shadow_params"]
|
| 305 |
+
assert isinstance(self.shadow_params, list), "shadow_params must be a list"
|
| 306 |
+
assert all(
|
| 307 |
+
isinstance(p, torch.Tensor) for p in self.shadow_params
|
| 308 |
+
), "shadow_params must all be Tensors"
|
| 309 |
+
|
| 310 |
+
self.collected_params = state_dict["collected_params"]
|
| 311 |
+
if self.collected_params is not None:
|
| 312 |
+
assert isinstance(
|
| 313 |
+
self.collected_params, list
|
| 314 |
+
), "collected_params must be a list"
|
| 315 |
+
assert all(
|
| 316 |
+
isinstance(p, torch.Tensor) for p in self.collected_params
|
| 317 |
+
), "collected_params must all be Tensors"
|
| 318 |
+
assert len(self.collected_params) == len(
|
| 319 |
+
self.shadow_params
|
| 320 |
+
), "collected_params and shadow_params had different lengths"
|
| 321 |
+
|
| 322 |
+
if len(self.shadow_params) == len(self._params_refs):
|
| 323 |
+
# Consistant with torch.optim.Optimizer, cast things to consistant
|
| 324 |
+
# device and dtype with the parameters
|
| 325 |
+
params = [p() for p in self._params_refs]
|
| 326 |
+
# If parameters have been garbage collected, just load the state
|
| 327 |
+
# we were given without change.
|
| 328 |
+
if not any(p is None for p in params):
|
| 329 |
+
# ^ parameter references are still good
|
| 330 |
+
for i, p in enumerate(params):
|
| 331 |
+
self.shadow_params[i] = self.shadow_params[i].to(
|
| 332 |
+
device=p.device, dtype=p.dtype
|
| 333 |
+
)
|
| 334 |
+
if self.collected_params is not None:
|
| 335 |
+
self.collected_params[i] = self.collected_params[i].to(
|
| 336 |
+
device=p.device, dtype=p.dtype
|
| 337 |
+
)
|
| 338 |
+
else:
|
| 339 |
+
raise ValueError(
|
| 340 |
+
"Tried to `load_state_dict()` with the wrong number of "
|
| 341 |
+
"parameters in the saved state."
|
| 342 |
+
)
|
flash3d/unidepth/utils/evaluation_depth.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
# We prefer not to install PyTorch3D in the package
|
| 6 |
+
# Code commented is how 3D metrics are computed
|
| 7 |
+
|
| 8 |
+
from collections import defaultdict
|
| 9 |
+
from functools import partial
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn.functional as F
|
| 13 |
+
|
| 14 |
+
# from chamfer_distance import ChamferDistance
|
| 15 |
+
|
| 16 |
+
from unidepth.utils.constants import DEPTH_BINS
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# chamfer_cls = ChamferDistance()
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
# def chamfer_dist(tensor1, tensor2):
|
| 23 |
+
# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
|
| 24 |
+
# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
|
| 25 |
+
# dist1, dist2, idx1, idx2 = chamfer_cls(
|
| 26 |
+
# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
|
| 27 |
+
# )
|
| 28 |
+
# return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
# def auc(tensor1, tensor2, thresholds):
|
| 32 |
+
# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
|
| 33 |
+
# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
|
| 34 |
+
# dist1, dist2, idx1, idx2 = chamfer_cls(
|
| 35 |
+
# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
|
| 36 |
+
# )
|
| 37 |
+
# # compute precision recall
|
| 38 |
+
# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
|
| 39 |
+
# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
|
| 40 |
+
# auc_value = torch.trapz(
|
| 41 |
+
# torch.tensor(precisions, device=tensor1.device),
|
| 42 |
+
# torch.tensor(recalls, device=tensor1.device),
|
| 43 |
+
# )
|
| 44 |
+
# return auc_value
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def delta(tensor1, tensor2, exponent):
|
| 48 |
+
inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1))
|
| 49 |
+
return (inlier < 1.25**exponent).to(torch.float32).mean()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def ssi(tensor1, tensor2, qtl=0.05):
|
| 53 |
+
stability_mat = 1e-9 * torch.eye(2, device=tensor1.device)
|
| 54 |
+
error = (tensor1 - tensor2).abs()
|
| 55 |
+
mask = error < torch.quantile(error, 1 - qtl)
|
| 56 |
+
tensor1_mask = tensor1[mask]
|
| 57 |
+
tensor2_mask = tensor2[mask]
|
| 58 |
+
tensor2_one = torch.stack(
|
| 59 |
+
[tensor2_mask.detach(), torch.ones_like(tensor2_mask).detach()], dim=1
|
| 60 |
+
)
|
| 61 |
+
scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
|
| 62 |
+
tensor2_one.T @ tensor1_mask.unsqueeze(1)
|
| 63 |
+
)
|
| 64 |
+
scale, shift = scale_shift.squeeze().chunk(2, dim=0)
|
| 65 |
+
return tensor2 * scale + shift
|
| 66 |
+
# tensor2_one = torch.stack([tensor2.detach(), torch.ones_like(tensor2).detach()], dim=1)
|
| 67 |
+
# scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (tensor2_one.T @ tensor1.unsqueeze(1))
|
| 68 |
+
# scale, shift = scale_shift.squeeze().chunk(2, dim=0)
|
| 69 |
+
# return tensor2 * scale + shift
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def d1_ssi(tensor1, tensor2):
|
| 73 |
+
delta_ = delta(tensor1, ssi(tensor1, tensor2), 1.0)
|
| 74 |
+
return delta_
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def d_auc(tensor1, tensor2):
|
| 78 |
+
exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device)
|
| 79 |
+
deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents]
|
| 80 |
+
return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# def f1_score(tensor1, tensor2, thresholds):
|
| 84 |
+
# x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
|
| 85 |
+
# y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
|
| 86 |
+
# dist1, dist2, idx1, idx2 = chamfer_cls(
|
| 87 |
+
# tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
|
| 88 |
+
# )
|
| 89 |
+
# # compute precision recall
|
| 90 |
+
# precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
|
| 91 |
+
# recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
|
| 92 |
+
# precisions = torch.tensor(precisions, device=tensor1.device)
|
| 93 |
+
# recalls = torch.tensor(recalls, device=tensor1.device)
|
| 94 |
+
# f1_thresholds = 2 * precisions * recalls / (precisions + recalls)
|
| 95 |
+
# f1_thresholds = torch.where(
|
| 96 |
+
# torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds
|
| 97 |
+
# )
|
| 98 |
+
# f1_value = torch.trapz(f1_thresholds) / len(thresholds)
|
| 99 |
+
# return f1_value
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
DICT_METRICS = {
|
| 103 |
+
"d1": partial(delta, exponent=1.0),
|
| 104 |
+
"d2": partial(delta, exponent=2.0),
|
| 105 |
+
"d3": partial(delta, exponent=3.0),
|
| 106 |
+
"rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()),
|
| 107 |
+
"rmselog": lambda gt, pred: torch.sqrt(
|
| 108 |
+
((torch.log(gt) - torch.log(pred)) ** 2).mean()
|
| 109 |
+
),
|
| 110 |
+
"arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(),
|
| 111 |
+
"sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(),
|
| 112 |
+
"log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(),
|
| 113 |
+
"silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(),
|
| 114 |
+
"medianlog": lambda gt, pred: 100
|
| 115 |
+
* (torch.log(pred) - torch.log(gt)).median().abs(),
|
| 116 |
+
"d_auc": d_auc,
|
| 117 |
+
"d1_ssi": d1_ssi,
|
| 118 |
+
}
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
# DICT_METRICS_3D = {
|
| 122 |
+
# "chamfer": lambda gt, pred, thresholds: chamfer_dist(
|
| 123 |
+
# gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1)
|
| 124 |
+
# ),
|
| 125 |
+
# "F1": lambda gt, pred, thresholds: f1_score(
|
| 126 |
+
# gt.unsqueeze(0).permute(0, 2, 1),
|
| 127 |
+
# pred.unsqueeze(0).permute(0, 2, 1),
|
| 128 |
+
# thresholds=thresholds,
|
| 129 |
+
# ),
|
| 130 |
+
# }
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
DICT_METRICS_D = {
|
| 134 |
+
"a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to(
|
| 135 |
+
torch.float32
|
| 136 |
+
),
|
| 137 |
+
"abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt),
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def eval_depth(
|
| 142 |
+
gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None
|
| 143 |
+
):
|
| 144 |
+
summary_metrics = defaultdict(list)
|
| 145 |
+
preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear")
|
| 146 |
+
for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
|
| 147 |
+
if max_depth is not None:
|
| 148 |
+
mask = torch.logical_and(mask, gt <= max_depth)
|
| 149 |
+
for name, fn in DICT_METRICS.items():
|
| 150 |
+
summary_metrics[name].append(fn(gt[mask], pred[mask]).mean())
|
| 151 |
+
return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# def eval_3d(
|
| 155 |
+
# gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None
|
| 156 |
+
# ):
|
| 157 |
+
# summary_metrics = defaultdict(list)
|
| 158 |
+
# w_max = min(gts.shape[-1] // 4, 400)
|
| 159 |
+
# gts = F.interpolate(
|
| 160 |
+
# gts, (int(w_max * gts.shape[-2] / gts.shape[-1]), w_max), mode="nearest"
|
| 161 |
+
# )
|
| 162 |
+
# preds = F.interpolate(preds, gts.shape[-2:], mode="nearest")
|
| 163 |
+
# masks = F.interpolate(
|
| 164 |
+
# masks.to(torch.float32), gts.shape[-2:], mode="nearest"
|
| 165 |
+
# ).bool()
|
| 166 |
+
# for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
|
| 167 |
+
# if not torch.any(mask):
|
| 168 |
+
# continue
|
| 169 |
+
# for name, fn in DICT_METRICS_3D.items():
|
| 170 |
+
# summary_metrics[name].append(
|
| 171 |
+
# fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean()
|
| 172 |
+
# )
|
| 173 |
+
# return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
|
flash3d/unidepth/utils/geometric.py
ADDED
|
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Tuple
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def generate_rays(
|
| 13 |
+
camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False
|
| 14 |
+
):
|
| 15 |
+
batch_size, device, dtype = (
|
| 16 |
+
camera_intrinsics.shape[0],
|
| 17 |
+
camera_intrinsics.device,
|
| 18 |
+
camera_intrinsics.dtype,
|
| 19 |
+
)
|
| 20 |
+
height, width = image_shape
|
| 21 |
+
# Generate grid of pixel coordinates
|
| 22 |
+
pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
|
| 23 |
+
pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
|
| 24 |
+
if noisy:
|
| 25 |
+
pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
|
| 26 |
+
pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
|
| 27 |
+
pixel_coords = torch.stack(
|
| 28 |
+
[pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2
|
| 29 |
+
) # (H, W, 2)
|
| 30 |
+
pixel_coords = pixel_coords + 0.5
|
| 31 |
+
|
| 32 |
+
# Calculate ray directions
|
| 33 |
+
intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype) # (B, 3, 3)
|
| 34 |
+
homogeneous_coords = torch.cat(
|
| 35 |
+
[pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2
|
| 36 |
+
) # (H, W, 3)
|
| 37 |
+
ray_directions = torch.matmul(
|
| 38 |
+
intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1)
|
| 39 |
+
) # (3, H*W)
|
| 40 |
+
ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W)
|
| 41 |
+
ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3)
|
| 42 |
+
|
| 43 |
+
theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1])
|
| 44 |
+
phi = torch.acos(ray_directions[..., 1])
|
| 45 |
+
# pitch = torch.asin(ray_directions[..., 1])
|
| 46 |
+
# roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1])
|
| 47 |
+
angles = torch.stack([theta, phi], dim=-1)
|
| 48 |
+
return ray_directions, angles
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
@torch.jit.script
|
| 52 |
+
def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
|
| 53 |
+
theta = spherical_tensor[..., 0] # Extract polar angle
|
| 54 |
+
phi = spherical_tensor[..., 1] # Extract azimuthal angle
|
| 55 |
+
z = spherical_tensor[..., 2] # Extract zbuffer depth
|
| 56 |
+
|
| 57 |
+
# y = r * cos(phi)
|
| 58 |
+
# x = r * sin(phi) * sin(theta)
|
| 59 |
+
# z = r * sin(phi) * cos(theta)
|
| 60 |
+
# =>
|
| 61 |
+
# r = z / sin(phi) / cos(theta)
|
| 62 |
+
# y = z / (sin(phi) / cos(phi)) / cos(theta)
|
| 63 |
+
# x = z * sin(theta) / cos(theta)
|
| 64 |
+
x = z * torch.tan(theta)
|
| 65 |
+
y = z / torch.tan(phi) / torch.cos(theta)
|
| 66 |
+
|
| 67 |
+
euclidean_tensor = torch.stack((x, y, z), dim=-1)
|
| 68 |
+
return euclidean_tensor
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@torch.jit.script
|
| 72 |
+
def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
|
| 73 |
+
theta = spherical_tensor[..., 0] # Extract polar angle
|
| 74 |
+
phi = spherical_tensor[..., 1] # Extract azimuthal angle
|
| 75 |
+
r = spherical_tensor[..., 2] # Extract radius
|
| 76 |
+
# y = r * cos(phi)
|
| 77 |
+
# x = r * sin(phi) * sin(theta)
|
| 78 |
+
# z = r * sin(phi) * cos(theta)
|
| 79 |
+
x = r * torch.sin(phi) * torch.sin(theta)
|
| 80 |
+
y = r * torch.cos(phi)
|
| 81 |
+
z = r * torch.cos(theta) * torch.sin(phi)
|
| 82 |
+
|
| 83 |
+
euclidean_tensor = torch.stack((x, y, z), dim=-1)
|
| 84 |
+
return euclidean_tensor
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
@torch.jit.script
|
| 88 |
+
def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
|
| 89 |
+
x = spherical_tensor[..., 0] # Extract polar angle
|
| 90 |
+
y = spherical_tensor[..., 1] # Extract azimuthal angle
|
| 91 |
+
z = spherical_tensor[..., 2] # Extract radius
|
| 92 |
+
# y = r * cos(phi)
|
| 93 |
+
# x = r * sin(phi) * sin(theta)
|
| 94 |
+
# z = r * sin(phi) * cos(theta)
|
| 95 |
+
r = torch.sqrt(x**2 + y**2 + z**2)
|
| 96 |
+
theta = torch.atan2(x / r, z / r)
|
| 97 |
+
phi = torch.acos(y / r)
|
| 98 |
+
|
| 99 |
+
euclidean_tensor = torch.stack((theta, phi, r), dim=-1)
|
| 100 |
+
return euclidean_tensor
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@torch.jit.script
|
| 104 |
+
def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
|
| 105 |
+
pitch = torch.asin(euclidean_tensor[..., 1])
|
| 106 |
+
yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
|
| 107 |
+
z = euclidean_tensor[..., 2] # Extract zbuffer depth
|
| 108 |
+
euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1)
|
| 109 |
+
return euclidean_tensor
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
@torch.jit.script
|
| 113 |
+
def unproject_points(
|
| 114 |
+
depth: torch.Tensor, camera_intrinsics: torch.Tensor
|
| 115 |
+
) -> torch.Tensor:
|
| 116 |
+
"""
|
| 117 |
+
Unprojects a batch of depth maps to 3D point clouds using camera intrinsics.
|
| 118 |
+
|
| 119 |
+
Args:
|
| 120 |
+
depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W).
|
| 121 |
+
camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3).
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W).
|
| 125 |
+
"""
|
| 126 |
+
batch_size, _, height, width = depth.shape
|
| 127 |
+
device = depth.device
|
| 128 |
+
|
| 129 |
+
# Create pixel grid
|
| 130 |
+
y_coords, x_coords = torch.meshgrid(
|
| 131 |
+
torch.arange(height, device=device),
|
| 132 |
+
torch.arange(width, device=device),
|
| 133 |
+
indexing="ij",
|
| 134 |
+
)
|
| 135 |
+
pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2)
|
| 136 |
+
|
| 137 |
+
# Get homogeneous coords (u v 1)
|
| 138 |
+
pixel_coords_homogeneous = torch.cat(
|
| 139 |
+
(pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1
|
| 140 |
+
)
|
| 141 |
+
pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten(
|
| 142 |
+
1
|
| 143 |
+
) # (3, H*W)
|
| 144 |
+
# Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W]
|
| 145 |
+
unprojected_points = torch.matmul(
|
| 146 |
+
torch.inverse(camera_intrinsics), pixel_coords_homogeneous
|
| 147 |
+
) # (B, 3, H*W)
|
| 148 |
+
unprojected_points = unprojected_points.view(
|
| 149 |
+
batch_size, 3, height, width
|
| 150 |
+
) # (B, 3, H, W)
|
| 151 |
+
unprojected_points = unprojected_points * depth # (B, 3, H, W)
|
| 152 |
+
return unprojected_points
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
@torch.jit.script
|
| 156 |
+
def project_points(
|
| 157 |
+
points_3d: torch.Tensor,
|
| 158 |
+
intrinsic_matrix: torch.Tensor,
|
| 159 |
+
image_shape: Tuple[int, int],
|
| 160 |
+
) -> torch.Tensor:
|
| 161 |
+
# Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T
|
| 162 |
+
points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2))
|
| 163 |
+
|
| 164 |
+
# Normalize projected points: (u v w) -> (u / w, v / w, 1)
|
| 165 |
+
points_2d = points_2d[..., :2] / points_2d[..., 2:]
|
| 166 |
+
|
| 167 |
+
# To pixels (rounding!!!), no int as it breaks gradient
|
| 168 |
+
points_2d = points_2d.round()
|
| 169 |
+
|
| 170 |
+
# pointa need to be inside the image (can it diverge onto all points out???)
|
| 171 |
+
valid_mask = (
|
| 172 |
+
(points_2d[..., 0] >= 0)
|
| 173 |
+
& (points_2d[..., 0] < image_shape[1])
|
| 174 |
+
& (points_2d[..., 1] >= 0)
|
| 175 |
+
& (points_2d[..., 1] < image_shape[0])
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
# Calculate the flat indices of the valid pixels
|
| 179 |
+
flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1]
|
| 180 |
+
flat_indices = flat_points_2d.long()
|
| 181 |
+
|
| 182 |
+
# Create depth maps and counts using scatter_add, (B, H, W)
|
| 183 |
+
depth_maps = torch.zeros(
|
| 184 |
+
[points_3d.shape[0], *image_shape], device=points_3d.device
|
| 185 |
+
)
|
| 186 |
+
counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device)
|
| 187 |
+
|
| 188 |
+
# Loop over batches to apply masks and accumulate depth/count values
|
| 189 |
+
for i in range(points_3d.shape[0]):
|
| 190 |
+
valid_indices = flat_indices[i, valid_mask[i]]
|
| 191 |
+
depth_maps[i].view(-1).scatter_add_(
|
| 192 |
+
0, valid_indices, points_3d[i, valid_mask[i], 2]
|
| 193 |
+
)
|
| 194 |
+
counts[i].view(-1).scatter_add_(
|
| 195 |
+
0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2])
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
# Calculate mean depth for each pixel in each batch
|
| 199 |
+
mean_depth_maps = depth_maps / counts.clamp(min=1.0)
|
| 200 |
+
return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W)
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
@torch.jit.script
|
| 204 |
+
def downsample(data: torch.Tensor, downsample_factor: int = 2):
|
| 205 |
+
N, _, H, W = data.shape
|
| 206 |
+
data = data.view(
|
| 207 |
+
N,
|
| 208 |
+
H // downsample_factor,
|
| 209 |
+
downsample_factor,
|
| 210 |
+
W // downsample_factor,
|
| 211 |
+
downsample_factor,
|
| 212 |
+
1,
|
| 213 |
+
)
|
| 214 |
+
data = data.permute(0, 1, 3, 5, 2, 4).contiguous()
|
| 215 |
+
data = data.view(-1, downsample_factor * downsample_factor)
|
| 216 |
+
data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data)
|
| 217 |
+
data = torch.min(data_tmp, dim=-1).values
|
| 218 |
+
data = data.view(N, 1, H // downsample_factor, W // downsample_factor)
|
| 219 |
+
data = torch.where(data > 1000, torch.zeros_like(data), data)
|
| 220 |
+
return data
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@torch.jit.script
|
| 224 |
+
def flat_interpolate(
|
| 225 |
+
flat_tensor: torch.Tensor,
|
| 226 |
+
old: Tuple[int, int],
|
| 227 |
+
new: Tuple[int, int],
|
| 228 |
+
antialias: bool = True,
|
| 229 |
+
mode: str = "bilinear",
|
| 230 |
+
) -> torch.Tensor:
|
| 231 |
+
if old[0] == new[0] and old[1] == new[1]:
|
| 232 |
+
return flat_tensor
|
| 233 |
+
tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
|
| 234 |
+
0, 3, 1, 2
|
| 235 |
+
) # b c h w
|
| 236 |
+
tensor_interp = F.interpolate(
|
| 237 |
+
tensor,
|
| 238 |
+
size=(new[0], new[1]),
|
| 239 |
+
mode=mode,
|
| 240 |
+
align_corners=False,
|
| 241 |
+
antialias=antialias,
|
| 242 |
+
)
|
| 243 |
+
flat_tensor_interp = tensor_interp.view(
|
| 244 |
+
flat_tensor.shape[0], -1, new[0] * new[1]
|
| 245 |
+
).permute(
|
| 246 |
+
0, 2, 1
|
| 247 |
+
) # b (h w) c
|
| 248 |
+
return flat_tensor_interp.contiguous()
|
flash3d/unidepth/utils/misc.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from functools import wraps
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from scipy import interpolate
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
|
| 15 |
+
from einops import rearrange, repeat, reduce
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def max_stack(tensors):
|
| 19 |
+
return torch.stack(tensors, dim=-1).max(dim=-1)[0]
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def softmax_stack(tensors, temperature=1.0):
|
| 23 |
+
return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def mean_stack(tensors):
|
| 27 |
+
if len(tensors) == 1:
|
| 28 |
+
return tensors[0]
|
| 29 |
+
return torch.stack(tensors, dim=-1).mean(dim=-1)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def sum_stack(tensors):
|
| 33 |
+
return torch.stack(tensors, dim=-1).sum(dim=-1)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def convert_module_to_f16(l):
|
| 37 |
+
"""
|
| 38 |
+
Convert primitive modules to float16.
|
| 39 |
+
"""
|
| 40 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
| 41 |
+
l.weight.data = l.weight.data.half()
|
| 42 |
+
if l.bias is not None:
|
| 43 |
+
l.bias.data = l.bias.data.half()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def convert_module_to_f32(l):
|
| 47 |
+
"""
|
| 48 |
+
Convert primitive modules to float32, undoing convert_module_to_f16().
|
| 49 |
+
"""
|
| 50 |
+
if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
|
| 51 |
+
l.weight.data = l.weight.data.float()
|
| 52 |
+
if l.bias is not None:
|
| 53 |
+
l.bias.data = l.bias.data.float()
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def format_seconds(seconds):
|
| 57 |
+
minutes, seconds = divmod(seconds, 60)
|
| 58 |
+
hours, minutes = divmod(minutes, 60)
|
| 59 |
+
return f"{hours:d}:{minutes:02d}:{seconds:02d}"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def get_params(module, lr, wd):
|
| 63 |
+
skip_list = {}
|
| 64 |
+
skip_keywords = {}
|
| 65 |
+
if hasattr(module, "no_weight_decay"):
|
| 66 |
+
skip_list = module.no_weight_decay()
|
| 67 |
+
if hasattr(module, "no_weight_decay_keywords"):
|
| 68 |
+
skip_keywords = module.no_weight_decay_keywords()
|
| 69 |
+
has_decay = []
|
| 70 |
+
no_decay = []
|
| 71 |
+
for name, param in module.named_parameters():
|
| 72 |
+
if not param.requires_grad:
|
| 73 |
+
continue # frozen weights
|
| 74 |
+
if (
|
| 75 |
+
(name in skip_list)
|
| 76 |
+
or any((kw in name for kw in skip_keywords))
|
| 77 |
+
or len(param.shape) == 1
|
| 78 |
+
):
|
| 79 |
+
# if (name in skip_list) or any((kw in name for kw in skip_keywords)):
|
| 80 |
+
# print(name, skip_keywords)
|
| 81 |
+
no_decay.append(param)
|
| 82 |
+
else:
|
| 83 |
+
has_decay.append(param)
|
| 84 |
+
|
| 85 |
+
group1 = {
|
| 86 |
+
"params": has_decay,
|
| 87 |
+
"weight_decay": wd,
|
| 88 |
+
"lr": lr,
|
| 89 |
+
"weight_decay_init": wd,
|
| 90 |
+
"weight_decay_base": wd,
|
| 91 |
+
"lr_init": lr,
|
| 92 |
+
"lr_base": lr,
|
| 93 |
+
}
|
| 94 |
+
group2 = {
|
| 95 |
+
"params": no_decay,
|
| 96 |
+
"weight_decay": 0.0,
|
| 97 |
+
"lr": lr,
|
| 98 |
+
"weight_decay_init": 0.0,
|
| 99 |
+
"weight_decay_base": 0.0,
|
| 100 |
+
"weight_decay_final": 0.0,
|
| 101 |
+
"lr_init": lr,
|
| 102 |
+
"lr_base": lr,
|
| 103 |
+
}
|
| 104 |
+
return [group1, group2], [lr, lr]
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
|
| 108 |
+
if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"):
|
| 109 |
+
return 0
|
| 110 |
+
elif var_name.startswith("patch_embed"):
|
| 111 |
+
return 0
|
| 112 |
+
elif var_name.startswith("layers"):
|
| 113 |
+
if var_name.split(".")[2] == "blocks":
|
| 114 |
+
stage_id = int(var_name.split(".")[1])
|
| 115 |
+
layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id])
|
| 116 |
+
return layer_id + 1
|
| 117 |
+
elif var_name.split(".")[2] == "downsample":
|
| 118 |
+
stage_id = int(var_name.split(".")[1])
|
| 119 |
+
layer_id = sum(layers_per_stage[: stage_id + 1])
|
| 120 |
+
return layer_id
|
| 121 |
+
else:
|
| 122 |
+
return num_max_layer - 1
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def get_params_layerdecayswin(module, lr, wd, ld):
|
| 126 |
+
skip_list = {}
|
| 127 |
+
skip_keywords = {}
|
| 128 |
+
if hasattr(module, "no_weight_decay"):
|
| 129 |
+
skip_list = module.no_weight_decay()
|
| 130 |
+
if hasattr(module, "no_weight_decay_keywords"):
|
| 131 |
+
skip_keywords = module.no_weight_decay_keywords()
|
| 132 |
+
layers_per_stage = module.depths
|
| 133 |
+
num_layers = sum(layers_per_stage) + 1
|
| 134 |
+
lrs = []
|
| 135 |
+
params = []
|
| 136 |
+
for name, param in module.named_parameters():
|
| 137 |
+
if not param.requires_grad:
|
| 138 |
+
print(f"{name} frozen")
|
| 139 |
+
continue # frozen weights
|
| 140 |
+
layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage)
|
| 141 |
+
lr_cur = lr * ld ** (num_layers - layer_id - 1)
|
| 142 |
+
# if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"):
|
| 143 |
+
if (name in skip_list) or any((kw in name for kw in skip_keywords)):
|
| 144 |
+
wd_cur = 0.0
|
| 145 |
+
else:
|
| 146 |
+
wd_cur = wd
|
| 147 |
+
params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur})
|
| 148 |
+
lrs.append(lr_cur)
|
| 149 |
+
return params, lrs
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def log(t, eps: float = 1e-5):
|
| 153 |
+
return torch.log(t.clamp(min=eps))
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def l2norm(t):
|
| 157 |
+
return F.normalize(t, dim=-1)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def exists(val):
|
| 161 |
+
return val is not None
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def identity(t, *args, **kwargs):
|
| 165 |
+
return t
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def divisible_by(numer, denom):
|
| 169 |
+
return (numer % denom) == 0
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def first(arr, d=None):
|
| 173 |
+
if len(arr) == 0:
|
| 174 |
+
return d
|
| 175 |
+
return arr[0]
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def default(val, d):
|
| 179 |
+
if exists(val):
|
| 180 |
+
return val
|
| 181 |
+
return d() if callable(d) else d
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
def maybe(fn):
|
| 185 |
+
@wraps(fn)
|
| 186 |
+
def inner(x):
|
| 187 |
+
if not exists(x):
|
| 188 |
+
return x
|
| 189 |
+
return fn(x)
|
| 190 |
+
|
| 191 |
+
return inner
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def once(fn):
|
| 195 |
+
called = False
|
| 196 |
+
|
| 197 |
+
@wraps(fn)
|
| 198 |
+
def inner(x):
|
| 199 |
+
nonlocal called
|
| 200 |
+
if called:
|
| 201 |
+
return
|
| 202 |
+
called = True
|
| 203 |
+
return fn(x)
|
| 204 |
+
|
| 205 |
+
return inner
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def _many(fn):
|
| 209 |
+
@wraps(fn)
|
| 210 |
+
def inner(tensors, pattern, **kwargs):
|
| 211 |
+
return (fn(tensor, pattern, **kwargs) for tensor in tensors)
|
| 212 |
+
|
| 213 |
+
return inner
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
rearrange_many = _many(rearrange)
|
| 217 |
+
repeat_many = _many(repeat)
|
| 218 |
+
reduce_many = _many(reduce)
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def load_pretrained(state_dict, checkpoint):
|
| 222 |
+
checkpoint_model = checkpoint["model"]
|
| 223 |
+
if any([True if "encoder." in k else False for k in checkpoint_model.keys()]):
|
| 224 |
+
checkpoint_model = {
|
| 225 |
+
k.replace("encoder.", ""): v
|
| 226 |
+
for k, v in checkpoint_model.items()
|
| 227 |
+
if k.startswith("encoder.")
|
| 228 |
+
}
|
| 229 |
+
print("Detect pre-trained model, remove [encoder.] prefix.")
|
| 230 |
+
else:
|
| 231 |
+
print("Detect non-pre-trained model, pass without doing anything.")
|
| 232 |
+
print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
|
| 233 |
+
checkpoint = load_checkpoint_swin(state_dict, checkpoint_model)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def load_checkpoint_swin(model, checkpoint_model):
|
| 237 |
+
state_dict = model.state_dict()
|
| 238 |
+
# Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
|
| 239 |
+
all_keys = list(checkpoint_model.keys())
|
| 240 |
+
for key in all_keys:
|
| 241 |
+
if "relative_position_bias_table" in key:
|
| 242 |
+
relative_position_bias_table_pretrained = checkpoint_model[key]
|
| 243 |
+
relative_position_bias_table_current = state_dict[key]
|
| 244 |
+
L1, nH1 = relative_position_bias_table_pretrained.size()
|
| 245 |
+
L2, nH2 = relative_position_bias_table_current.size()
|
| 246 |
+
if nH1 != nH2:
|
| 247 |
+
print(f"Error in loading {key}, passing......")
|
| 248 |
+
else:
|
| 249 |
+
if L1 != L2:
|
| 250 |
+
print(f"{key}: Interpolate relative_position_bias_table using geo.")
|
| 251 |
+
src_size = int(L1**0.5)
|
| 252 |
+
dst_size = int(L2**0.5)
|
| 253 |
+
|
| 254 |
+
def geometric_progression(a, r, n):
|
| 255 |
+
return a * (1.0 - r**n) / (1.0 - r)
|
| 256 |
+
|
| 257 |
+
left, right = 1.01, 1.5
|
| 258 |
+
while right - left > 1e-6:
|
| 259 |
+
q = (left + right) / 2.0
|
| 260 |
+
gp = geometric_progression(1, q, src_size // 2)
|
| 261 |
+
if gp > dst_size // 2:
|
| 262 |
+
right = q
|
| 263 |
+
else:
|
| 264 |
+
left = q
|
| 265 |
+
|
| 266 |
+
# if q > 1.090307:
|
| 267 |
+
# q = 1.090307
|
| 268 |
+
|
| 269 |
+
dis = []
|
| 270 |
+
cur = 1
|
| 271 |
+
for i in range(src_size // 2):
|
| 272 |
+
dis.append(cur)
|
| 273 |
+
cur += q ** (i + 1)
|
| 274 |
+
|
| 275 |
+
r_ids = [-_ for _ in reversed(dis)]
|
| 276 |
+
|
| 277 |
+
x = r_ids + [0] + dis
|
| 278 |
+
y = r_ids + [0] + dis
|
| 279 |
+
|
| 280 |
+
t = dst_size // 2.0
|
| 281 |
+
dx = np.arange(-t, t + 0.1, 1.0)
|
| 282 |
+
dy = np.arange(-t, t + 0.1, 1.0)
|
| 283 |
+
|
| 284 |
+
print("Original positions = %s" % str(x))
|
| 285 |
+
print("Target positions = %s" % str(dx))
|
| 286 |
+
|
| 287 |
+
all_rel_pos_bias = []
|
| 288 |
+
|
| 289 |
+
for i in range(nH1):
|
| 290 |
+
z = (
|
| 291 |
+
relative_position_bias_table_pretrained[:, i]
|
| 292 |
+
.view(src_size, src_size)
|
| 293 |
+
.float()
|
| 294 |
+
.numpy()
|
| 295 |
+
)
|
| 296 |
+
f_cubic = interpolate.interp2d(x, y, z, kind="cubic")
|
| 297 |
+
all_rel_pos_bias.append(
|
| 298 |
+
torch.Tensor(f_cubic(dx, dy))
|
| 299 |
+
.contiguous()
|
| 300 |
+
.view(-1, 1)
|
| 301 |
+
.to(relative_position_bias_table_pretrained.device)
|
| 302 |
+
)
|
| 303 |
+
|
| 304 |
+
new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
|
| 305 |
+
checkpoint_model[key] = new_rel_pos_bias
|
| 306 |
+
|
| 307 |
+
# delete relative_position_index since we always re-init it
|
| 308 |
+
relative_position_index_keys = [
|
| 309 |
+
k for k in checkpoint_model.keys() if "relative_position_index" in k
|
| 310 |
+
]
|
| 311 |
+
for k in relative_position_index_keys:
|
| 312 |
+
del checkpoint_model[k]
|
| 313 |
+
|
| 314 |
+
# delete relative_coords_table since we always re-init it
|
| 315 |
+
relative_coords_table_keys = [
|
| 316 |
+
k for k in checkpoint_model.keys() if "relative_coords_table" in k
|
| 317 |
+
]
|
| 318 |
+
for k in relative_coords_table_keys:
|
| 319 |
+
del checkpoint_model[k]
|
| 320 |
+
|
| 321 |
+
# # re-map keys due to name change
|
| 322 |
+
rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k]
|
| 323 |
+
for k in rpe_mlp_keys:
|
| 324 |
+
checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k)
|
| 325 |
+
|
| 326 |
+
# delete attn_mask since we always re-init it
|
| 327 |
+
attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
|
| 328 |
+
for k in attn_mask_keys:
|
| 329 |
+
del checkpoint_model[k]
|
| 330 |
+
|
| 331 |
+
encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")]
|
| 332 |
+
for k in encoder_keys:
|
| 333 |
+
checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k)
|
| 334 |
+
|
| 335 |
+
return checkpoint_model
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def add_padding_metas(out, image_metas):
|
| 339 |
+
device = out.device
|
| 340 |
+
# left, right, top, bottom
|
| 341 |
+
paddings = [img_meta.get("padding_size", [0] * 4) for img_meta in image_metas]
|
| 342 |
+
paddings = torch.stack(paddings).to(device)
|
| 343 |
+
outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)]
|
| 344 |
+
return torch.stack(outs)
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def remove_padding(out, paddings):
|
| 348 |
+
B, C, H, W = out.shape
|
| 349 |
+
device = out.device
|
| 350 |
+
# left, right, top, bottom
|
| 351 |
+
paddings = torch.stack(paddings).to(device)
|
| 352 |
+
outs = [
|
| 353 |
+
o[:, padding[1] : H - padding[3], padding[0] : W - padding[2]]
|
| 354 |
+
for padding, o in zip(paddings, out)
|
| 355 |
+
]
|
| 356 |
+
return torch.stack(outs)
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def remove_padding_metas(out, image_metas):
|
| 360 |
+
B, C, H, W = out.shape
|
| 361 |
+
device = out.device
|
| 362 |
+
# left, right, top, bottom
|
| 363 |
+
paddings = [
|
| 364 |
+
torch.tensor(img_meta.get("padding_size", [0] * 4)) for img_meta in image_metas
|
| 365 |
+
]
|
| 366 |
+
return remove_padding(out, paddings)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def ssi_helper(tensor1, tensor2):
|
| 370 |
+
stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
|
| 371 |
+
tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1)
|
| 372 |
+
scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
|
| 373 |
+
tensor2_one.T @ tensor1.unsqueeze(1)
|
| 374 |
+
)
|
| 375 |
+
scale, shift = scale_shift.squeeze().chunk(2, dim=0)
|
| 376 |
+
return scale, shift
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
def calculate_mean_values(names, values):
|
| 380 |
+
# Create a defaultdict to store sum and count for each name
|
| 381 |
+
name_values = {name: {} for name in names}
|
| 382 |
+
|
| 383 |
+
# Iterate through the lists and accumulate values for each name
|
| 384 |
+
for name, value in zip(names, values):
|
| 385 |
+
name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value
|
| 386 |
+
name_values[name]["count"] = name_values[name].get("count", 0.0) + 1
|
| 387 |
+
|
| 388 |
+
# Calculate mean values and create the output dictionary
|
| 389 |
+
output_dict = {
|
| 390 |
+
name: name_values[name]["sum"] / name_values[name]["count"]
|
| 391 |
+
for name in name_values
|
| 392 |
+
}
|
| 393 |
+
|
| 394 |
+
return output_dict
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
def remove_leading_dim(infos):
|
| 398 |
+
if isinstance(infos, dict):
|
| 399 |
+
return {k: remove_leading_dim(v) for k, v in infos.items()}
|
| 400 |
+
elif isinstance(infos, torch.Tensor):
|
| 401 |
+
return infos.squeeze(0)
|
| 402 |
+
else:
|
| 403 |
+
return infos
|
flash3d/unidepth/utils/positional_embedding.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from math import pi
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
|
| 12 |
+
from einops import rearrange, repeat
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PositionEmbeddingSine(nn.Module):
|
| 16 |
+
def __init__(
|
| 17 |
+
self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
|
| 18 |
+
):
|
| 19 |
+
super().__init__()
|
| 20 |
+
self.num_pos_feats = num_pos_feats
|
| 21 |
+
self.temperature = temperature
|
| 22 |
+
self.normalize = normalize
|
| 23 |
+
if scale is not None and normalize is False:
|
| 24 |
+
raise ValueError("normalize should be True if scale is passed")
|
| 25 |
+
if scale is None:
|
| 26 |
+
scale = 2 * pi
|
| 27 |
+
self.scale = scale
|
| 28 |
+
|
| 29 |
+
def forward(
|
| 30 |
+
self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
|
| 31 |
+
) -> torch.Tensor:
|
| 32 |
+
if mask is None:
|
| 33 |
+
mask = torch.zeros(
|
| 34 |
+
(x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
|
| 35 |
+
)
|
| 36 |
+
not_mask = ~mask
|
| 37 |
+
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
| 38 |
+
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
| 39 |
+
if self.normalize:
|
| 40 |
+
eps = 1e-6
|
| 41 |
+
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
| 42 |
+
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
| 43 |
+
|
| 44 |
+
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
| 45 |
+
dim_t = self.temperature ** (
|
| 46 |
+
2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
pos_x = x_embed[:, :, :, None] / dim_t
|
| 50 |
+
pos_y = y_embed[:, :, :, None] / dim_t
|
| 51 |
+
pos_x = torch.stack(
|
| 52 |
+
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
| 53 |
+
).flatten(3)
|
| 54 |
+
pos_y = torch.stack(
|
| 55 |
+
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
| 56 |
+
).flatten(3)
|
| 57 |
+
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
| 58 |
+
return pos
|
| 59 |
+
|
| 60 |
+
def __repr__(self, _repr_indent=4):
|
| 61 |
+
head = "Positional encoding " + self.__class__.__name__
|
| 62 |
+
body = [
|
| 63 |
+
"num_pos_feats: {}".format(self.num_pos_feats),
|
| 64 |
+
"temperature: {}".format(self.temperature),
|
| 65 |
+
"normalize: {}".format(self.normalize),
|
| 66 |
+
"scale: {}".format(self.scale),
|
| 67 |
+
]
|
| 68 |
+
# _repr_indent = 4
|
| 69 |
+
lines = [head] + [" " * _repr_indent + line for line in body]
|
| 70 |
+
return "\n".join(lines)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class LearnedSinusoidalPosEmb(nn.Module):
|
| 74 |
+
def __init__(self, dim):
|
| 75 |
+
super().__init__()
|
| 76 |
+
assert (dim % 2) == 0
|
| 77 |
+
half_dim = dim // 2
|
| 78 |
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
| 79 |
+
|
| 80 |
+
def forward(self, x):
|
| 81 |
+
x = rearrange(x, "b -> b 1")
|
| 82 |
+
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
|
| 83 |
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
| 84 |
+
fouriered = torch.cat((x, fouriered), dim=-1)
|
| 85 |
+
return fouriered
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def broadcat(tensors, dim=-1):
|
| 89 |
+
num_tensors = len(tensors)
|
| 90 |
+
shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
|
| 91 |
+
assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
|
| 92 |
+
shape_len = list(shape_lens)[0]
|
| 93 |
+
dim = (dim + shape_len) if dim < 0 else dim
|
| 94 |
+
dims = list(zip(*map(lambda t: list(t.shape), tensors)))
|
| 95 |
+
expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
|
| 96 |
+
assert all(
|
| 97 |
+
[*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
|
| 98 |
+
), "invalid dimensions for broadcastable concatentation"
|
| 99 |
+
max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
|
| 100 |
+
expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
|
| 101 |
+
expanded_dims.insert(dim, (dim, dims[dim]))
|
| 102 |
+
expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
|
| 103 |
+
tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
|
| 104 |
+
return torch.cat(tensors, dim=dim)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def rotate_half(x):
|
| 108 |
+
x = rearrange(x, "... (d r) -> ... d r", r=2)
|
| 109 |
+
x1, x2 = x.unbind(dim=-1)
|
| 110 |
+
x = torch.stack((-x2, x1), dim=-1)
|
| 111 |
+
return rearrange(x, "... d r -> ... (d r)")
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
class VisionRotaryEmbedding(nn.Module):
|
| 115 |
+
def __init__(
|
| 116 |
+
self,
|
| 117 |
+
dim,
|
| 118 |
+
pt_seq_len,
|
| 119 |
+
ft_seq_len=None,
|
| 120 |
+
custom_freqs=None,
|
| 121 |
+
freqs_for="lang",
|
| 122 |
+
theta=10000,
|
| 123 |
+
max_freq=10,
|
| 124 |
+
num_freqs=1,
|
| 125 |
+
):
|
| 126 |
+
super().__init__()
|
| 127 |
+
if custom_freqs:
|
| 128 |
+
freqs = custom_freqs
|
| 129 |
+
elif freqs_for == "lang":
|
| 130 |
+
freqs = 1.0 / (
|
| 131 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 132 |
+
)
|
| 133 |
+
elif freqs_for == "pixel":
|
| 134 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 135 |
+
elif freqs_for == "constant":
|
| 136 |
+
freqs = torch.ones(num_freqs).float()
|
| 137 |
+
else:
|
| 138 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
| 139 |
+
|
| 140 |
+
if ft_seq_len is None:
|
| 141 |
+
ft_seq_len = pt_seq_len
|
| 142 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 143 |
+
|
| 144 |
+
freqs_h = torch.einsum("..., f -> ... f", t, freqs)
|
| 145 |
+
freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
|
| 146 |
+
|
| 147 |
+
freqs_w = torch.einsum("..., f -> ... f", t, freqs)
|
| 148 |
+
freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
|
| 149 |
+
|
| 150 |
+
freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
|
| 151 |
+
|
| 152 |
+
self.register_buffer("freqs_cos", freqs.cos())
|
| 153 |
+
self.register_buffer("freqs_sin", freqs.sin())
|
| 154 |
+
|
| 155 |
+
print("======== shape of rope freq", self.freqs_cos.shape, "========")
|
| 156 |
+
|
| 157 |
+
def forward(self, t, start_index=0):
|
| 158 |
+
rot_dim = self.freqs_cos.shape[-1]
|
| 159 |
+
end_index = start_index + rot_dim
|
| 160 |
+
assert (
|
| 161 |
+
rot_dim <= t.shape[-1]
|
| 162 |
+
), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
|
| 163 |
+
t_left, t, t_right = (
|
| 164 |
+
t[..., :start_index],
|
| 165 |
+
t[..., start_index:end_index],
|
| 166 |
+
t[..., end_index:],
|
| 167 |
+
)
|
| 168 |
+
t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
|
| 169 |
+
return torch.cat((t_left, t, t_right), dim=-1)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
class VisionRotaryEmbeddingFast(nn.Module):
|
| 173 |
+
def __init__(
|
| 174 |
+
self,
|
| 175 |
+
dim,
|
| 176 |
+
pt_seq_len,
|
| 177 |
+
ft_seq_len=None,
|
| 178 |
+
custom_freqs=None,
|
| 179 |
+
freqs_for="lang",
|
| 180 |
+
theta=10000,
|
| 181 |
+
max_freq=10,
|
| 182 |
+
num_freqs=1,
|
| 183 |
+
):
|
| 184 |
+
super().__init__()
|
| 185 |
+
if custom_freqs:
|
| 186 |
+
freqs = custom_freqs
|
| 187 |
+
elif freqs_for == "lang":
|
| 188 |
+
freqs = 1.0 / (
|
| 189 |
+
theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
|
| 190 |
+
)
|
| 191 |
+
elif freqs_for == "pixel":
|
| 192 |
+
freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
|
| 193 |
+
elif freqs_for == "constant":
|
| 194 |
+
freqs = torch.ones(num_freqs).float()
|
| 195 |
+
else:
|
| 196 |
+
raise ValueError(f"unknown modality {freqs_for}")
|
| 197 |
+
|
| 198 |
+
if ft_seq_len is None:
|
| 199 |
+
ft_seq_len = pt_seq_len
|
| 200 |
+
t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
|
| 201 |
+
|
| 202 |
+
freqs = torch.einsum("..., f -> ... f", t, freqs)
|
| 203 |
+
freqs = repeat(freqs, "... n -> ... (n r)", r=2)
|
| 204 |
+
freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
|
| 205 |
+
|
| 206 |
+
freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
|
| 207 |
+
freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
|
| 208 |
+
|
| 209 |
+
self.register_buffer("freqs_cos", freqs_cos)
|
| 210 |
+
self.register_buffer("freqs_sin", freqs_sin)
|
| 211 |
+
|
| 212 |
+
def forward(self, t):
|
| 213 |
+
return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
from math import log2
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def generate_fourier_features(
|
| 220 |
+
x: torch.Tensor,
|
| 221 |
+
dim: int = 512,
|
| 222 |
+
max_freq: int = 64,
|
| 223 |
+
use_cos: bool = False,
|
| 224 |
+
use_log: bool = False,
|
| 225 |
+
cat_orig: bool = False,
|
| 226 |
+
):
|
| 227 |
+
x_orig = x
|
| 228 |
+
device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
|
| 229 |
+
num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
|
| 230 |
+
|
| 231 |
+
if use_log:
|
| 232 |
+
scales = 2.0 ** torch.linspace(
|
| 233 |
+
0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
|
| 234 |
+
)
|
| 235 |
+
else:
|
| 236 |
+
scales = torch.linspace(
|
| 237 |
+
1.0, max_freq / 2, num_bands, device=device, dtype=dtype
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
x = x.unsqueeze(-1)
|
| 241 |
+
scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
|
| 242 |
+
|
| 243 |
+
x = x * scales * pi
|
| 244 |
+
x = torch.cat(
|
| 245 |
+
(
|
| 246 |
+
[x.sin(), x.cos()]
|
| 247 |
+
if use_cos
|
| 248 |
+
else [
|
| 249 |
+
x.sin(),
|
| 250 |
+
]
|
| 251 |
+
),
|
| 252 |
+
dim=-1,
|
| 253 |
+
)
|
| 254 |
+
x = x.flatten(-2)
|
| 255 |
+
if cat_orig:
|
| 256 |
+
return torch.cat((x, x_orig), dim=-1)
|
| 257 |
+
return x
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
# from PIL import Image
|
| 261 |
+
# from unidepth.utils import image_grid, colorize
|
| 262 |
+
# if __name__ == "__main__":
|
| 263 |
+
# H, W = 512, 512
|
| 264 |
+
# resolution = 128
|
| 265 |
+
# mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W))
|
| 266 |
+
# mesh = torch.stack(mesh, dim=0).unsqueeze(0)
|
| 267 |
+
# mesh = mesh.view(1, 2, -1).permute(0, 2, 1)
|
| 268 |
+
|
| 269 |
+
# features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True)
|
| 270 |
+
# channels = features.shape[-1]
|
| 271 |
+
# print(features.shape)
|
| 272 |
+
|
| 273 |
+
# features = features[0].view(H, W, channels).permute(2, 0, 1).numpy()
|
| 274 |
+
# Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png")
|
flash3d/unidepth/utils/sht.py
ADDED
|
@@ -0,0 +1,1637 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Real spherical harmonics in Cartesian form for PyTorch.
|
| 2 |
+
|
| 3 |
+
This is an autogenerated file. See
|
| 4 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 5 |
+
for more information.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def rsh_cart_0(xyz: torch.Tensor):
|
| 12 |
+
"""Computes all real spherical harmonics up to degree 0.
|
| 13 |
+
|
| 14 |
+
This is an autogenerated method. See
|
| 15 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 16 |
+
for more information.
|
| 17 |
+
|
| 18 |
+
Params:
|
| 19 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 20 |
+
|
| 21 |
+
Returns:
|
| 22 |
+
rsh: (N,...,1) real spherical harmonics
|
| 23 |
+
projections of input. Ynm is found at index
|
| 24 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 25 |
+
`-n <= m <= n`.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
return torch.stack(
|
| 29 |
+
[
|
| 30 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 31 |
+
],
|
| 32 |
+
-1,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def rsh_cart_1(xyz: torch.Tensor):
|
| 37 |
+
"""Computes all real spherical harmonics up to degree 1.
|
| 38 |
+
|
| 39 |
+
This is an autogenerated method. See
|
| 40 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 41 |
+
for more information.
|
| 42 |
+
|
| 43 |
+
Params:
|
| 44 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
rsh: (N,...,4) real spherical harmonics
|
| 48 |
+
projections of input. Ynm is found at index
|
| 49 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 50 |
+
`-n <= m <= n`.
|
| 51 |
+
"""
|
| 52 |
+
x = xyz[..., 0]
|
| 53 |
+
y = xyz[..., 1]
|
| 54 |
+
z = xyz[..., 2]
|
| 55 |
+
|
| 56 |
+
return torch.stack(
|
| 57 |
+
[
|
| 58 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 59 |
+
-0.48860251190292 * y,
|
| 60 |
+
0.48860251190292 * z,
|
| 61 |
+
-0.48860251190292 * x,
|
| 62 |
+
],
|
| 63 |
+
-1,
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def rsh_cart_2(xyz: torch.Tensor):
|
| 68 |
+
"""Computes all real spherical harmonics up to degree 2.
|
| 69 |
+
|
| 70 |
+
This is an autogenerated method. See
|
| 71 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 72 |
+
for more information.
|
| 73 |
+
|
| 74 |
+
Params:
|
| 75 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
rsh: (N,...,9) real spherical harmonics
|
| 79 |
+
projections of input. Ynm is found at index
|
| 80 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 81 |
+
`-n <= m <= n`.
|
| 82 |
+
"""
|
| 83 |
+
x = xyz[..., 0]
|
| 84 |
+
y = xyz[..., 1]
|
| 85 |
+
z = xyz[..., 2]
|
| 86 |
+
|
| 87 |
+
x2 = x**2
|
| 88 |
+
y2 = y**2
|
| 89 |
+
z2 = z**2
|
| 90 |
+
xy = x * y
|
| 91 |
+
xz = x * z
|
| 92 |
+
yz = y * z
|
| 93 |
+
|
| 94 |
+
return torch.stack(
|
| 95 |
+
[
|
| 96 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 97 |
+
-0.48860251190292 * y,
|
| 98 |
+
0.48860251190292 * z,
|
| 99 |
+
-0.48860251190292 * x,
|
| 100 |
+
1.09254843059208 * xy,
|
| 101 |
+
-1.09254843059208 * yz,
|
| 102 |
+
0.94617469575756 * z2 - 0.31539156525252,
|
| 103 |
+
-1.09254843059208 * xz,
|
| 104 |
+
0.54627421529604 * x2 - 0.54627421529604 * y2,
|
| 105 |
+
],
|
| 106 |
+
-1,
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def rsh_cart_3(xyz: torch.Tensor):
|
| 111 |
+
"""Computes all real spherical harmonics up to degree 3.
|
| 112 |
+
|
| 113 |
+
This is an autogenerated method. See
|
| 114 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 115 |
+
for more information.
|
| 116 |
+
|
| 117 |
+
Params:
|
| 118 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
rsh: (N,...,16) real spherical harmonics
|
| 122 |
+
projections of input. Ynm is found at index
|
| 123 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 124 |
+
`-n <= m <= n`.
|
| 125 |
+
"""
|
| 126 |
+
x = xyz[..., 0]
|
| 127 |
+
y = xyz[..., 1]
|
| 128 |
+
z = xyz[..., 2]
|
| 129 |
+
|
| 130 |
+
x2 = x**2
|
| 131 |
+
y2 = y**2
|
| 132 |
+
z2 = z**2
|
| 133 |
+
xy = x * y
|
| 134 |
+
xz = x * z
|
| 135 |
+
yz = y * z
|
| 136 |
+
|
| 137 |
+
return torch.stack(
|
| 138 |
+
[
|
| 139 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 140 |
+
-0.48860251190292 * y,
|
| 141 |
+
0.48860251190292 * z,
|
| 142 |
+
-0.48860251190292 * x,
|
| 143 |
+
1.09254843059208 * xy,
|
| 144 |
+
-1.09254843059208 * yz,
|
| 145 |
+
0.94617469575756 * z2 - 0.31539156525252,
|
| 146 |
+
-1.09254843059208 * xz,
|
| 147 |
+
0.54627421529604 * x2 - 0.54627421529604 * y2,
|
| 148 |
+
-0.590043589926644 * y * (3.0 * x2 - y2),
|
| 149 |
+
2.89061144264055 * xy * z,
|
| 150 |
+
0.304697199642977 * y * (1.5 - 7.5 * z2),
|
| 151 |
+
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
|
| 152 |
+
0.304697199642977 * x * (1.5 - 7.5 * z2),
|
| 153 |
+
1.44530572132028 * z * (x2 - y2),
|
| 154 |
+
-0.590043589926644 * x * (x2 - 3.0 * y2),
|
| 155 |
+
],
|
| 156 |
+
-1,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def rsh_cart_4(xyz: torch.Tensor):
|
| 161 |
+
"""Computes all real spherical harmonics up to degree 4.
|
| 162 |
+
|
| 163 |
+
This is an autogenerated method. See
|
| 164 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 165 |
+
for more information.
|
| 166 |
+
|
| 167 |
+
Params:
|
| 168 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 169 |
+
|
| 170 |
+
Returns:
|
| 171 |
+
rsh: (N,...,25) real spherical harmonics
|
| 172 |
+
projections of input. Ynm is found at index
|
| 173 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 174 |
+
`-n <= m <= n`.
|
| 175 |
+
"""
|
| 176 |
+
x = xyz[..., 0]
|
| 177 |
+
y = xyz[..., 1]
|
| 178 |
+
z = xyz[..., 2]
|
| 179 |
+
|
| 180 |
+
x2 = x**2
|
| 181 |
+
y2 = y**2
|
| 182 |
+
z2 = z**2
|
| 183 |
+
xy = x * y
|
| 184 |
+
xz = x * z
|
| 185 |
+
yz = y * z
|
| 186 |
+
x4 = x2**2
|
| 187 |
+
y4 = y2**2
|
| 188 |
+
z4 = z2**2
|
| 189 |
+
|
| 190 |
+
return torch.stack(
|
| 191 |
+
[
|
| 192 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 193 |
+
-0.48860251190292 * y,
|
| 194 |
+
0.48860251190292 * z,
|
| 195 |
+
-0.48860251190292 * x,
|
| 196 |
+
1.09254843059208 * xy,
|
| 197 |
+
-1.09254843059208 * yz,
|
| 198 |
+
0.94617469575756 * z2 - 0.31539156525252,
|
| 199 |
+
-1.09254843059208 * xz,
|
| 200 |
+
0.54627421529604 * x2 - 0.54627421529604 * y2,
|
| 201 |
+
-0.590043589926644 * y * (3.0 * x2 - y2),
|
| 202 |
+
2.89061144264055 * xy * z,
|
| 203 |
+
0.304697199642977 * y * (1.5 - 7.5 * z2),
|
| 204 |
+
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
|
| 205 |
+
0.304697199642977 * x * (1.5 - 7.5 * z2),
|
| 206 |
+
1.44530572132028 * z * (x2 - y2),
|
| 207 |
+
-0.590043589926644 * x * (x2 - 3.0 * y2),
|
| 208 |
+
2.5033429417967 * xy * (x2 - y2),
|
| 209 |
+
-1.77013076977993 * yz * (3.0 * x2 - y2),
|
| 210 |
+
0.126156626101008 * xy * (52.5 * z2 - 7.5),
|
| 211 |
+
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 212 |
+
1.48099765681286
|
| 213 |
+
* z
|
| 214 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 215 |
+
- 0.952069922236839 * z2
|
| 216 |
+
+ 0.317356640745613,
|
| 217 |
+
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 218 |
+
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
|
| 219 |
+
-1.77013076977993 * xz * (x2 - 3.0 * y2),
|
| 220 |
+
-3.75501441269506 * x2 * y2
|
| 221 |
+
+ 0.625835735449176 * x4
|
| 222 |
+
+ 0.625835735449176 * y4,
|
| 223 |
+
],
|
| 224 |
+
-1,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def rsh_cart_5(xyz: torch.Tensor):
|
| 229 |
+
"""Computes all real spherical harmonics up to degree 5.
|
| 230 |
+
|
| 231 |
+
This is an autogenerated method. See
|
| 232 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 233 |
+
for more information.
|
| 234 |
+
|
| 235 |
+
Params:
|
| 236 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
rsh: (N,...,36) real spherical harmonics
|
| 240 |
+
projections of input. Ynm is found at index
|
| 241 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 242 |
+
`-n <= m <= n`.
|
| 243 |
+
"""
|
| 244 |
+
x = xyz[..., 0]
|
| 245 |
+
y = xyz[..., 1]
|
| 246 |
+
z = xyz[..., 2]
|
| 247 |
+
|
| 248 |
+
x2 = x**2
|
| 249 |
+
y2 = y**2
|
| 250 |
+
z2 = z**2
|
| 251 |
+
xy = x * y
|
| 252 |
+
xz = x * z
|
| 253 |
+
yz = y * z
|
| 254 |
+
x4 = x2**2
|
| 255 |
+
y4 = y2**2
|
| 256 |
+
z4 = z2**2
|
| 257 |
+
|
| 258 |
+
return torch.stack(
|
| 259 |
+
[
|
| 260 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 261 |
+
-0.48860251190292 * y,
|
| 262 |
+
0.48860251190292 * z,
|
| 263 |
+
-0.48860251190292 * x,
|
| 264 |
+
1.09254843059208 * xy,
|
| 265 |
+
-1.09254843059208 * yz,
|
| 266 |
+
0.94617469575756 * z2 - 0.31539156525252,
|
| 267 |
+
-1.09254843059208 * xz,
|
| 268 |
+
0.54627421529604 * x2 - 0.54627421529604 * y2,
|
| 269 |
+
-0.590043589926644 * y * (3.0 * x2 - y2),
|
| 270 |
+
2.89061144264055 * xy * z,
|
| 271 |
+
0.304697199642977 * y * (1.5 - 7.5 * z2),
|
| 272 |
+
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
|
| 273 |
+
0.304697199642977 * x * (1.5 - 7.5 * z2),
|
| 274 |
+
1.44530572132028 * z * (x2 - y2),
|
| 275 |
+
-0.590043589926644 * x * (x2 - 3.0 * y2),
|
| 276 |
+
2.5033429417967 * xy * (x2 - y2),
|
| 277 |
+
-1.77013076977993 * yz * (3.0 * x2 - y2),
|
| 278 |
+
0.126156626101008 * xy * (52.5 * z2 - 7.5),
|
| 279 |
+
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 280 |
+
1.48099765681286
|
| 281 |
+
* z
|
| 282 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 283 |
+
- 0.952069922236839 * z2
|
| 284 |
+
+ 0.317356640745613,
|
| 285 |
+
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 286 |
+
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
|
| 287 |
+
-1.77013076977993 * xz * (x2 - 3.0 * y2),
|
| 288 |
+
-3.75501441269506 * x2 * y2
|
| 289 |
+
+ 0.625835735449176 * x4
|
| 290 |
+
+ 0.625835735449176 * y4,
|
| 291 |
+
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 292 |
+
8.30264925952416 * xy * z * (x2 - y2),
|
| 293 |
+
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
|
| 294 |
+
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 295 |
+
0.241571547304372
|
| 296 |
+
* y
|
| 297 |
+
* (
|
| 298 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 299 |
+
+ 9.375 * z2
|
| 300 |
+
- 1.875
|
| 301 |
+
),
|
| 302 |
+
-1.24747010616985 * z * (1.5 * z2 - 0.5)
|
| 303 |
+
+ 1.6840846433293
|
| 304 |
+
* z
|
| 305 |
+
* (
|
| 306 |
+
1.75
|
| 307 |
+
* z
|
| 308 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 309 |
+
- 1.125 * z2
|
| 310 |
+
+ 0.375
|
| 311 |
+
)
|
| 312 |
+
+ 0.498988042467941 * z,
|
| 313 |
+
0.241571547304372
|
| 314 |
+
* x
|
| 315 |
+
* (
|
| 316 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 317 |
+
+ 9.375 * z2
|
| 318 |
+
- 1.875
|
| 319 |
+
),
|
| 320 |
+
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 321 |
+
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
|
| 322 |
+
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
|
| 323 |
+
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 324 |
+
],
|
| 325 |
+
-1,
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
def rsh_cart_6(xyz: torch.Tensor):
|
| 330 |
+
"""Computes all real spherical harmonics up to degree 6.
|
| 331 |
+
|
| 332 |
+
This is an autogenerated method. See
|
| 333 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 334 |
+
for more information.
|
| 335 |
+
|
| 336 |
+
Params:
|
| 337 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
rsh: (N,...,49) real spherical harmonics
|
| 341 |
+
projections of input. Ynm is found at index
|
| 342 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 343 |
+
`-n <= m <= n`.
|
| 344 |
+
"""
|
| 345 |
+
x = xyz[..., 0]
|
| 346 |
+
y = xyz[..., 1]
|
| 347 |
+
z = xyz[..., 2]
|
| 348 |
+
|
| 349 |
+
x2 = x**2
|
| 350 |
+
y2 = y**2
|
| 351 |
+
z2 = z**2
|
| 352 |
+
xy = x * y
|
| 353 |
+
xz = x * z
|
| 354 |
+
yz = y * z
|
| 355 |
+
x4 = x2**2
|
| 356 |
+
y4 = y2**2
|
| 357 |
+
z4 = z2**2
|
| 358 |
+
|
| 359 |
+
return torch.stack(
|
| 360 |
+
[
|
| 361 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 362 |
+
-0.48860251190292 * y,
|
| 363 |
+
0.48860251190292 * z,
|
| 364 |
+
-0.48860251190292 * x,
|
| 365 |
+
1.09254843059208 * xy,
|
| 366 |
+
-1.09254843059208 * yz,
|
| 367 |
+
0.94617469575756 * z2 - 0.31539156525252,
|
| 368 |
+
-1.09254843059208 * xz,
|
| 369 |
+
0.54627421529604 * x2 - 0.54627421529604 * y2,
|
| 370 |
+
-0.590043589926644 * y * (3.0 * x2 - y2),
|
| 371 |
+
2.89061144264055 * xy * z,
|
| 372 |
+
0.304697199642977 * y * (1.5 - 7.5 * z2),
|
| 373 |
+
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
|
| 374 |
+
0.304697199642977 * x * (1.5 - 7.5 * z2),
|
| 375 |
+
1.44530572132028 * z * (x2 - y2),
|
| 376 |
+
-0.590043589926644 * x * (x2 - 3.0 * y2),
|
| 377 |
+
2.5033429417967 * xy * (x2 - y2),
|
| 378 |
+
-1.77013076977993 * yz * (3.0 * x2 - y2),
|
| 379 |
+
0.126156626101008 * xy * (52.5 * z2 - 7.5),
|
| 380 |
+
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 381 |
+
1.48099765681286
|
| 382 |
+
* z
|
| 383 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 384 |
+
- 0.952069922236839 * z2
|
| 385 |
+
+ 0.317356640745613,
|
| 386 |
+
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 387 |
+
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
|
| 388 |
+
-1.77013076977993 * xz * (x2 - 3.0 * y2),
|
| 389 |
+
-3.75501441269506 * x2 * y2
|
| 390 |
+
+ 0.625835735449176 * x4
|
| 391 |
+
+ 0.625835735449176 * y4,
|
| 392 |
+
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 393 |
+
8.30264925952416 * xy * z * (x2 - y2),
|
| 394 |
+
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
|
| 395 |
+
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 396 |
+
0.241571547304372
|
| 397 |
+
* y
|
| 398 |
+
* (
|
| 399 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 400 |
+
+ 9.375 * z2
|
| 401 |
+
- 1.875
|
| 402 |
+
),
|
| 403 |
+
-1.24747010616985 * z * (1.5 * z2 - 0.5)
|
| 404 |
+
+ 1.6840846433293
|
| 405 |
+
* z
|
| 406 |
+
* (
|
| 407 |
+
1.75
|
| 408 |
+
* z
|
| 409 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 410 |
+
- 1.125 * z2
|
| 411 |
+
+ 0.375
|
| 412 |
+
)
|
| 413 |
+
+ 0.498988042467941 * z,
|
| 414 |
+
0.241571547304372
|
| 415 |
+
* x
|
| 416 |
+
* (
|
| 417 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 418 |
+
+ 9.375 * z2
|
| 419 |
+
- 1.875
|
| 420 |
+
),
|
| 421 |
+
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 422 |
+
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
|
| 423 |
+
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
|
| 424 |
+
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 425 |
+
4.09910463115149 * x**4 * xy
|
| 426 |
+
- 13.6636821038383 * xy**3
|
| 427 |
+
+ 4.09910463115149 * xy * y**4,
|
| 428 |
+
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 429 |
+
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
|
| 430 |
+
0.00584892228263444
|
| 431 |
+
* y
|
| 432 |
+
* (3.0 * x2 - y2)
|
| 433 |
+
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
|
| 434 |
+
0.0701870673916132
|
| 435 |
+
* xy
|
| 436 |
+
* (
|
| 437 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 438 |
+
- 91.875 * z2
|
| 439 |
+
+ 13.125
|
| 440 |
+
),
|
| 441 |
+
0.221950995245231
|
| 442 |
+
* y
|
| 443 |
+
* (
|
| 444 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 445 |
+
+ 2.2
|
| 446 |
+
* z
|
| 447 |
+
* (
|
| 448 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 449 |
+
+ 9.375 * z2
|
| 450 |
+
- 1.875
|
| 451 |
+
)
|
| 452 |
+
- 4.8 * z
|
| 453 |
+
),
|
| 454 |
+
-1.48328138624466
|
| 455 |
+
* z
|
| 456 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 457 |
+
+ 1.86469659985043
|
| 458 |
+
* z
|
| 459 |
+
* (
|
| 460 |
+
-1.33333333333333 * z * (1.5 * z2 - 0.5)
|
| 461 |
+
+ 1.8
|
| 462 |
+
* z
|
| 463 |
+
* (
|
| 464 |
+
1.75
|
| 465 |
+
* z
|
| 466 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 467 |
+
- 1.125 * z2
|
| 468 |
+
+ 0.375
|
| 469 |
+
)
|
| 470 |
+
+ 0.533333333333333 * z
|
| 471 |
+
)
|
| 472 |
+
+ 0.953538034014426 * z2
|
| 473 |
+
- 0.317846011338142,
|
| 474 |
+
0.221950995245231
|
| 475 |
+
* x
|
| 476 |
+
* (
|
| 477 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 478 |
+
+ 2.2
|
| 479 |
+
* z
|
| 480 |
+
* (
|
| 481 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 482 |
+
+ 9.375 * z2
|
| 483 |
+
- 1.875
|
| 484 |
+
)
|
| 485 |
+
- 4.8 * z
|
| 486 |
+
),
|
| 487 |
+
0.0350935336958066
|
| 488 |
+
* (x2 - y2)
|
| 489 |
+
* (
|
| 490 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 491 |
+
- 91.875 * z2
|
| 492 |
+
+ 13.125
|
| 493 |
+
),
|
| 494 |
+
0.00584892228263444
|
| 495 |
+
* x
|
| 496 |
+
* (x2 - 3.0 * y2)
|
| 497 |
+
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
|
| 498 |
+
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
|
| 499 |
+
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 500 |
+
0.683184105191914 * x2**3
|
| 501 |
+
+ 10.2477615778787 * x2 * y4
|
| 502 |
+
- 10.2477615778787 * x4 * y2
|
| 503 |
+
- 0.683184105191914 * y2**3,
|
| 504 |
+
],
|
| 505 |
+
-1,
|
| 506 |
+
)
|
| 507 |
+
|
| 508 |
+
|
| 509 |
+
def rsh_cart_7(xyz: torch.Tensor):
|
| 510 |
+
"""Computes all real spherical harmonics up to degree 7.
|
| 511 |
+
|
| 512 |
+
This is an autogenerated method. See
|
| 513 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 514 |
+
for more information.
|
| 515 |
+
|
| 516 |
+
Params:
|
| 517 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 518 |
+
|
| 519 |
+
Returns:
|
| 520 |
+
rsh: (N,...,64) real spherical harmonics
|
| 521 |
+
projections of input. Ynm is found at index
|
| 522 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 523 |
+
`-n <= m <= n`.
|
| 524 |
+
"""
|
| 525 |
+
x = xyz[..., 0]
|
| 526 |
+
y = xyz[..., 1]
|
| 527 |
+
z = xyz[..., 2]
|
| 528 |
+
|
| 529 |
+
x2 = x**2
|
| 530 |
+
y2 = y**2
|
| 531 |
+
z2 = z**2
|
| 532 |
+
xy = x * y
|
| 533 |
+
xz = x * z
|
| 534 |
+
yz = y * z
|
| 535 |
+
x4 = x2**2
|
| 536 |
+
y4 = y2**2
|
| 537 |
+
z4 = z2**2
|
| 538 |
+
|
| 539 |
+
return torch.stack(
|
| 540 |
+
[
|
| 541 |
+
xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
|
| 542 |
+
-0.48860251190292 * y,
|
| 543 |
+
0.48860251190292 * z,
|
| 544 |
+
-0.48860251190292 * x,
|
| 545 |
+
1.09254843059208 * xy,
|
| 546 |
+
-1.09254843059208 * yz,
|
| 547 |
+
0.94617469575756 * z2 - 0.31539156525252,
|
| 548 |
+
-1.09254843059208 * xz,
|
| 549 |
+
0.54627421529604 * x2 - 0.54627421529604 * y2,
|
| 550 |
+
-0.590043589926644 * y * (3.0 * x2 - y2),
|
| 551 |
+
2.89061144264055 * xy * z,
|
| 552 |
+
0.304697199642977 * y * (1.5 - 7.5 * z2),
|
| 553 |
+
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
|
| 554 |
+
0.304697199642977 * x * (1.5 - 7.5 * z2),
|
| 555 |
+
1.44530572132028 * z * (x2 - y2),
|
| 556 |
+
-0.590043589926644 * x * (x2 - 3.0 * y2),
|
| 557 |
+
2.5033429417967 * xy * (x2 - y2),
|
| 558 |
+
-1.77013076977993 * yz * (3.0 * x2 - y2),
|
| 559 |
+
0.126156626101008 * xy * (52.5 * z2 - 7.5),
|
| 560 |
+
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 561 |
+
1.48099765681286
|
| 562 |
+
* z
|
| 563 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 564 |
+
- 0.952069922236839 * z2
|
| 565 |
+
+ 0.317356640745613,
|
| 566 |
+
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 567 |
+
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
|
| 568 |
+
-1.77013076977993 * xz * (x2 - 3.0 * y2),
|
| 569 |
+
-3.75501441269506 * x2 * y2
|
| 570 |
+
+ 0.625835735449176 * x4
|
| 571 |
+
+ 0.625835735449176 * y4,
|
| 572 |
+
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 573 |
+
8.30264925952416 * xy * z * (x2 - y2),
|
| 574 |
+
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
|
| 575 |
+
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 576 |
+
0.241571547304372
|
| 577 |
+
* y
|
| 578 |
+
* (
|
| 579 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 580 |
+
+ 9.375 * z2
|
| 581 |
+
- 1.875
|
| 582 |
+
),
|
| 583 |
+
-1.24747010616985 * z * (1.5 * z2 - 0.5)
|
| 584 |
+
+ 1.6840846433293
|
| 585 |
+
* z
|
| 586 |
+
* (
|
| 587 |
+
1.75
|
| 588 |
+
* z
|
| 589 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 590 |
+
- 1.125 * z2
|
| 591 |
+
+ 0.375
|
| 592 |
+
)
|
| 593 |
+
+ 0.498988042467941 * z,
|
| 594 |
+
0.241571547304372
|
| 595 |
+
* x
|
| 596 |
+
* (
|
| 597 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 598 |
+
+ 9.375 * z2
|
| 599 |
+
- 1.875
|
| 600 |
+
),
|
| 601 |
+
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 602 |
+
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
|
| 603 |
+
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
|
| 604 |
+
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 605 |
+
4.09910463115149 * x**4 * xy
|
| 606 |
+
- 13.6636821038383 * xy**3
|
| 607 |
+
+ 4.09910463115149 * xy * y**4,
|
| 608 |
+
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 609 |
+
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
|
| 610 |
+
0.00584892228263444
|
| 611 |
+
* y
|
| 612 |
+
* (3.0 * x2 - y2)
|
| 613 |
+
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
|
| 614 |
+
0.0701870673916132
|
| 615 |
+
* xy
|
| 616 |
+
* (
|
| 617 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 618 |
+
- 91.875 * z2
|
| 619 |
+
+ 13.125
|
| 620 |
+
),
|
| 621 |
+
0.221950995245231
|
| 622 |
+
* y
|
| 623 |
+
* (
|
| 624 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 625 |
+
+ 2.2
|
| 626 |
+
* z
|
| 627 |
+
* (
|
| 628 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 629 |
+
+ 9.375 * z2
|
| 630 |
+
- 1.875
|
| 631 |
+
)
|
| 632 |
+
- 4.8 * z
|
| 633 |
+
),
|
| 634 |
+
-1.48328138624466
|
| 635 |
+
* z
|
| 636 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 637 |
+
+ 1.86469659985043
|
| 638 |
+
* z
|
| 639 |
+
* (
|
| 640 |
+
-1.33333333333333 * z * (1.5 * z2 - 0.5)
|
| 641 |
+
+ 1.8
|
| 642 |
+
* z
|
| 643 |
+
* (
|
| 644 |
+
1.75
|
| 645 |
+
* z
|
| 646 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 647 |
+
- 1.125 * z2
|
| 648 |
+
+ 0.375
|
| 649 |
+
)
|
| 650 |
+
+ 0.533333333333333 * z
|
| 651 |
+
)
|
| 652 |
+
+ 0.953538034014426 * z2
|
| 653 |
+
- 0.317846011338142,
|
| 654 |
+
0.221950995245231
|
| 655 |
+
* x
|
| 656 |
+
* (
|
| 657 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 658 |
+
+ 2.2
|
| 659 |
+
* z
|
| 660 |
+
* (
|
| 661 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 662 |
+
+ 9.375 * z2
|
| 663 |
+
- 1.875
|
| 664 |
+
)
|
| 665 |
+
- 4.8 * z
|
| 666 |
+
),
|
| 667 |
+
0.0350935336958066
|
| 668 |
+
* (x2 - y2)
|
| 669 |
+
* (
|
| 670 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 671 |
+
- 91.875 * z2
|
| 672 |
+
+ 13.125
|
| 673 |
+
),
|
| 674 |
+
0.00584892228263444
|
| 675 |
+
* x
|
| 676 |
+
* (x2 - 3.0 * y2)
|
| 677 |
+
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
|
| 678 |
+
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
|
| 679 |
+
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 680 |
+
0.683184105191914 * x2**3
|
| 681 |
+
+ 10.2477615778787 * x2 * y4
|
| 682 |
+
- 10.2477615778787 * x4 * y2
|
| 683 |
+
- 0.683184105191914 * y2**3,
|
| 684 |
+
-0.707162732524596
|
| 685 |
+
* y
|
| 686 |
+
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
|
| 687 |
+
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
|
| 688 |
+
9.98394571852353e-5
|
| 689 |
+
* y
|
| 690 |
+
* (5197.5 - 67567.5 * z2)
|
| 691 |
+
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 692 |
+
0.00239614697244565
|
| 693 |
+
* xy
|
| 694 |
+
* (x2 - y2)
|
| 695 |
+
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
|
| 696 |
+
0.00397356022507413
|
| 697 |
+
* y
|
| 698 |
+
* (3.0 * x2 - y2)
|
| 699 |
+
* (
|
| 700 |
+
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
|
| 701 |
+
+ 1063.125 * z2
|
| 702 |
+
- 118.125
|
| 703 |
+
),
|
| 704 |
+
0.0561946276120613
|
| 705 |
+
* xy
|
| 706 |
+
* (
|
| 707 |
+
-4.8 * z * (52.5 * z2 - 7.5)
|
| 708 |
+
+ 2.6
|
| 709 |
+
* z
|
| 710 |
+
* (
|
| 711 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 712 |
+
- 91.875 * z2
|
| 713 |
+
+ 13.125
|
| 714 |
+
)
|
| 715 |
+
+ 48.0 * z
|
| 716 |
+
),
|
| 717 |
+
0.206472245902897
|
| 718 |
+
* y
|
| 719 |
+
* (
|
| 720 |
+
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 721 |
+
+ 2.16666666666667
|
| 722 |
+
* z
|
| 723 |
+
* (
|
| 724 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 725 |
+
+ 2.2
|
| 726 |
+
* z
|
| 727 |
+
* (
|
| 728 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 729 |
+
+ 9.375 * z2
|
| 730 |
+
- 1.875
|
| 731 |
+
)
|
| 732 |
+
- 4.8 * z
|
| 733 |
+
)
|
| 734 |
+
- 10.9375 * z2
|
| 735 |
+
+ 2.1875
|
| 736 |
+
),
|
| 737 |
+
1.24862677781952 * z * (1.5 * z2 - 0.5)
|
| 738 |
+
- 1.68564615005635
|
| 739 |
+
* z
|
| 740 |
+
* (
|
| 741 |
+
1.75
|
| 742 |
+
* z
|
| 743 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 744 |
+
- 1.125 * z2
|
| 745 |
+
+ 0.375
|
| 746 |
+
)
|
| 747 |
+
+ 2.02901851395672
|
| 748 |
+
* z
|
| 749 |
+
* (
|
| 750 |
+
-1.45833333333333
|
| 751 |
+
* z
|
| 752 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 753 |
+
+ 1.83333333333333
|
| 754 |
+
* z
|
| 755 |
+
* (
|
| 756 |
+
-1.33333333333333 * z * (1.5 * z2 - 0.5)
|
| 757 |
+
+ 1.8
|
| 758 |
+
* z
|
| 759 |
+
* (
|
| 760 |
+
1.75
|
| 761 |
+
* z
|
| 762 |
+
* (
|
| 763 |
+
1.66666666666667 * z * (1.5 * z2 - 0.5)
|
| 764 |
+
- 0.666666666666667 * z
|
| 765 |
+
)
|
| 766 |
+
- 1.125 * z2
|
| 767 |
+
+ 0.375
|
| 768 |
+
)
|
| 769 |
+
+ 0.533333333333333 * z
|
| 770 |
+
)
|
| 771 |
+
+ 0.9375 * z2
|
| 772 |
+
- 0.3125
|
| 773 |
+
)
|
| 774 |
+
- 0.499450711127808 * z,
|
| 775 |
+
0.206472245902897
|
| 776 |
+
* x
|
| 777 |
+
* (
|
| 778 |
+
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 779 |
+
+ 2.16666666666667
|
| 780 |
+
* z
|
| 781 |
+
* (
|
| 782 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 783 |
+
+ 2.2
|
| 784 |
+
* z
|
| 785 |
+
* (
|
| 786 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 787 |
+
+ 9.375 * z2
|
| 788 |
+
- 1.875
|
| 789 |
+
)
|
| 790 |
+
- 4.8 * z
|
| 791 |
+
)
|
| 792 |
+
- 10.9375 * z2
|
| 793 |
+
+ 2.1875
|
| 794 |
+
),
|
| 795 |
+
0.0280973138060306
|
| 796 |
+
* (x2 - y2)
|
| 797 |
+
* (
|
| 798 |
+
-4.8 * z * (52.5 * z2 - 7.5)
|
| 799 |
+
+ 2.6
|
| 800 |
+
* z
|
| 801 |
+
* (
|
| 802 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 803 |
+
- 91.875 * z2
|
| 804 |
+
+ 13.125
|
| 805 |
+
)
|
| 806 |
+
+ 48.0 * z
|
| 807 |
+
),
|
| 808 |
+
0.00397356022507413
|
| 809 |
+
* x
|
| 810 |
+
* (x2 - 3.0 * y2)
|
| 811 |
+
* (
|
| 812 |
+
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
|
| 813 |
+
+ 1063.125 * z2
|
| 814 |
+
- 118.125
|
| 815 |
+
),
|
| 816 |
+
0.000599036743111412
|
| 817 |
+
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
|
| 818 |
+
* (-6.0 * x2 * y2 + x4 + y4),
|
| 819 |
+
9.98394571852353e-5
|
| 820 |
+
* x
|
| 821 |
+
* (5197.5 - 67567.5 * z2)
|
| 822 |
+
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 823 |
+
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
|
| 824 |
+
-0.707162732524596
|
| 825 |
+
* x
|
| 826 |
+
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
|
| 827 |
+
],
|
| 828 |
+
-1,
|
| 829 |
+
)
|
| 830 |
+
|
| 831 |
+
|
| 832 |
+
# @torch.jit.script
|
| 833 |
+
def rsh_cart_8(xyz: torch.Tensor):
|
| 834 |
+
"""Computes all real spherical harmonics up to degree 8.
|
| 835 |
+
|
| 836 |
+
This is an autogenerated method. See
|
| 837 |
+
https://github.com/cheind/torch-spherical-harmonics
|
| 838 |
+
for more information.
|
| 839 |
+
|
| 840 |
+
Params:
|
| 841 |
+
xyz: (N,...,3) tensor of points on the unit sphere
|
| 842 |
+
|
| 843 |
+
Returns:
|
| 844 |
+
rsh: (N,...,81) real spherical harmonics
|
| 845 |
+
projections of input. Ynm is found at index
|
| 846 |
+
`n*(n+1) + m`, with `0 <= n <= degree` and
|
| 847 |
+
`-n <= m <= n`.
|
| 848 |
+
"""
|
| 849 |
+
x = xyz[..., 0]
|
| 850 |
+
y = xyz[..., 1]
|
| 851 |
+
z = xyz[..., 2]
|
| 852 |
+
|
| 853 |
+
x2 = x**2
|
| 854 |
+
y2 = y**2
|
| 855 |
+
z2 = z**2
|
| 856 |
+
xy = x * y
|
| 857 |
+
xz = x * z
|
| 858 |
+
yz = y * z
|
| 859 |
+
x4 = x2**2
|
| 860 |
+
y4 = y2**2
|
| 861 |
+
# z4 = z2**2
|
| 862 |
+
return torch.stack(
|
| 863 |
+
[
|
| 864 |
+
0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
|
| 865 |
+
-0.48860251190292 * y,
|
| 866 |
+
0.48860251190292 * z,
|
| 867 |
+
-0.48860251190292 * x,
|
| 868 |
+
1.09254843059208 * xy,
|
| 869 |
+
-1.09254843059208 * yz,
|
| 870 |
+
0.94617469575756 * z2 - 0.31539156525252,
|
| 871 |
+
-1.09254843059208 * xz,
|
| 872 |
+
0.54627421529604 * x2 - 0.54627421529604 * y2,
|
| 873 |
+
-0.590043589926644 * y * (3.0 * x2 - y2),
|
| 874 |
+
2.89061144264055 * xy * z,
|
| 875 |
+
0.304697199642977 * y * (1.5 - 7.5 * z2),
|
| 876 |
+
1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
|
| 877 |
+
0.304697199642977 * x * (1.5 - 7.5 * z2),
|
| 878 |
+
1.44530572132028 * z * (x2 - y2),
|
| 879 |
+
-0.590043589926644 * x * (x2 - 3.0 * y2),
|
| 880 |
+
2.5033429417967 * xy * (x2 - y2),
|
| 881 |
+
-1.77013076977993 * yz * (3.0 * x2 - y2),
|
| 882 |
+
0.126156626101008 * xy * (52.5 * z2 - 7.5),
|
| 883 |
+
0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 884 |
+
1.48099765681286
|
| 885 |
+
* z
|
| 886 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 887 |
+
- 0.952069922236839 * z2
|
| 888 |
+
+ 0.317356640745613,
|
| 889 |
+
0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
|
| 890 |
+
0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
|
| 891 |
+
-1.77013076977993 * xz * (x2 - 3.0 * y2),
|
| 892 |
+
-3.75501441269506 * x2 * y2
|
| 893 |
+
+ 0.625835735449176 * x4
|
| 894 |
+
+ 0.625835735449176 * y4,
|
| 895 |
+
-0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 896 |
+
8.30264925952416 * xy * z * (x2 - y2),
|
| 897 |
+
0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
|
| 898 |
+
0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 899 |
+
0.241571547304372
|
| 900 |
+
* y
|
| 901 |
+
* (
|
| 902 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 903 |
+
+ 9.375 * z2
|
| 904 |
+
- 1.875
|
| 905 |
+
),
|
| 906 |
+
-1.24747010616985 * z * (1.5 * z2 - 0.5)
|
| 907 |
+
+ 1.6840846433293
|
| 908 |
+
* z
|
| 909 |
+
* (
|
| 910 |
+
1.75
|
| 911 |
+
* z
|
| 912 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 913 |
+
- 1.125 * z2
|
| 914 |
+
+ 0.375
|
| 915 |
+
)
|
| 916 |
+
+ 0.498988042467941 * z,
|
| 917 |
+
0.241571547304372
|
| 918 |
+
* x
|
| 919 |
+
* (
|
| 920 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 921 |
+
+ 9.375 * z2
|
| 922 |
+
- 1.875
|
| 923 |
+
),
|
| 924 |
+
0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
|
| 925 |
+
0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
|
| 926 |
+
2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
|
| 927 |
+
-0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 928 |
+
4.09910463115149 * x**4 * xy
|
| 929 |
+
- 13.6636821038383 * xy**3
|
| 930 |
+
+ 4.09910463115149 * xy * y**4,
|
| 931 |
+
-2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 932 |
+
0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
|
| 933 |
+
0.00584892228263444
|
| 934 |
+
* y
|
| 935 |
+
* (3.0 * x2 - y2)
|
| 936 |
+
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
|
| 937 |
+
0.0701870673916132
|
| 938 |
+
* xy
|
| 939 |
+
* (
|
| 940 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 941 |
+
- 91.875 * z2
|
| 942 |
+
+ 13.125
|
| 943 |
+
),
|
| 944 |
+
0.221950995245231
|
| 945 |
+
* y
|
| 946 |
+
* (
|
| 947 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 948 |
+
+ 2.2
|
| 949 |
+
* z
|
| 950 |
+
* (
|
| 951 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 952 |
+
+ 9.375 * z2
|
| 953 |
+
- 1.875
|
| 954 |
+
)
|
| 955 |
+
- 4.8 * z
|
| 956 |
+
),
|
| 957 |
+
-1.48328138624466
|
| 958 |
+
* z
|
| 959 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 960 |
+
+ 1.86469659985043
|
| 961 |
+
* z
|
| 962 |
+
* (
|
| 963 |
+
-1.33333333333333 * z * (1.5 * z2 - 0.5)
|
| 964 |
+
+ 1.8
|
| 965 |
+
* z
|
| 966 |
+
* (
|
| 967 |
+
1.75
|
| 968 |
+
* z
|
| 969 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 970 |
+
- 1.125 * z2
|
| 971 |
+
+ 0.375
|
| 972 |
+
)
|
| 973 |
+
+ 0.533333333333333 * z
|
| 974 |
+
)
|
| 975 |
+
+ 0.953538034014426 * z2
|
| 976 |
+
- 0.317846011338142,
|
| 977 |
+
0.221950995245231
|
| 978 |
+
* x
|
| 979 |
+
* (
|
| 980 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 981 |
+
+ 2.2
|
| 982 |
+
* z
|
| 983 |
+
* (
|
| 984 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 985 |
+
+ 9.375 * z2
|
| 986 |
+
- 1.875
|
| 987 |
+
)
|
| 988 |
+
- 4.8 * z
|
| 989 |
+
),
|
| 990 |
+
0.0350935336958066
|
| 991 |
+
* (x2 - y2)
|
| 992 |
+
* (
|
| 993 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 994 |
+
- 91.875 * z2
|
| 995 |
+
+ 13.125
|
| 996 |
+
),
|
| 997 |
+
0.00584892228263444
|
| 998 |
+
* x
|
| 999 |
+
* (x2 - 3.0 * y2)
|
| 1000 |
+
* (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
|
| 1001 |
+
0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
|
| 1002 |
+
-2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 1003 |
+
0.683184105191914 * x2**3
|
| 1004 |
+
+ 10.2477615778787 * x2 * y4
|
| 1005 |
+
- 10.2477615778787 * x4 * y2
|
| 1006 |
+
- 0.683184105191914 * y2**3,
|
| 1007 |
+
-0.707162732524596
|
| 1008 |
+
* y
|
| 1009 |
+
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
|
| 1010 |
+
2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
|
| 1011 |
+
9.98394571852353e-5
|
| 1012 |
+
* y
|
| 1013 |
+
* (5197.5 - 67567.5 * z2)
|
| 1014 |
+
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 1015 |
+
0.00239614697244565
|
| 1016 |
+
* xy
|
| 1017 |
+
* (x2 - y2)
|
| 1018 |
+
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
|
| 1019 |
+
0.00397356022507413
|
| 1020 |
+
* y
|
| 1021 |
+
* (3.0 * x2 - y2)
|
| 1022 |
+
* (
|
| 1023 |
+
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
|
| 1024 |
+
+ 1063.125 * z2
|
| 1025 |
+
- 118.125
|
| 1026 |
+
),
|
| 1027 |
+
0.0561946276120613
|
| 1028 |
+
* xy
|
| 1029 |
+
* (
|
| 1030 |
+
-4.8 * z * (52.5 * z2 - 7.5)
|
| 1031 |
+
+ 2.6
|
| 1032 |
+
* z
|
| 1033 |
+
* (
|
| 1034 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 1035 |
+
- 91.875 * z2
|
| 1036 |
+
+ 13.125
|
| 1037 |
+
)
|
| 1038 |
+
+ 48.0 * z
|
| 1039 |
+
),
|
| 1040 |
+
0.206472245902897
|
| 1041 |
+
* y
|
| 1042 |
+
* (
|
| 1043 |
+
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1044 |
+
+ 2.16666666666667
|
| 1045 |
+
* z
|
| 1046 |
+
* (
|
| 1047 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 1048 |
+
+ 2.2
|
| 1049 |
+
* z
|
| 1050 |
+
* (
|
| 1051 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1052 |
+
+ 9.375 * z2
|
| 1053 |
+
- 1.875
|
| 1054 |
+
)
|
| 1055 |
+
- 4.8 * z
|
| 1056 |
+
)
|
| 1057 |
+
- 10.9375 * z2
|
| 1058 |
+
+ 2.1875
|
| 1059 |
+
),
|
| 1060 |
+
1.24862677781952 * z * (1.5 * z2 - 0.5)
|
| 1061 |
+
- 1.68564615005635
|
| 1062 |
+
* z
|
| 1063 |
+
* (
|
| 1064 |
+
1.75
|
| 1065 |
+
* z
|
| 1066 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 1067 |
+
- 1.125 * z2
|
| 1068 |
+
+ 0.375
|
| 1069 |
+
)
|
| 1070 |
+
+ 2.02901851395672
|
| 1071 |
+
* z
|
| 1072 |
+
* (
|
| 1073 |
+
-1.45833333333333
|
| 1074 |
+
* z
|
| 1075 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 1076 |
+
+ 1.83333333333333
|
| 1077 |
+
* z
|
| 1078 |
+
* (
|
| 1079 |
+
-1.33333333333333 * z * (1.5 * z2 - 0.5)
|
| 1080 |
+
+ 1.8
|
| 1081 |
+
* z
|
| 1082 |
+
* (
|
| 1083 |
+
1.75
|
| 1084 |
+
* z
|
| 1085 |
+
* (
|
| 1086 |
+
1.66666666666667 * z * (1.5 * z2 - 0.5)
|
| 1087 |
+
- 0.666666666666667 * z
|
| 1088 |
+
)
|
| 1089 |
+
- 1.125 * z2
|
| 1090 |
+
+ 0.375
|
| 1091 |
+
)
|
| 1092 |
+
+ 0.533333333333333 * z
|
| 1093 |
+
)
|
| 1094 |
+
+ 0.9375 * z2
|
| 1095 |
+
- 0.3125
|
| 1096 |
+
)
|
| 1097 |
+
- 0.499450711127808 * z,
|
| 1098 |
+
0.206472245902897
|
| 1099 |
+
* x
|
| 1100 |
+
* (
|
| 1101 |
+
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1102 |
+
+ 2.16666666666667
|
| 1103 |
+
* z
|
| 1104 |
+
* (
|
| 1105 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 1106 |
+
+ 2.2
|
| 1107 |
+
* z
|
| 1108 |
+
* (
|
| 1109 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1110 |
+
+ 9.375 * z2
|
| 1111 |
+
- 1.875
|
| 1112 |
+
)
|
| 1113 |
+
- 4.8 * z
|
| 1114 |
+
)
|
| 1115 |
+
- 10.9375 * z2
|
| 1116 |
+
+ 2.1875
|
| 1117 |
+
),
|
| 1118 |
+
0.0280973138060306
|
| 1119 |
+
* (x2 - y2)
|
| 1120 |
+
* (
|
| 1121 |
+
-4.8 * z * (52.5 * z2 - 7.5)
|
| 1122 |
+
+ 2.6
|
| 1123 |
+
* z
|
| 1124 |
+
* (
|
| 1125 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 1126 |
+
- 91.875 * z2
|
| 1127 |
+
+ 13.125
|
| 1128 |
+
)
|
| 1129 |
+
+ 48.0 * z
|
| 1130 |
+
),
|
| 1131 |
+
0.00397356022507413
|
| 1132 |
+
* x
|
| 1133 |
+
* (x2 - 3.0 * y2)
|
| 1134 |
+
* (
|
| 1135 |
+
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
|
| 1136 |
+
+ 1063.125 * z2
|
| 1137 |
+
- 118.125
|
| 1138 |
+
),
|
| 1139 |
+
0.000599036743111412
|
| 1140 |
+
* (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
|
| 1141 |
+
* (-6.0 * x2 * y2 + x4 + y4),
|
| 1142 |
+
9.98394571852353e-5
|
| 1143 |
+
* x
|
| 1144 |
+
* (5197.5 - 67567.5 * z2)
|
| 1145 |
+
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 1146 |
+
2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
|
| 1147 |
+
-0.707162732524596
|
| 1148 |
+
* x
|
| 1149 |
+
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
|
| 1150 |
+
5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
|
| 1151 |
+
-2.91570664069932
|
| 1152 |
+
* yz
|
| 1153 |
+
* (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
|
| 1154 |
+
7.87853281621404e-6
|
| 1155 |
+
* (1013512.5 * z2 - 67567.5)
|
| 1156 |
+
* (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
|
| 1157 |
+
5.10587282657803e-5
|
| 1158 |
+
* y
|
| 1159 |
+
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
|
| 1160 |
+
* (-10.0 * x2 * y2 + 5.0 * x4 + y4),
|
| 1161 |
+
0.00147275890257803
|
| 1162 |
+
* xy
|
| 1163 |
+
* (x2 - y2)
|
| 1164 |
+
* (
|
| 1165 |
+
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
|
| 1166 |
+
- 14293.125 * z2
|
| 1167 |
+
+ 1299.375
|
| 1168 |
+
),
|
| 1169 |
+
0.0028519853513317
|
| 1170 |
+
* y
|
| 1171 |
+
* (3.0 * x2 - y2)
|
| 1172 |
+
* (
|
| 1173 |
+
-7.33333333333333 * z * (52.5 - 472.5 * z2)
|
| 1174 |
+
+ 3.0
|
| 1175 |
+
* z
|
| 1176 |
+
* (
|
| 1177 |
+
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
|
| 1178 |
+
+ 1063.125 * z2
|
| 1179 |
+
- 118.125
|
| 1180 |
+
)
|
| 1181 |
+
- 560.0 * z
|
| 1182 |
+
),
|
| 1183 |
+
0.0463392770473559
|
| 1184 |
+
* xy
|
| 1185 |
+
* (
|
| 1186 |
+
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 1187 |
+
+ 2.5
|
| 1188 |
+
* z
|
| 1189 |
+
* (
|
| 1190 |
+
-4.8 * z * (52.5 * z2 - 7.5)
|
| 1191 |
+
+ 2.6
|
| 1192 |
+
* z
|
| 1193 |
+
* (
|
| 1194 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 1195 |
+
- 91.875 * z2
|
| 1196 |
+
+ 13.125
|
| 1197 |
+
)
|
| 1198 |
+
+ 48.0 * z
|
| 1199 |
+
)
|
| 1200 |
+
+ 137.8125 * z2
|
| 1201 |
+
- 19.6875
|
| 1202 |
+
),
|
| 1203 |
+
0.193851103820053
|
| 1204 |
+
* y
|
| 1205 |
+
* (
|
| 1206 |
+
3.2 * z * (1.5 - 7.5 * z2)
|
| 1207 |
+
- 2.51428571428571
|
| 1208 |
+
* z
|
| 1209 |
+
* (
|
| 1210 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1211 |
+
+ 9.375 * z2
|
| 1212 |
+
- 1.875
|
| 1213 |
+
)
|
| 1214 |
+
+ 2.14285714285714
|
| 1215 |
+
* z
|
| 1216 |
+
* (
|
| 1217 |
+
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1218 |
+
+ 2.16666666666667
|
| 1219 |
+
* z
|
| 1220 |
+
* (
|
| 1221 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 1222 |
+
+ 2.2
|
| 1223 |
+
* z
|
| 1224 |
+
* (
|
| 1225 |
+
2.25
|
| 1226 |
+
* z
|
| 1227 |
+
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1228 |
+
+ 9.375 * z2
|
| 1229 |
+
- 1.875
|
| 1230 |
+
)
|
| 1231 |
+
- 4.8 * z
|
| 1232 |
+
)
|
| 1233 |
+
- 10.9375 * z2
|
| 1234 |
+
+ 2.1875
|
| 1235 |
+
)
|
| 1236 |
+
+ 5.48571428571429 * z
|
| 1237 |
+
),
|
| 1238 |
+
1.48417251362228
|
| 1239 |
+
* z
|
| 1240 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 1241 |
+
- 1.86581687426801
|
| 1242 |
+
* z
|
| 1243 |
+
* (
|
| 1244 |
+
-1.33333333333333 * z * (1.5 * z2 - 0.5)
|
| 1245 |
+
+ 1.8
|
| 1246 |
+
* z
|
| 1247 |
+
* (
|
| 1248 |
+
1.75
|
| 1249 |
+
* z
|
| 1250 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 1251 |
+
- 1.125 * z2
|
| 1252 |
+
+ 0.375
|
| 1253 |
+
)
|
| 1254 |
+
+ 0.533333333333333 * z
|
| 1255 |
+
)
|
| 1256 |
+
+ 2.1808249179756
|
| 1257 |
+
* z
|
| 1258 |
+
* (
|
| 1259 |
+
1.14285714285714 * z * (1.5 * z2 - 0.5)
|
| 1260 |
+
- 1.54285714285714
|
| 1261 |
+
* z
|
| 1262 |
+
* (
|
| 1263 |
+
1.75
|
| 1264 |
+
* z
|
| 1265 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 1266 |
+
- 1.125 * z2
|
| 1267 |
+
+ 0.375
|
| 1268 |
+
)
|
| 1269 |
+
+ 1.85714285714286
|
| 1270 |
+
* z
|
| 1271 |
+
* (
|
| 1272 |
+
-1.45833333333333
|
| 1273 |
+
* z
|
| 1274 |
+
* (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
|
| 1275 |
+
+ 1.83333333333333
|
| 1276 |
+
* z
|
| 1277 |
+
* (
|
| 1278 |
+
-1.33333333333333 * z * (1.5 * z2 - 0.5)
|
| 1279 |
+
+ 1.8
|
| 1280 |
+
* z
|
| 1281 |
+
* (
|
| 1282 |
+
1.75
|
| 1283 |
+
* z
|
| 1284 |
+
* (
|
| 1285 |
+
1.66666666666667 * z * (1.5 * z2 - 0.5)
|
| 1286 |
+
- 0.666666666666667 * z
|
| 1287 |
+
)
|
| 1288 |
+
- 1.125 * z2
|
| 1289 |
+
+ 0.375
|
| 1290 |
+
)
|
| 1291 |
+
+ 0.533333333333333 * z
|
| 1292 |
+
)
|
| 1293 |
+
+ 0.9375 * z2
|
| 1294 |
+
- 0.3125
|
| 1295 |
+
)
|
| 1296 |
+
- 0.457142857142857 * z
|
| 1297 |
+
)
|
| 1298 |
+
- 0.954110901614325 * z2
|
| 1299 |
+
+ 0.318036967204775,
|
| 1300 |
+
0.193851103820053
|
| 1301 |
+
* x
|
| 1302 |
+
* (
|
| 1303 |
+
3.2 * z * (1.5 - 7.5 * z2)
|
| 1304 |
+
- 2.51428571428571
|
| 1305 |
+
* z
|
| 1306 |
+
* (
|
| 1307 |
+
2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1308 |
+
+ 9.375 * z2
|
| 1309 |
+
- 1.875
|
| 1310 |
+
)
|
| 1311 |
+
+ 2.14285714285714
|
| 1312 |
+
* z
|
| 1313 |
+
* (
|
| 1314 |
+
-2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1315 |
+
+ 2.16666666666667
|
| 1316 |
+
* z
|
| 1317 |
+
* (
|
| 1318 |
+
-2.8 * z * (1.5 - 7.5 * z2)
|
| 1319 |
+
+ 2.2
|
| 1320 |
+
* z
|
| 1321 |
+
* (
|
| 1322 |
+
2.25
|
| 1323 |
+
* z
|
| 1324 |
+
* (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
|
| 1325 |
+
+ 9.375 * z2
|
| 1326 |
+
- 1.875
|
| 1327 |
+
)
|
| 1328 |
+
- 4.8 * z
|
| 1329 |
+
)
|
| 1330 |
+
- 10.9375 * z2
|
| 1331 |
+
+ 2.1875
|
| 1332 |
+
)
|
| 1333 |
+
+ 5.48571428571429 * z
|
| 1334 |
+
),
|
| 1335 |
+
0.0231696385236779
|
| 1336 |
+
* (x2 - y2)
|
| 1337 |
+
* (
|
| 1338 |
+
-4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 1339 |
+
+ 2.5
|
| 1340 |
+
* z
|
| 1341 |
+
* (
|
| 1342 |
+
-4.8 * z * (52.5 * z2 - 7.5)
|
| 1343 |
+
+ 2.6
|
| 1344 |
+
* z
|
| 1345 |
+
* (
|
| 1346 |
+
2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
|
| 1347 |
+
- 91.875 * z2
|
| 1348 |
+
+ 13.125
|
| 1349 |
+
)
|
| 1350 |
+
+ 48.0 * z
|
| 1351 |
+
)
|
| 1352 |
+
+ 137.8125 * z2
|
| 1353 |
+
- 19.6875
|
| 1354 |
+
),
|
| 1355 |
+
0.0028519853513317
|
| 1356 |
+
* x
|
| 1357 |
+
* (x2 - 3.0 * y2)
|
| 1358 |
+
* (
|
| 1359 |
+
-7.33333333333333 * z * (52.5 - 472.5 * z2)
|
| 1360 |
+
+ 3.0
|
| 1361 |
+
* z
|
| 1362 |
+
* (
|
| 1363 |
+
3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
|
| 1364 |
+
+ 1063.125 * z2
|
| 1365 |
+
- 118.125
|
| 1366 |
+
)
|
| 1367 |
+
- 560.0 * z
|
| 1368 |
+
),
|
| 1369 |
+
0.000368189725644507
|
| 1370 |
+
* (-6.0 * x2 * y2 + x4 + y4)
|
| 1371 |
+
* (
|
| 1372 |
+
3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
|
| 1373 |
+
- 14293.125 * z2
|
| 1374 |
+
+ 1299.375
|
| 1375 |
+
),
|
| 1376 |
+
5.10587282657803e-5
|
| 1377 |
+
* x
|
| 1378 |
+
* (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
|
| 1379 |
+
* (-10.0 * x2 * y2 + x4 + 5.0 * y4),
|
| 1380 |
+
7.87853281621404e-6
|
| 1381 |
+
* (1013512.5 * z2 - 67567.5)
|
| 1382 |
+
* (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
|
| 1383 |
+
-2.91570664069932
|
| 1384 |
+
* xz
|
| 1385 |
+
* (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
|
| 1386 |
+
-20.4099464848952 * x2**3 * y2
|
| 1387 |
+
- 20.4099464848952 * x2 * y2**3
|
| 1388 |
+
+ 0.72892666017483 * x4**2
|
| 1389 |
+
+ 51.0248662122381 * x4 * y4
|
| 1390 |
+
+ 0.72892666017483 * y4**2,
|
| 1391 |
+
],
|
| 1392 |
+
-1,
|
| 1393 |
+
)
|
| 1394 |
+
|
| 1395 |
+
|
| 1396 |
+
__all__ = [
|
| 1397 |
+
"rsh_cart_0",
|
| 1398 |
+
"rsh_cart_1",
|
| 1399 |
+
"rsh_cart_2",
|
| 1400 |
+
"rsh_cart_3",
|
| 1401 |
+
"rsh_cart_4",
|
| 1402 |
+
"rsh_cart_5",
|
| 1403 |
+
"rsh_cart_6",
|
| 1404 |
+
"rsh_cart_7",
|
| 1405 |
+
"rsh_cart_8",
|
| 1406 |
+
]
|
| 1407 |
+
|
| 1408 |
+
|
| 1409 |
+
from typing import Optional
|
| 1410 |
+
import torch
|
| 1411 |
+
|
| 1412 |
+
|
| 1413 |
+
class SphHarm(torch.nn.Module):
|
| 1414 |
+
def __init__(self, m, n, dtype=torch.float32) -> None:
|
| 1415 |
+
super().__init__()
|
| 1416 |
+
self.dtype = dtype
|
| 1417 |
+
m = torch.tensor(list(range(-m + 1, m)))
|
| 1418 |
+
n = torch.tensor(list(range(n)))
|
| 1419 |
+
self.is_normalized = False
|
| 1420 |
+
vals = torch.cartesian_prod(m, n).T
|
| 1421 |
+
vals = vals[:, vals[0] <= vals[1]]
|
| 1422 |
+
m, n = vals.unbind(0)
|
| 1423 |
+
|
| 1424 |
+
self.register_buffer("m", tensor=m)
|
| 1425 |
+
self.register_buffer("n", tensor=n)
|
| 1426 |
+
self.register_buffer("l_max", tensor=torch.max(self.n))
|
| 1427 |
+
|
| 1428 |
+
f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
|
| 1429 |
+
self.register_buffer("f_a", tensor=f_a)
|
| 1430 |
+
self.register_buffer("f_b", tensor=f_b)
|
| 1431 |
+
self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
|
| 1432 |
+
self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
|
| 1433 |
+
self.register_buffer("initial_value", tensor=initial_value)
|
| 1434 |
+
|
| 1435 |
+
@property
|
| 1436 |
+
def device(self):
|
| 1437 |
+
return next(self.buffers()).device
|
| 1438 |
+
|
| 1439 |
+
def forward(self, points: torch.Tensor) -> torch.Tensor:
|
| 1440 |
+
"""Computes the spherical harmonics."""
|
| 1441 |
+
# Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
|
| 1442 |
+
B, N, D = points.shape
|
| 1443 |
+
dtype = points.dtype
|
| 1444 |
+
theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
|
| 1445 |
+
cos_colatitude = torch.cos(phi)
|
| 1446 |
+
legendre = self._gen_associated_legendre(cos_colatitude)
|
| 1447 |
+
vals = torch.stack([self.m.abs(), self.n], dim=0)
|
| 1448 |
+
vals = torch.cat(
|
| 1449 |
+
[
|
| 1450 |
+
vals.repeat(1, theta.shape[0]),
|
| 1451 |
+
torch.arange(theta.shape[0], device=theta.device)
|
| 1452 |
+
.unsqueeze(0)
|
| 1453 |
+
.repeat_interleave(vals.shape[1], dim=1),
|
| 1454 |
+
],
|
| 1455 |
+
dim=0,
|
| 1456 |
+
)
|
| 1457 |
+
legendre_vals = legendre[vals[0], vals[1], vals[2]]
|
| 1458 |
+
legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
|
| 1459 |
+
angle = torch.outer(self.m.abs(), theta)
|
| 1460 |
+
vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
|
| 1461 |
+
harmonics = torch.complex(
|
| 1462 |
+
legendre_vals * torch.real(vandermonde),
|
| 1463 |
+
legendre_vals * torch.imag(vandermonde),
|
| 1464 |
+
)
|
| 1465 |
+
|
| 1466 |
+
# Negative order.
|
| 1467 |
+
m = self.m.unsqueeze(-1)
|
| 1468 |
+
harmonics = torch.where(
|
| 1469 |
+
m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
|
| 1470 |
+
)
|
| 1471 |
+
harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
|
| 1472 |
+
return harmonics
|
| 1473 |
+
|
| 1474 |
+
def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
|
| 1475 |
+
"""Generates mask for recurrence relation on the remaining entries.
|
| 1476 |
+
|
| 1477 |
+
The remaining entries are with respect to the diagonal and offdiagonal
|
| 1478 |
+
entries.
|
| 1479 |
+
|
| 1480 |
+
Args:
|
| 1481 |
+
l_max: see `gen_normalized_legendre`.
|
| 1482 |
+
Returns:
|
| 1483 |
+
torch.Tensors representing the mask used by the recurrence relations.
|
| 1484 |
+
"""
|
| 1485 |
+
|
| 1486 |
+
# Computes all coefficients.
|
| 1487 |
+
m_mat, l_mat = torch.meshgrid(
|
| 1488 |
+
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
|
| 1489 |
+
torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
|
| 1490 |
+
indexing="ij",
|
| 1491 |
+
)
|
| 1492 |
+
if self.is_normalized:
|
| 1493 |
+
c0 = l_mat * l_mat
|
| 1494 |
+
c1 = m_mat * m_mat
|
| 1495 |
+
c2 = 2.0 * l_mat
|
| 1496 |
+
c3 = (l_mat - 1.0) * (l_mat - 1.0)
|
| 1497 |
+
d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
|
| 1498 |
+
d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
|
| 1499 |
+
else:
|
| 1500 |
+
d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
|
| 1501 |
+
d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
|
| 1502 |
+
|
| 1503 |
+
d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
|
| 1504 |
+
d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
|
| 1505 |
+
|
| 1506 |
+
d_zeros = torch.zeros(
|
| 1507 |
+
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
|
| 1508 |
+
)
|
| 1509 |
+
d_zeros[d0_mask_indices] = d0[d0_mask_indices]
|
| 1510 |
+
d0_mask = d_zeros
|
| 1511 |
+
|
| 1512 |
+
d_zeros = torch.zeros(
|
| 1513 |
+
(self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
|
| 1514 |
+
)
|
| 1515 |
+
d_zeros[d1_mask_indices] = d1[d1_mask_indices]
|
| 1516 |
+
d1_mask = d_zeros
|
| 1517 |
+
|
| 1518 |
+
# Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
|
| 1519 |
+
i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
|
| 1520 |
+
j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
|
| 1521 |
+
k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
|
| 1522 |
+
mask = (i + j - k == 0).to(self.dtype)
|
| 1523 |
+
d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
|
| 1524 |
+
d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
|
| 1525 |
+
return (d0_mask_3d, d1_mask_3d)
|
| 1526 |
+
|
| 1527 |
+
def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
| 1528 |
+
coeff_0 = self.d0_mask_3d[i]
|
| 1529 |
+
coeff_1 = self.d1_mask_3d[i]
|
| 1530 |
+
h = torch.einsum(
|
| 1531 |
+
"ij,ijk->ijk",
|
| 1532 |
+
coeff_0,
|
| 1533 |
+
torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
|
| 1534 |
+
) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
|
| 1535 |
+
p_val = p_val + h
|
| 1536 |
+
return p_val
|
| 1537 |
+
|
| 1538 |
+
def _init_legendre(self):
|
| 1539 |
+
a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
|
| 1540 |
+
b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
|
| 1541 |
+
if self.is_normalized:
|
| 1542 |
+
# The initial value p(0,0).
|
| 1543 |
+
initial_value: torch.Tensor = torch.tensor(
|
| 1544 |
+
0.5 / (torch.pi**0.5), device=self.device
|
| 1545 |
+
)
|
| 1546 |
+
f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
|
| 1547 |
+
f_b = torch.sqrt(2.0 * b_idx + 3.0)
|
| 1548 |
+
else:
|
| 1549 |
+
# The initial value p(0,0).
|
| 1550 |
+
initial_value = torch.tensor(1.0, device=self.device)
|
| 1551 |
+
f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
|
| 1552 |
+
f_b = 2.0 * b_idx + 1.0
|
| 1553 |
+
|
| 1554 |
+
d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
|
| 1555 |
+
return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
|
| 1556 |
+
|
| 1557 |
+
def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
|
| 1558 |
+
r"""Computes associated Legendre functions (ALFs) of the first kind.
|
| 1559 |
+
|
| 1560 |
+
The ALFs of the first kind are used in spherical harmonics. The spherical
|
| 1561 |
+
harmonic of degree `l` and order `m` can be written as
|
| 1562 |
+
`Y_l^m(ΞΈ, Ο) = N_l^m * P_l^m(cos(ΞΈ)) * exp(i m Ο)`, where `N_l^m` is the
|
| 1563 |
+
normalization factor and ΞΈ and Ο are the colatitude and longitude,
|
| 1564 |
+
repectively. `N_l^m` is chosen in the way that the spherical harmonics form
|
| 1565 |
+
a set of orthonormal basis function of L^2(S^2). For the computational
|
| 1566 |
+
efficiency of spherical harmonics transform, the normalization factor is
|
| 1567 |
+
used in the computation of the ALFs. In addition, normalizing `P_l^m`
|
| 1568 |
+
avoids overflow/underflow and achieves better numerical stability. Three
|
| 1569 |
+
recurrence relations are used in the computation.
|
| 1570 |
+
|
| 1571 |
+
Args:
|
| 1572 |
+
l_max: The maximum degree of the associated Legendre function. Both the
|
| 1573 |
+
degrees and orders are `[0, 1, 2, ..., l_max]`.
|
| 1574 |
+
x: A vector of type `float32`, `float64` containing the sampled points in
|
| 1575 |
+
spherical coordinates, at which the ALFs are computed; `x` is essentially
|
| 1576 |
+
`cos(ΞΈ)`. For the numerical integration used by the spherical harmonics
|
| 1577 |
+
transforms, `x` contains the quadrature points in the interval of
|
| 1578 |
+
`[-1, 1]`. There are several approaches to provide the quadrature points:
|
| 1579 |
+
Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
|
| 1580 |
+
method (`scipy.special.roots_chebyu`), and Driscoll & Healy
|
| 1581 |
+
method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
|
| 1582 |
+
transforms and convolutions on the 2-sphere." Advances in applied
|
| 1583 |
+
mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
|
| 1584 |
+
points are nearly equal-spaced along ΞΈ and provide exact discrete
|
| 1585 |
+
orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
|
| 1586 |
+
operation, `W` is a diagonal matrix containing the quadrature weights,
|
| 1587 |
+
and `I` is the identity matrix. The Gauss-Chebyshev points are equally
|
| 1588 |
+
spaced, which only provide approximate discrete orthogonality. The
|
| 1589 |
+
Driscoll & Healy qudarture points are equally spaced and provide the
|
| 1590 |
+
exact discrete orthogonality. The number of sampling points is required to
|
| 1591 |
+
be twice as the number of frequency points (modes) in the Driscoll & Healy
|
| 1592 |
+
approach, which enables FFT and achieves a fast spherical harmonics
|
| 1593 |
+
transform.
|
| 1594 |
+
is_normalized: True if the associated Legendre functions are normalized.
|
| 1595 |
+
With normalization, `N_l^m` is applied such that the spherical harmonics
|
| 1596 |
+
form a set of orthonormal basis functions of L^2(S^2).
|
| 1597 |
+
|
| 1598 |
+
Returns:
|
| 1599 |
+
The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
|
| 1600 |
+
of the ALFs at `x`; the dimensions in the sequence of order, degree, and
|
| 1601 |
+
evalution points.
|
| 1602 |
+
"""
|
| 1603 |
+
p = torch.zeros(
|
| 1604 |
+
(self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
|
| 1605 |
+
)
|
| 1606 |
+
p[0, 0] = self.initial_value
|
| 1607 |
+
|
| 1608 |
+
# Compute the diagonal entries p(l,l) with recurrence.
|
| 1609 |
+
y = torch.cumprod(
|
| 1610 |
+
torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
|
| 1611 |
+
)
|
| 1612 |
+
p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
|
| 1613 |
+
# torch.diag_indices(l_max + 1)
|
| 1614 |
+
diag_indices = torch.stack(
|
| 1615 |
+
[torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
|
| 1616 |
+
)
|
| 1617 |
+
p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
|
| 1618 |
+
|
| 1619 |
+
diag_indices = torch.stack(
|
| 1620 |
+
[torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
|
| 1621 |
+
)
|
| 1622 |
+
|
| 1623 |
+
# Compute the off-diagonal entries with recurrence.
|
| 1624 |
+
p_offdiag = torch.einsum(
|
| 1625 |
+
"ij,ij->ij",
|
| 1626 |
+
torch.einsum("i,j->ij", self.f_b, x),
|
| 1627 |
+
p[(diag_indices[0], diag_indices[1])],
|
| 1628 |
+
) # p[torch.diag_indices(l_max)])
|
| 1629 |
+
p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
|
| 1630 |
+
p_offdiag
|
| 1631 |
+
)
|
| 1632 |
+
|
| 1633 |
+
# Compute the remaining entries with recurrence.
|
| 1634 |
+
if self.l_max > 1:
|
| 1635 |
+
for i in range(2, self.l_max + 1):
|
| 1636 |
+
p = self._recursive(i, p, x)
|
| 1637 |
+
return p
|
flash3d/unidepth/utils/visualization.py
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Author: Luigi Piccinelli
|
| 3 |
+
Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
from PIL import Image
|
| 10 |
+
import matplotlib.cm
|
| 11 |
+
import wandb
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from unidepth.utils.misc import ssi_helper
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def colorize(
|
| 18 |
+
value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r"
|
| 19 |
+
):
|
| 20 |
+
# if already RGB, do nothing
|
| 21 |
+
if value.ndim > 2:
|
| 22 |
+
if value.shape[-1] > 1:
|
| 23 |
+
return value
|
| 24 |
+
value = value[..., 0]
|
| 25 |
+
invalid_mask = value < 0.0001
|
| 26 |
+
# normalize
|
| 27 |
+
vmin = value.min() if vmin is None else vmin
|
| 28 |
+
vmax = value.max() if vmax is None else vmax
|
| 29 |
+
value = (value - vmin) / (vmax - vmin) # vmin..vmax
|
| 30 |
+
|
| 31 |
+
# set color
|
| 32 |
+
cmapper = matplotlib.cm.get_cmap(cmap)
|
| 33 |
+
value = cmapper(value, bytes=True) # (nxmx4)
|
| 34 |
+
value[invalid_mask] = 0
|
| 35 |
+
img = value[..., :3]
|
| 36 |
+
return img
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray:
|
| 40 |
+
if not len(imgs):
|
| 41 |
+
return None
|
| 42 |
+
assert len(imgs) == rows * cols
|
| 43 |
+
h, w = imgs[0].shape[:2]
|
| 44 |
+
grid = Image.new("RGB", size=(cols * w, rows * h))
|
| 45 |
+
|
| 46 |
+
for i, img in enumerate(imgs):
|
| 47 |
+
grid.paste(
|
| 48 |
+
Image.fromarray(img.astype(np.uint8)).resize(
|
| 49 |
+
(w, h), resample=Image.BILINEAR
|
| 50 |
+
),
|
| 51 |
+
box=(i % cols * w, i // cols * h),
|
| 52 |
+
)
|
| 53 |
+
|
| 54 |
+
return np.array(grid)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def get_pointcloud_from_rgbd(
|
| 58 |
+
image: np.array,
|
| 59 |
+
depth: np.array,
|
| 60 |
+
mask: np.ndarray,
|
| 61 |
+
intrinsic_matrix: np.array,
|
| 62 |
+
extrinsic_matrix: np.array = None,
|
| 63 |
+
):
|
| 64 |
+
depth = np.array(depth).squeeze()
|
| 65 |
+
mask = np.array(mask).squeeze()
|
| 66 |
+
# Mask the depth array
|
| 67 |
+
masked_depth = np.ma.masked_where(mask == False, depth)
|
| 68 |
+
# masked_depth = np.ma.masked_greater(masked_depth, 8000)
|
| 69 |
+
# Create idx array
|
| 70 |
+
idxs = np.indices(masked_depth.shape)
|
| 71 |
+
u_idxs = idxs[1]
|
| 72 |
+
v_idxs = idxs[0]
|
| 73 |
+
# Get only non-masked depth and idxs
|
| 74 |
+
z = masked_depth[~masked_depth.mask]
|
| 75 |
+
compressed_u_idxs = u_idxs[~masked_depth.mask]
|
| 76 |
+
compressed_v_idxs = v_idxs[~masked_depth.mask]
|
| 77 |
+
image = np.stack(
|
| 78 |
+
[image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# Calculate local position of each point
|
| 82 |
+
# Apply vectorized math to depth using compressed arrays
|
| 83 |
+
cx = intrinsic_matrix[0, 2]
|
| 84 |
+
fx = intrinsic_matrix[0, 0]
|
| 85 |
+
x = (compressed_u_idxs - cx) * z / fx
|
| 86 |
+
cy = intrinsic_matrix[1, 2]
|
| 87 |
+
fy = intrinsic_matrix[1, 1]
|
| 88 |
+
# Flip y as we want +y pointing up not down
|
| 89 |
+
y = -((compressed_v_idxs - cy) * z / fy)
|
| 90 |
+
|
| 91 |
+
# # Apply camera_matrix to pointcloud as to get the pointcloud in world coords
|
| 92 |
+
# if extrinsic_matrix is not None:
|
| 93 |
+
# # Calculate camera pose from extrinsic matrix
|
| 94 |
+
# camera_matrix = np.linalg.inv(extrinsic_matrix)
|
| 95 |
+
# # Create homogenous array of vectors by adding 4th entry of 1
|
| 96 |
+
# # At the same time flip z as for eye space the camera is looking down the -z axis
|
| 97 |
+
# w = np.ones(z.shape)
|
| 98 |
+
# x_y_z_eye_hom = np.vstack((x, y, -z, w))
|
| 99 |
+
# # Transform the points from eye space to world space
|
| 100 |
+
# x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3]
|
| 101 |
+
# return x_y_z_world.T
|
| 102 |
+
# else:
|
| 103 |
+
x_y_z_local = np.stack((x, y, z), axis=-1)
|
| 104 |
+
return np.concatenate([x_y_z_local, image], axis=-1)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def save_file_ply(xyz, rgb, pc_file):
|
| 108 |
+
if rgb.max() < 1.001:
|
| 109 |
+
rgb = rgb * 255.0
|
| 110 |
+
rgb = rgb.astype(np.uint8)
|
| 111 |
+
# print(rgb)
|
| 112 |
+
with open(pc_file, "w") as f:
|
| 113 |
+
# headers
|
| 114 |
+
f.writelines(
|
| 115 |
+
[
|
| 116 |
+
"ply\n" "format ascii 1.0\n",
|
| 117 |
+
"element vertex {}\n".format(xyz.shape[0]),
|
| 118 |
+
"property float x\n",
|
| 119 |
+
"property float y\n",
|
| 120 |
+
"property float z\n",
|
| 121 |
+
"property uchar red\n",
|
| 122 |
+
"property uchar green\n",
|
| 123 |
+
"property uchar blue\n",
|
| 124 |
+
"end_header\n",
|
| 125 |
+
]
|
| 126 |
+
)
|
| 127 |
+
|
| 128 |
+
for i in range(xyz.shape[0]):
|
| 129 |
+
str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format(
|
| 130 |
+
xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2]
|
| 131 |
+
)
|
| 132 |
+
f.write(str_v)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
# really awful fct... FIXME
|
| 136 |
+
def log_train_artifacts(rgbs, gts, preds, ds_name, step, infos={}):
|
| 137 |
+
rgbs = [
|
| 138 |
+
(127.5 * (rgb + 1))
|
| 139 |
+
.clip(0, 255)
|
| 140 |
+
.to(torch.uint8)
|
| 141 |
+
.cpu()
|
| 142 |
+
.detach()
|
| 143 |
+
.permute(1, 2, 0)
|
| 144 |
+
.numpy()
|
| 145 |
+
for rgb in rgbs
|
| 146 |
+
]
|
| 147 |
+
|
| 148 |
+
new_gts, new_preds = [], []
|
| 149 |
+
if len(gts) > 0:
|
| 150 |
+
for i, gt in enumerate(gts):
|
| 151 |
+
scale, shift = ssi_helper(
|
| 152 |
+
gts[i][gts[i] > 0].cpu().detach(), preds[i][gts[i] > 0].cpu().detach()
|
| 153 |
+
)
|
| 154 |
+
gt = gts[i].cpu().detach().squeeze().numpy()
|
| 155 |
+
pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy()
|
| 156 |
+
vmin = gt[gt > 0].min() if (gt > 0).any() else 0.0
|
| 157 |
+
vmax = gt.max() if (gt > 0).any() else 0.1
|
| 158 |
+
new_gts.append(colorize(gt, vmin=vmin, vmax=vmax))
|
| 159 |
+
new_preds.append(colorize(pred, vmin=vmin, vmax=vmax))
|
| 160 |
+
gts, preds = new_gts, new_preds
|
| 161 |
+
else:
|
| 162 |
+
preds = [
|
| 163 |
+
colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0)
|
| 164 |
+
for i, pred in enumerate(preds)
|
| 165 |
+
]
|
| 166 |
+
|
| 167 |
+
num_additional, additionals = 0, []
|
| 168 |
+
for name, info in infos.items():
|
| 169 |
+
num_additional += 1
|
| 170 |
+
if info.shape[1] == 3:
|
| 171 |
+
additionals.extend(
|
| 172 |
+
[
|
| 173 |
+
(127.5 * (x + 1))
|
| 174 |
+
.clip(0, 255)
|
| 175 |
+
.to(torch.uint8)
|
| 176 |
+
.cpu()
|
| 177 |
+
.detach()
|
| 178 |
+
.permute(1, 2, 0)
|
| 179 |
+
.numpy()
|
| 180 |
+
for x in info[:4]
|
| 181 |
+
]
|
| 182 |
+
)
|
| 183 |
+
else:
|
| 184 |
+
additionals.extend(
|
| 185 |
+
[
|
| 186 |
+
colorize(x.cpu().detach().squeeze().numpy())
|
| 187 |
+
for i, x in enumerate(info[:4])
|
| 188 |
+
]
|
| 189 |
+
)
|
| 190 |
+
|
| 191 |
+
num_rows = 2 + int(len(gts) > 0) + num_additional
|
| 192 |
+
artifacts_grid = image_grid(
|
| 193 |
+
[*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs)
|
| 194 |
+
)
|
| 195 |
+
try:
|
| 196 |
+
wandb.log({f"{ds_name}_training": [wandb.Image(artifacts_grid)]}, step=step)
|
| 197 |
+
except:
|
| 198 |
+
Image.fromarray(artifacts_grid).save(
|
| 199 |
+
os.path.join(os.environ["HOME"], "Workspace", f"art_grid{step}.png")
|
| 200 |
+
)
|
| 201 |
+
print("Logging training images failed")
|
flash3d/util/vis3d.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from jaxtyping import Float
|
| 3 |
+
import numpy as np
|
| 4 |
+
from scipy.spatial.transform import Rotation as R
|
| 5 |
+
from plyfile import PlyData, PlyElement
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
from einops import rearrange, einsum
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def construct_list_of_attributes(num_rest: int) -> list[str]:
|
| 12 |
+
attributes = ["x", "y", "z", "nx", "ny", "nz"]
|
| 13 |
+
for i in range(3):
|
| 14 |
+
attributes.append(f"f_dc_{i}")
|
| 15 |
+
for i in range(num_rest):
|
| 16 |
+
attributes.append(f"f_rest_{i}")
|
| 17 |
+
attributes.append("opacity")
|
| 18 |
+
for i in range(3):
|
| 19 |
+
attributes.append(f"scale_{i}")
|
| 20 |
+
for i in range(4):
|
| 21 |
+
attributes.append(f"rot_{i}")
|
| 22 |
+
return attributes
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def export_ply(
|
| 26 |
+
means: Float[Tensor, "gaussian 3"],
|
| 27 |
+
scales: Float[Tensor, "gaussian 3"],
|
| 28 |
+
rotations: Float[Tensor, "gaussian 4"],
|
| 29 |
+
harmonics: Float[Tensor, "gaussian 3 d_sh"],
|
| 30 |
+
opacities: Float[Tensor, "gaussian"],
|
| 31 |
+
path: Path,
|
| 32 |
+
):
|
| 33 |
+
path = Path(path)
|
| 34 |
+
# Shift the scene so that the median Gaussian is at the origin.
|
| 35 |
+
means = means - means.median(dim=0).values
|
| 36 |
+
|
| 37 |
+
# Rescale the scene so that most Gaussians are within range [-1, 1].
|
| 38 |
+
scale_factor = means.abs().quantile(0.95, dim=0).max()
|
| 39 |
+
means = means / scale_factor
|
| 40 |
+
scales = scales / scale_factor
|
| 41 |
+
scales = scales * 4.0
|
| 42 |
+
scales = torch.clamp(scales, 0, 0.0075)
|
| 43 |
+
|
| 44 |
+
# Define a rotation that makes +Z be the world up vector.
|
| 45 |
+
# rotation = [
|
| 46 |
+
# [0, 0, 1],
|
| 47 |
+
# [-1, 0, 0],
|
| 48 |
+
# [0, -1, 0],
|
| 49 |
+
# ]
|
| 50 |
+
rotation = [
|
| 51 |
+
[1, 0, 0],
|
| 52 |
+
[0, 1, 0],
|
| 53 |
+
[0, 0, 1],
|
| 54 |
+
]
|
| 55 |
+
rotation = torch.tensor(rotation, dtype=torch.float32, device=means.device)
|
| 56 |
+
|
| 57 |
+
# The Polycam viewer seems to start at a 45 degree angle. Since we want to be
|
| 58 |
+
# looking directly at the object, we compose a 45 degree rotation onto the above
|
| 59 |
+
# rotation.
|
| 60 |
+
# adjustment = torch.tensor(
|
| 61 |
+
# R.from_rotvec([0, 0, -45], True).as_matrix(),
|
| 62 |
+
# dtype=torch.float32,
|
| 63 |
+
# device=means.device,
|
| 64 |
+
# )
|
| 65 |
+
# rotation = adjustment @ rotation
|
| 66 |
+
|
| 67 |
+
# We also want to see the scene in camera space (as the default view). We therefore
|
| 68 |
+
# compose the w2c rotation onto the above rotation.
|
| 69 |
+
# rotation = rotation @ extrinsics[:3, :3].inverse()
|
| 70 |
+
|
| 71 |
+
# Apply the rotation to the means (Gaussian positions).
|
| 72 |
+
means = einsum(rotation, means, "i j, ... j -> ... i")
|
| 73 |
+
|
| 74 |
+
# Apply the rotation to the Gaussian rotations.
|
| 75 |
+
rotations = R.from_quat(rotations.detach().cpu().numpy()).as_matrix()
|
| 76 |
+
rotations = rotation.detach().cpu().numpy() @ rotations
|
| 77 |
+
rotations = R.from_matrix(rotations).as_quat()
|
| 78 |
+
x, y, z, w = rearrange(rotations, "g xyzw -> xyzw g")
|
| 79 |
+
rotations = np.stack((w, x, y, z), axis=-1)
|
| 80 |
+
|
| 81 |
+
# Since our axes are swizzled for the spherical harmonics, we only export the DC
|
| 82 |
+
# band.
|
| 83 |
+
harmonics_view_invariant = harmonics
|
| 84 |
+
|
| 85 |
+
dtype_full = [(attribute, "f4") for attribute in construct_list_of_attributes(0)]
|
| 86 |
+
elements = np.empty(means.shape[0], dtype=dtype_full)
|
| 87 |
+
attributes = (
|
| 88 |
+
means.detach().cpu().numpy(),
|
| 89 |
+
torch.zeros_like(means).detach().cpu().numpy(),
|
| 90 |
+
harmonics_view_invariant.detach().cpu().contiguous().numpy(),
|
| 91 |
+
opacities.detach().cpu().numpy(),
|
| 92 |
+
scales.log().detach().cpu().numpy(),
|
| 93 |
+
rotations,
|
| 94 |
+
)
|
| 95 |
+
attributes = np.concatenate(attributes, axis=1)
|
| 96 |
+
elements[:] = list(map(tuple, attributes))
|
| 97 |
+
path.parent.mkdir(exist_ok=True, parents=True)
|
| 98 |
+
PlyData([PlyElement.describe(elements, "vertex")]).write(path)
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def save_ply(outputs, path, num_gauss=3):
|
| 102 |
+
pad = 32
|
| 103 |
+
|
| 104 |
+
def crop_r(t):
|
| 105 |
+
h, w = 256, 384
|
| 106 |
+
H = h + pad * 2
|
| 107 |
+
W = w + pad * 2
|
| 108 |
+
t = rearrange(t, "b c (h w) -> b c h w", h=H, w=W)
|
| 109 |
+
t = t[..., pad:H-pad, pad:W-pad]
|
| 110 |
+
t = rearrange(t, "b c h w -> b c (h w)")
|
| 111 |
+
return t
|
| 112 |
+
|
| 113 |
+
def crop(t):
|
| 114 |
+
h, w = 256, 384
|
| 115 |
+
H = h + pad * 2
|
| 116 |
+
W = w + pad * 2
|
| 117 |
+
t = t[..., pad:H-pad, pad:W-pad]
|
| 118 |
+
return t
|
| 119 |
+
|
| 120 |
+
# import pdb
|
| 121 |
+
# pdb.set_trace()
|
| 122 |
+
means = rearrange(crop_r(outputs[('gauss_means', 0, 0)]), "(b v) c n -> b (v n) c", v=num_gauss)[0, :, :3]
|
| 123 |
+
scales = rearrange(crop(outputs[('gauss_scaling', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|
| 124 |
+
rotations = rearrange(crop(outputs[('gauss_rotation', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|
| 125 |
+
opacities = rearrange(crop(outputs[('gauss_opacity', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|
| 126 |
+
harmonics = rearrange(crop(outputs[('gauss_features_dc', 0, 0)]), "(b v) c h w -> b (v h w) c", v=num_gauss)[0]
|
| 127 |
+
|
| 128 |
+
export_ply(
|
| 129 |
+
means,
|
| 130 |
+
scales,
|
| 131 |
+
rotations,
|
| 132 |
+
harmonics,
|
| 133 |
+
opacities,
|
| 134 |
+
path
|
| 135 |
+
)
|