File size: 14,656 Bytes
f06f310 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 |
import os
import sys
import numpy as np
sys.path.insert(0,'Metric3D')
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as T
from attrdict import AttrDict
from core.extractor import ResidualBlock
from depth_anything_v2.dpt import DepthAnythingV2
from core.utils.utils import sv_intermediate_results
def resize_tensor(tensor, target_size=512, ratio=16):
# 获取输入 tensor 的尺寸 (B, C, H, W)
_, _, H, W = tensor.shape
# 计算 H 和 W 中较长的一边
if H > W:
new_H = target_size
new_W = int(W * (target_size / H))
else:
new_W = target_size
new_H = int(H * (target_size / W))
new_W = (np.ceil(new_W / ratio) * ratio).astype(int)
new_H = (np.ceil(new_H / ratio) * ratio).astype(int)
# 使用 interpolate 进行缩放
resized_tensor = F.interpolate(tensor, size=(new_H, new_W), mode='bicubic', align_corners=False)
return resized_tensor
def resize_to_quarter(tensor, original_size, ratio):
# 将尺寸缩小为原始尺寸的 1/4
quarter_H = original_size[0] // ratio
quarter_W = original_size[1] // ratio
# 使用 interpolate 进行缩小
resized_tensor = F.interpolate(tensor, size=(quarter_H, quarter_W), mode='bilinear', align_corners=False)
return resized_tensor
from mono.utils.comm import get_func
class Metric3DExtractor(nn.Module):
def __init__(self, args) -> None:
super(Metric3DExtractor, self).__init__()
self.args = args
cfg = dict(
model = dict(
type='DensePredModel',
backbone=dict(
type='vit_large_reg',
prefix='backbones.',
out_channels=[1024, 1024, 1024, 1024],
drop_path_rate = 0.0,
checkpoint="./pretrained/metric3d/dinov2_vitl14_reg4_pretrain.pth",
),
decode_head=dict(
type='RAFTDepthNormalDPT5',
# type='RAFTDepthDPT',
prefix='decode_heads.',
in_channels=[1024, 1024, 1024, 1024],
use_cls_token=True,
feature_channels = [256, 512, 1024, 1024], # [2/7, 1/7, 1/14, 1/14]
decoder_channels = [128, 256, 512, 1024, 1024], # [4/7, 2/7, 1/7, 1/14, 1/14]
up_scale = 7,
hidden_channels=[128, 128, 128, 128], # [x_4, x_8, x_16, x_32] [192, 384, 768, 1536]
n_gru_layers=3,
n_downsample=2,
iters=8,
slow_fast_gru=True,
num_register_tokens=4,
# detach=False
),
),
data_basic = dict(
canonical_space = dict(
# img_size=(540, 960),
focal_length=1000.0,
),
depth_range=(0, 1),
depth_normalize=(0.1, 200),
crop_size = (616, 1064), # %28 = 0
clip_depth_range=(0.1, 200),
vit_size=(616,1064)
),
)
self.cfg = AttrDict(cfg)
self.encoder = get_func('mono.model.' + self.cfg.model.backbone.prefix + self.cfg.model.backbone.type)(**self.cfg.model.backbone)
self.decoder = get_func('mono.model.' + self.cfg.model.decode_head.prefix + self.cfg.model.decode_head.type)(self.cfg)
# print(get_func('mono.model.' + self.cfg.model.backbone.prefix + self.cfg.model.backbone.type))
# print(self.encoder)
self.hidden_dims = self.cfg.model.decode_head.hidden_channels
self.n_gru_layers = self.cfg.model.decode_head.n_gru_layers
self.inp_convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(self.hidden_dims[i]*3, self.hidden_dims[i]*3, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.hidden_dims[i]*3, self.hidden_dims[i]*3, kernel_size=3, stride=1, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(self.hidden_dims[i]*3, self.hidden_dims[i]*3, kernel_size=3, stride=1, padding=1),
) for i in range(self.n_gru_layers)
])
self.net_convs = nn.ModuleList([
nn.Sequential(
nn.Conv2d(self.hidden_dims[i], self.hidden_dims[i], 3, padding=3//2),
nn.ReLU(inplace=True),
nn.Conv2d(self.hidden_dims[i], self.hidden_dims[i], 3, padding=3//2),
nn.ReLU(inplace=True),
nn.Conv2d(self.hidden_dims[i], self.hidden_dims[i], 3, padding=3//2),
) for i in range(self.n_gru_layers)
])
load_path = "./pretrained/metric3d/metric_depth_vit_large_800k.pth"
checkpoint = torch.load(load_path, map_location="cpu")
state_dict = checkpoint['model_state_dict']
encoder_state_dict = {k.replace("depth_model.encoder.", ""): v for k, v in state_dict.items() if k.startswith("depth_model.encoder")}
decoder_state_dict = {k.replace("depth_model.decoder.", ""): v for k, v in state_dict.items() if k.startswith("depth_model.decoder")}
self.encoder.load_state_dict(encoder_state_dict)
self.decoder.load_state_dict(decoder_state_dict)
self.encoder = self.encoder.to('cuda')
self.decoder = self.decoder.to('cuda')
# 冻结 depth_anything 模型的所有参数
for param in self.encoder.parameters():
param.requires_grad = False
for param in self.decoder.parameters():
param.requires_grad = False
mean = [123.675, 116.28, 103.53]
std = [58.395, 57.12, 57.375]
self.mean = torch.tensor(mean).view(1, 3, 1, 1).cuda()
self.std = torch.tensor(std).view(1, 3, 1, 1).cuda()
self.pad_val = torch.tensor(mean).view(1, 3, 1, 1).cuda()
def forward(self, rgb, intrinsic, baseline=1):
with torch.no_grad():
focal_length = (intrinsic[:, 0] + intrinsic[:, 1]) / 2
rgb_input, cam_model_stacks, pad, label_scale_factor, (ori_h, ori_w) = self.aug_data(rgb, intrinsic)
# [f_32, f_16, f_8, f_4]
features = self.encoder(rgb_input)
output = self.decoder(features, cam_model=cam_model_stacks)
# outputs=dict(
# prediction=flow_predictions[-1],
# predictions_list=flow_predictions,
# confidence=conf_predictions[-1],
# confidence_list=conf_predictions,
# pred_logit=None,
# # samples_pred_list=samples_pred_list,
# # coord_list=coord_list,
# prediction_normal=norma`l_outs[-1],
# normal_out_list=normal_outs,
# low_resolution_init=low_resolution_init,
# net_list = net_list,
# inp_list = inp_list,
# )
pred_depth, confidence = output['prediction'], output['confidence']
net_list, inp_list = output['net_list'], output['inp_list']
B, C, H_new, W_new = pred_depth.shape
normalize_scale = self.cfg.data_basic.depth_range[1]
pred_depth = pred_depth[:, :, pad[0] : H_new - pad[1], pad[2] : W_new - pad[3]]
pred_depth = F.interpolate(pred_depth, [ori_h, ori_w], mode='bilinear') # to original size
# print("-"*10, f"pred_depth: {pred_depth.shape}, confidence: {confidence.shape}", pred_depth.max(), pred_depth.min())
pred_depth = pred_depth * normalize_scale / label_scale_factor.unsqueeze(1).unsqueeze(1).unsqueeze(1)
# print("-"*10, pred_depth.max(), pred_depth.min(), normalize_scale, label_scale_factor, baseline, focal_length)
pred_disp = (baseline * focal_length).unsqueeze(1).unsqueeze(1).unsqueeze(1) / pred_depth
pred_disp_down = F.interpolate(pred_disp, scale_factor=1/2**self.cfg.model.decode_head.n_downsample, mode='bilinear') * (1/2**self.cfg.model.decode_head.n_downsample)
# print("*"*30, rgb.shape, rgb_input.shape, pred_depth.shape, confidence.shape, pred_disp_down.max(), pred_disp_down.min())
# with autocast(enabled=self.args.mixed_precision):
net_list = [F.interpolate(x, size=(ori_h//(2**(self.cfg.model.decode_head.n_downsample+i)),
ori_w//(2**(self.cfg.model.decode_head.n_downsample+i))),
mode='bilinear', align_corners=False) for i, x in enumerate(net_list)]
inp_list = [F.interpolate(torch.cat(x,dim=1),
size=(ori_h//(2**(self.cfg.model.decode_head.n_downsample+i)),
ori_w//(2**(self.cfg.model.decode_head.n_downsample+i))),
mode='bilinear', align_corners=False) for i, x in enumerate(inp_list)]
# Update the hidden states and context features
net_list = [conv(x) for x, conv in zip(net_list, self.net_convs)]
inp_list = [list( conv(x).chunk(3, dim=1) ) for x, conv in zip(inp_list, self.inp_convs)]
return net_list, inp_list, pred_disp_down
def aug_data(self, rgb, intrinsic):
B, C, ori_h, ori_w = rgb.shape
ori_focal = (intrinsic[:,0] + intrinsic[:,1]) / 2
canonical_focal = self.cfg.data_basic['canonical_space']['focal_length']
cano_label_scale_ratio = canonical_focal / ori_focal # Shape: (B,)
canonical_intrinsic = torch.stack([
intrinsic[:,0] * cano_label_scale_ratio,
intrinsic[:,1] * cano_label_scale_ratio,
intrinsic[:,2],
intrinsic[:,3],
], dim=1)
# resize
rgb, cam_model, pad, resize_label_scale_ratio = resize_for_input(rgb, self.cfg.data_basic.crop_size, canonical_intrinsic, [ori_h, ori_w], 1.0, self.pad_val)
# label scale factor
label_scale_factor = cano_label_scale_ratio * resize_label_scale_ratio # Shape: (B,)
rgb = torch.div(((rgb+1)/2*255 - self.mean), self.std)
cam_model = cam_model.permute((0, 3, 1, 2)).float()
cam_model = cam_model.cuda()
cam_model_stacks = [
torch.nn.functional.interpolate(cam_model, size=(cam_model.shape[2]//i, cam_model.shape[3]//i), mode='bilinear', align_corners=False)
for i in [2, 4, 8, 16, 32]
]
return rgb, cam_model_stacks, pad, label_scale_factor, (ori_h, ori_w)
def resize_for_input(image, output_shape, intrinsic, canonical_shape, to_canonical_ratio, pad_values):
"""
Resize the input using PyTorch tensors.
"""
h, w = image.shape[-2:]
resize_ratio_h = output_shape[0] / canonical_shape[0]
resize_ratio_w = output_shape[1] / canonical_shape[1]
to_scale_ratio = min(resize_ratio_h, resize_ratio_w)
resize_ratio = to_canonical_ratio * to_scale_ratio
reshape_h = int(resize_ratio * h)
reshape_w = int(resize_ratio * w)
pad_h = max(output_shape[0] - reshape_h, 0)
pad_w = max(output_shape[1] - reshape_w, 0)
pad_h_half = pad_h // 2
pad_w_half = pad_w // 2
# Resize image
image = F.interpolate(image, size=(reshape_h, reshape_w), mode='bilinear', align_corners=False)
# Padding
# image = F.pad(image, (pad_w_half, pad_w - pad_w_half, pad_h_half, pad_h - pad_h_half), value=pad_values)
image = pad_with_channel_values(image, (pad_w_half, pad_w - pad_w_half, pad_h_half, pad_h - pad_h_half), pad_values)
# Adjust intrinsic parameters
intrinsic[:, 2] *= to_scale_ratio # fx
intrinsic[:, 3] *= to_scale_ratio # fy
# Build camera model (dummy implementation, replace with actual function)
cam_model = build_camera_model(reshape_h, reshape_w, intrinsic)
cam_model = F.pad(cam_model, (pad_w_half, pad_w - pad_w_half, pad_h_half, pad_h - pad_h_half), value=-1)
pad = [pad_h_half, pad_h - pad_h_half, pad_w_half, pad_w - pad_w_half]
label_scale_factor = 1 / to_scale_ratio
return image, cam_model, pad, label_scale_factor
def pad_with_channel_values(input_tensor, padding, pad_values):
if isinstance(padding, int):
pad_left = pad_right = pad_top = pad_bottom = padding
else:
pad_left, pad_right, pad_top, pad_bottom = padding
B, C, H, W = input_tensor.shape
new_H = H + pad_top + pad_bottom
new_W = W + pad_left + pad_right
pad_values = pad_values.view(1, C, 1, 1)
padded_tensor = pad_values.expand(B, C, new_H, new_W).clone()
# 计算中间区域并复制数据
h_start, h_end = pad_top, new_H - pad_bottom
w_start, w_end = pad_left, new_W - pad_right
padded_tensor[:, :, h_start:h_end, w_start:w_end] = input_tensor
return padded_tensor
def build_camera_model(H: int, W: int, intrinsics: torch.Tensor) -> torch.Tensor:
"""
Encode the camera intrinsic parameters (focal length and principle point) to a 4-channel map.
Args:
H (int): Image height
W (int): Image width
intrinsics (torch.Tensor): Tensor of shape (B, 4) containing fx, fy, u0, v0
Returns:
torch.Tensor: Camera model tensor of shape (B, H, W, 4)
"""
B = intrinsics.shape[0]
fx, fy, u0, v0 = intrinsics[:, 0:1], intrinsics[:, 1:2], intrinsics[:, 2:3], intrinsics[:, 3:4]
f = (fx + fy) / 2.0 # Shape: (B,1)
# Generate normalized coordinate grids
x_row = torch.arange(W, dtype=torch.float32, device=intrinsics.device).view(1, W)
y_col = torch.arange(H, dtype=torch.float32, device=intrinsics.device).view(1, H)
# Normalize based on principal point
x_center = (x_row - u0) / W # Shape: (B, W)
y_center = (y_col - v0) / H # Shape: (B, H)
# Expand dimensions for batch processing
x_center = x_center.unsqueeze(1).expand(B, H, W) # Shape: (B, H, W)
y_center = y_center.unsqueeze(2).expand(B, H, W) # Shape: (B, H, W)
# Compute FoV angles
fov_x = torch.atan(x_center / (f.unsqueeze(1) / W)) # Shape: (B, H, W)
fov_y = torch.atan(y_center / (f.unsqueeze(1) / H)) # Shape: (B, H, W)
# Stack channels
cam_model = torch.stack([x_center, y_center, fov_x, fov_y], dim=-1) # Shape: (B, H, W, 4)
return cam_model
|