init project
Browse files
modules/dust3r/cloud_opt/base_opt.py
CHANGED
|
@@ -121,12 +121,11 @@ class BasePCOptimizer (nn.Module):
|
|
| 121 |
self.fix_imgs = rgb(ori_imgs)
|
| 122 |
self.smoothed_imgs = rgb(smoothed_imgs)
|
| 123 |
|
| 124 |
-
self.cogs = [torch.zeros((h, w, 1024)
|
| 125 |
-
semantic_feats = semantic_feats.to("cuda")
|
| 126 |
-
self.segmaps = [-torch.ones((h, w)
|
| 127 |
-
self.rev_segmaps = [-torch.ones((h, w)
|
| 128 |
-
|
| 129 |
-
# self.conf_2 = [torch.zeros((h, w), device="cuda") for h, w in self.imshapes]
|
| 130 |
for v in range(len(self.edges)):
|
| 131 |
idx = view1['idx'][v]
|
| 132 |
|
|
@@ -142,8 +141,8 @@ class BasePCOptimizer (nn.Module):
|
|
| 142 |
seg = cog_seg_map[y, x].squeeze(-1).long()
|
| 143 |
|
| 144 |
self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
|
| 145 |
-
self.segmaps[idx] = cog_seg_map
|
| 146 |
-
self.rev_segmaps[idx] = rev_seg_map
|
| 147 |
|
| 148 |
idx = view2['idx'][v]
|
| 149 |
h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
|
|
@@ -158,8 +157,8 @@ class BasePCOptimizer (nn.Module):
|
|
| 158 |
seg = cog_seg_map[y, x].squeeze(-1).long()
|
| 159 |
|
| 160 |
self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
|
| 161 |
-
self.segmaps[idx] = cog_seg_map
|
| 162 |
-
self.rev_segmaps[idx] = rev_seg_map
|
| 163 |
|
| 164 |
self.rendered_imgs = []
|
| 165 |
|
|
|
|
| 121 |
self.fix_imgs = rgb(ori_imgs)
|
| 122 |
self.smoothed_imgs = rgb(smoothed_imgs)
|
| 123 |
|
| 124 |
+
self.cogs = [torch.zeros((h, w, 1024)) for h, w in self.imshapes]
|
| 125 |
+
# semantic_feats = semantic_feats.to("cuda")
|
| 126 |
+
self.segmaps = [-torch.ones((h, w)) for h, w in self.imshapes]
|
| 127 |
+
self.rev_segmaps = [-torch.ones((h, w)) for h, w in self.imshapes]
|
| 128 |
+
|
|
|
|
| 129 |
for v in range(len(self.edges)):
|
| 130 |
idx = view1['idx'][v]
|
| 131 |
|
|
|
|
| 141 |
seg = cog_seg_map[y, x].squeeze(-1).long()
|
| 142 |
|
| 143 |
self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
|
| 144 |
+
self.segmaps[idx] = cog_seg_map#.cuda()
|
| 145 |
+
self.rev_segmaps[idx] = rev_seg_map#.cuda()
|
| 146 |
|
| 147 |
idx = view2['idx'][v]
|
| 148 |
h, w = self.cogs[idx].shape[0], self.cogs[idx].shape[1]
|
|
|
|
| 157 |
seg = cog_seg_map[y, x].squeeze(-1).long()
|
| 158 |
|
| 159 |
self.cogs[idx] = semantic_feats[seg].reshape(h, w, -1)
|
| 160 |
+
self.segmaps[idx] = cog_seg_map#.cuda()
|
| 161 |
+
self.rev_segmaps[idx] = rev_seg_map#.cuda()
|
| 162 |
|
| 163 |
self.rendered_imgs = []
|
| 164 |
|
modules/dust3r/cloud_opt/optimizer.py.bak.1216
DELETED
|
@@ -1,533 +0,0 @@
|
|
| 1 |
-
# Copyright (C) 2024-present Naver Corporation. All rights reserved.
|
| 2 |
-
# Licensed under CC BY-NC-SA 4.0 (non-commercial use only).
|
| 3 |
-
#
|
| 4 |
-
# --------------------------------------------------------
|
| 5 |
-
# Main class for the implementation of the global alignment
|
| 6 |
-
# --------------------------------------------------------
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
|
| 11 |
-
from dust3r.cloud_opt.base_opt import BasePCOptimizer
|
| 12 |
-
from dust3r.utils.geometry import xy_grid, geotrf
|
| 13 |
-
from dust3r.utils.device import to_cpu, to_numpy
|
| 14 |
-
import torch.nn.functional as F
|
| 15 |
-
|
| 16 |
-
class PointCloudOptimizer(BasePCOptimizer):
|
| 17 |
-
""" Optimize a global scene, given a list of pairwise observations.
|
| 18 |
-
Graph node: images
|
| 19 |
-
Graph edges: observations = (pred1, pred2)
|
| 20 |
-
"""
|
| 21 |
-
|
| 22 |
-
def __init__(self, *args, optimize_pp=False, focal_break=20, **kwargs):
|
| 23 |
-
super().__init__(*args, **kwargs)
|
| 24 |
-
|
| 25 |
-
self.has_im_poses = True # by definition of this class
|
| 26 |
-
self.focal_break = focal_break
|
| 27 |
-
|
| 28 |
-
# adding thing to optimize
|
| 29 |
-
self.im_depthmaps = nn.ParameterList(torch.randn(H, W)/10-3 for H, W in self.imshapes) # log(depth)
|
| 30 |
-
self.im_poses = nn.ParameterList(self.rand_pose(self.POSE_DIM) for _ in range(self.n_imgs)) # camera poses
|
| 31 |
-
self.im_focals = nn.ParameterList(torch.FloatTensor(
|
| 32 |
-
[self.focal_break*np.log(max(H, W))]) for H, W in self.imshapes) # camera intrinsics
|
| 33 |
-
self.im_pp = nn.ParameterList(torch.zeros((2,)) for _ in range(self.n_imgs)) # camera intrinsics
|
| 34 |
-
self.im_pp.requires_grad_(optimize_pp)
|
| 35 |
-
|
| 36 |
-
self.imshape = self.imshapes[0]
|
| 37 |
-
im_areas = [h*w for h, w in self.imshapes]
|
| 38 |
-
self.max_area = max(im_areas)
|
| 39 |
-
|
| 40 |
-
# adding thing to optimize
|
| 41 |
-
self.im_depthmaps = ParameterStack(self.im_depthmaps, is_param=True, fill=self.max_area)
|
| 42 |
-
self.im_poses = ParameterStack(self.im_poses, is_param=True)
|
| 43 |
-
self.im_focals = ParameterStack(self.im_focals, is_param=True)
|
| 44 |
-
self.im_pp = ParameterStack(self.im_pp, is_param=True)
|
| 45 |
-
self.register_buffer('_pp', torch.tensor([(w/2, h/2) for h, w in self.imshapes]))
|
| 46 |
-
self.register_buffer('_grid', ParameterStack(
|
| 47 |
-
[xy_grid(W, H, device=self.device) for H, W in self.imshapes], fill=self.max_area))
|
| 48 |
-
|
| 49 |
-
# pre-compute pixel weights
|
| 50 |
-
self.register_buffer('_weight_i', ParameterStack(
|
| 51 |
-
[self.conf_trf(self.conf_i[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
| 52 |
-
self.register_buffer('_weight_j', ParameterStack(
|
| 53 |
-
[self.conf_trf(self.conf_j[i_j]) for i_j in self.str_edges], fill=self.max_area))
|
| 54 |
-
|
| 55 |
-
# precompute aa
|
| 56 |
-
self.register_buffer('_stacked_pred_i', ParameterStack(self.pred_i, self.str_edges, fill=self.max_area))
|
| 57 |
-
self.register_buffer('_stacked_pred_j', ParameterStack(self.pred_j, self.str_edges, fill=self.max_area))
|
| 58 |
-
self.register_buffer('_ei', torch.tensor([i for i, j in self.edges]))
|
| 59 |
-
self.register_buffer('_ej', torch.tensor([j for i, j in self.edges]))
|
| 60 |
-
self.total_area_i = sum([im_areas[i] for i, j in self.edges])
|
| 61 |
-
self.total_area_j = sum([im_areas[j] for i, j in self.edges])
|
| 62 |
-
|
| 63 |
-
def _check_all_imgs_are_selected(self, msk):
|
| 64 |
-
assert np.all(self._get_msk_indices(msk) == np.arange(self.n_imgs)), 'incomplete mask!'
|
| 65 |
-
|
| 66 |
-
def preset_pose(self, known_poses, pose_msk=None): # cam-to-world
|
| 67 |
-
self._check_all_imgs_are_selected(pose_msk)
|
| 68 |
-
|
| 69 |
-
if isinstance(known_poses, torch.Tensor) and known_poses.ndim == 2:
|
| 70 |
-
known_poses = [known_poses]
|
| 71 |
-
for idx, pose in zip(self._get_msk_indices(pose_msk), known_poses):
|
| 72 |
-
if self.verbose:
|
| 73 |
-
print(f' (setting pose #{idx} = {pose[:3,3]})')
|
| 74 |
-
self._no_grad(self._set_pose(self.im_poses, idx, torch.tensor(pose)))
|
| 75 |
-
|
| 76 |
-
# normalize scale if there's less than 1 known pose
|
| 77 |
-
n_known_poses = sum((p.requires_grad is False) for p in self.im_poses)
|
| 78 |
-
self.norm_pw_scale = (n_known_poses <= 1)
|
| 79 |
-
|
| 80 |
-
self.im_poses.requires_grad_(False)
|
| 81 |
-
self.norm_pw_scale = False
|
| 82 |
-
|
| 83 |
-
def preset_focal(self, known_focals, msk=None):
|
| 84 |
-
self._check_all_imgs_are_selected(msk)
|
| 85 |
-
|
| 86 |
-
for idx, focal in zip(self._get_msk_indices(msk), known_focals):
|
| 87 |
-
if self.verbose:
|
| 88 |
-
print(f' (setting focal #{idx} = {focal})')
|
| 89 |
-
self._no_grad(self._set_focal(idx, focal))
|
| 90 |
-
|
| 91 |
-
self.im_focals.requires_grad_(False)
|
| 92 |
-
|
| 93 |
-
def preset_principal_point(self, known_pp, msk=None):
|
| 94 |
-
self._check_all_imgs_are_selected(msk)
|
| 95 |
-
|
| 96 |
-
for idx, pp in zip(self._get_msk_indices(msk), known_pp):
|
| 97 |
-
if self.verbose:
|
| 98 |
-
print(f' (setting principal point #{idx} = {pp})')
|
| 99 |
-
self._no_grad(self._set_principal_point(idx, pp))
|
| 100 |
-
|
| 101 |
-
self.im_pp.requires_grad_(False)
|
| 102 |
-
|
| 103 |
-
def _get_msk_indices(self, msk):
|
| 104 |
-
if msk is None:
|
| 105 |
-
return range(self.n_imgs)
|
| 106 |
-
elif isinstance(msk, int):
|
| 107 |
-
return [msk]
|
| 108 |
-
elif isinstance(msk, (tuple, list)):
|
| 109 |
-
return self._get_msk_indices(np.array(msk))
|
| 110 |
-
elif msk.dtype in (bool, torch.bool, np.bool_):
|
| 111 |
-
assert len(msk) == self.n_imgs
|
| 112 |
-
return np.where(msk)[0]
|
| 113 |
-
elif np.issubdtype(msk.dtype, np.integer):
|
| 114 |
-
return msk
|
| 115 |
-
else:
|
| 116 |
-
raise ValueError(f'bad {msk=}')
|
| 117 |
-
|
| 118 |
-
def _no_grad(self, tensor):
|
| 119 |
-
assert tensor.requires_grad, 'it must be True at this point, otherwise no modification occurs'
|
| 120 |
-
|
| 121 |
-
def _set_focal(self, idx, focal, force=False):
|
| 122 |
-
param = self.im_focals[idx]
|
| 123 |
-
if param.requires_grad or force: # can only init a parameter not already initialized
|
| 124 |
-
param.data[:] = self.focal_break * np.log(focal)
|
| 125 |
-
return param
|
| 126 |
-
|
| 127 |
-
def get_focals(self):
|
| 128 |
-
log_focals = torch.stack(list(self.im_focals), dim=0)
|
| 129 |
-
return (log_focals / self.focal_break).exp()
|
| 130 |
-
|
| 131 |
-
def get_known_focal_mask(self):
|
| 132 |
-
return torch.tensor([not (p.requires_grad) for p in self.im_focals])
|
| 133 |
-
|
| 134 |
-
def _set_principal_point(self, idx, pp, force=False):
|
| 135 |
-
param = self.im_pp[idx]
|
| 136 |
-
H, W = self.imshapes[idx]
|
| 137 |
-
if param.requires_grad or force: # can only init a parameter not already initialized
|
| 138 |
-
param.data[:] = to_cpu(to_numpy(pp) - (W/2, H/2)) / 10
|
| 139 |
-
return param
|
| 140 |
-
|
| 141 |
-
def get_principal_points(self):
|
| 142 |
-
return self._pp + 10 * self.im_pp
|
| 143 |
-
|
| 144 |
-
def get_intrinsics(self):
|
| 145 |
-
K = torch.zeros((self.n_imgs, 3, 3), device=self.device)
|
| 146 |
-
focals = self.get_focals().flatten()
|
| 147 |
-
K[:, 0, 0] = K[:, 1, 1] = focals
|
| 148 |
-
K[:, :2, 2] = self.get_principal_points()
|
| 149 |
-
K[:, 2, 2] = 1
|
| 150 |
-
return K
|
| 151 |
-
|
| 152 |
-
def get_im_poses(self): # cam to world
|
| 153 |
-
cam2world = self._get_poses(self.im_poses)
|
| 154 |
-
return cam2world
|
| 155 |
-
|
| 156 |
-
def _set_depthmap(self, idx, depth, force=False):
|
| 157 |
-
depth = _ravel_hw(depth, self.max_area)
|
| 158 |
-
|
| 159 |
-
param = self.im_depthmaps[idx]
|
| 160 |
-
if param.requires_grad or force: # can only init a parameter not already initialized
|
| 161 |
-
param.data[:] = depth.log().nan_to_num(neginf=0)
|
| 162 |
-
return param
|
| 163 |
-
|
| 164 |
-
def get_depthmaps(self, raw=False):
|
| 165 |
-
res = self.im_depthmaps.exp()
|
| 166 |
-
if not raw:
|
| 167 |
-
res = [dm[:h*w].view(h, w) for dm, (h, w) in zip(res, self.imshapes)]
|
| 168 |
-
return res
|
| 169 |
-
|
| 170 |
-
def depth_to_pts3d(self):
|
| 171 |
-
# Get depths and projection params if not provided
|
| 172 |
-
focals = self.get_focals()
|
| 173 |
-
pp = self.get_principal_points()
|
| 174 |
-
im_poses = self.get_im_poses()
|
| 175 |
-
depth = self.get_depthmaps(raw=True)
|
| 176 |
-
|
| 177 |
-
# get pointmaps in camera frame
|
| 178 |
-
rel_ptmaps = _fast_depthmap_to_pts3d(depth, self._grid, focals, pp=pp)
|
| 179 |
-
# project to world frame
|
| 180 |
-
return geotrf(im_poses, rel_ptmaps)
|
| 181 |
-
|
| 182 |
-
def get_pts3d(self, raw=False):
|
| 183 |
-
res = self.depth_to_pts3d()
|
| 184 |
-
if not raw:
|
| 185 |
-
res = [dm[:h*w].view(h, w, 3) for dm, (h, w) in zip(res, self.imshapes)]
|
| 186 |
-
return res
|
| 187 |
-
|
| 188 |
-
# def cosine_similarity_batch(self, semantic_features, query_pixels):
|
| 189 |
-
# # 扩展维度进行广播计算余弦相似度
|
| 190 |
-
# query_pixels = query_pixels.unsqueeze(1) # [B, 1, C]
|
| 191 |
-
# semantic_features = semantic_features.unsqueeze(0) # [1, H, W, C]
|
| 192 |
-
# cos_sim = F.cosine_similarity(query_pixels, semantic_features, dim=-1) # [B, H, W]
|
| 193 |
-
# return cos_sim
|
| 194 |
-
|
| 195 |
-
# def semantic_loss(self, semantic_features, predicted_depth, window_size=32, stride=16, lambda_semantic=0.1):
|
| 196 |
-
# # 获取图像的尺寸
|
| 197 |
-
# height, width, channels = semantic_features.shape
|
| 198 |
-
# # 执行矩阵化处理
|
| 199 |
-
# ret_loss = 0.0
|
| 200 |
-
# cnt = 0
|
| 201 |
-
# for i in range(0, height, stride):
|
| 202 |
-
# for j in range(0, width, stride):
|
| 203 |
-
# window_semantic = semantic_features[i:min(i+window_size,height), j:min(j+window_size,width), :]
|
| 204 |
-
# window_depth = predicted_depth[i:min(i+window_size,height), j:min(j+window_size,width)]
|
| 205 |
-
# # print(window_semantic.shape, window_depth.shape)
|
| 206 |
-
|
| 207 |
-
# window_semantic = window_semantic.reshape(-1, channels)
|
| 208 |
-
# window_depth = window_depth.reshape(-1, 1)
|
| 209 |
-
|
| 210 |
-
# cos_sim = torch.matmul(window_semantic, window_semantic.t())
|
| 211 |
-
# dep_dif = torch.abs(window_depth - window_depth.reshape(1, -1))
|
| 212 |
-
|
| 213 |
-
# # print(torch.sum(cos_sim * dep_dif))
|
| 214 |
-
# ret_loss += torch.mean(cos_sim * dep_dif)
|
| 215 |
-
# cnt += 1
|
| 216 |
-
|
| 217 |
-
# return ret_loss / cnt
|
| 218 |
-
|
| 219 |
-
# def segmap_loss(self, predicted_depth, seg_map):
|
| 220 |
-
# ret_loss = 0.0
|
| 221 |
-
# cnt = 0
|
| 222 |
-
# seg_map = seg_map.view(-1)
|
| 223 |
-
# predicted_depth = predicted_depth.view(-1, 1)
|
| 224 |
-
# unique_groups = torch.unique(seg_map)
|
| 225 |
-
# for group in unique_groups:
|
| 226 |
-
# # print(group)
|
| 227 |
-
# if group == -1:
|
| 228 |
-
# continue
|
| 229 |
-
# group_indices = (seg_map == group).nonzero(as_tuple=True)[0]
|
| 230 |
-
# if len(group_indices) > 0:
|
| 231 |
-
# now_feat = predicted_depth[group_indices]
|
| 232 |
-
|
| 233 |
-
# dep_dif = torch.abs(now_feat - now_feat.reshape(1, -1))
|
| 234 |
-
|
| 235 |
-
# ret_loss += torch.mean(dep_dif)
|
| 236 |
-
# cnt += 1
|
| 237 |
-
|
| 238 |
-
# return ret_loss / cnt if cnt > 0 else ret_loss
|
| 239 |
-
|
| 240 |
-
# def spatial_smoothness_loss(self, point_map, semantic_map):
|
| 241 |
-
# """
|
| 242 |
-
# 计算空间平滑性损失,使得同一语义类别的相邻像素点空间位置变化不剧烈。
|
| 243 |
-
# 使用八邻域。
|
| 244 |
-
|
| 245 |
-
# 参数:
|
| 246 |
-
# - point_map: (H, W, 3),表示每个像素点的空间坐标 (x, y, z)
|
| 247 |
-
# - semantic_map: (H, W, 1),每个像素点的语义标签
|
| 248 |
-
|
| 249 |
-
# 返回:
|
| 250 |
-
# - 总损失值
|
| 251 |
-
# """
|
| 252 |
-
|
| 253 |
-
# # 获取图像的高度和宽度
|
| 254 |
-
# H, W = semantic_map.shape
|
| 255 |
-
|
| 256 |
-
# # 将点图和语义图调整为二维形式
|
| 257 |
-
# point_map = point_map.view(-1, 3) # (H * W, 3)
|
| 258 |
-
# semantic_map = semantic_map.view(-1) # (H * W,)
|
| 259 |
-
|
| 260 |
-
# # 创建图像的索引
|
| 261 |
-
# row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W))
|
| 262 |
-
# row_idx = row_idx.flatten()
|
| 263 |
-
# col_idx = col_idx.flatten()
|
| 264 |
-
|
| 265 |
-
# # 定义八邻域偏移
|
| 266 |
-
# neighbor_offsets = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1],
|
| 267 |
-
# [-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.long)
|
| 268 |
-
|
| 269 |
-
# # 存储损失值
|
| 270 |
-
# total_loss = 0.0
|
| 271 |
-
|
| 272 |
-
# # 对每个像素点进行计算
|
| 273 |
-
# for offset in neighbor_offsets:
|
| 274 |
-
# # 计算邻居位置
|
| 275 |
-
# neighbor_row = row_idx + offset[0]
|
| 276 |
-
# neighbor_col = col_idx + offset[1]
|
| 277 |
-
|
| 278 |
-
# # 确保邻居在图像内部
|
| 279 |
-
# valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W)
|
| 280 |
-
# valid_row = neighbor_row[valid_mask]
|
| 281 |
-
# valid_col = neighbor_col[valid_mask]
|
| 282 |
-
|
| 283 |
-
# # 获取有效像素点的索引
|
| 284 |
-
# idx = valid_mask.nonzero(as_tuple=True)[0]
|
| 285 |
-
# neighbor_idx = valid_row * W + valid_col
|
| 286 |
-
|
| 287 |
-
# # 获取相邻像素点的语义标签和空间坐标
|
| 288 |
-
# sem_i = semantic_map[idx]
|
| 289 |
-
# sem_j = semantic_map[neighbor_idx]
|
| 290 |
-
# p_i = point_map[idx]
|
| 291 |
-
# p_j = point_map[neighbor_idx]
|
| 292 |
-
|
| 293 |
-
# # 计算空间坐标差异的平方
|
| 294 |
-
# distance = torch.sum((p_i - p_j) ** 2, dim=1)
|
| 295 |
-
|
| 296 |
-
# # 如果相邻像素属于同一语义类别,计算损失
|
| 297 |
-
# loss_mask = (sem_i == sem_j)
|
| 298 |
-
# total_loss += torch.sum(loss_mask * distance)
|
| 299 |
-
|
| 300 |
-
# # 平均损失
|
| 301 |
-
# return total_loss / point_map.size(0)
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
def spatial_smoothness_loss_multi_image(self, point_maps, semantic_maps, confidence_maps):
|
| 305 |
-
"""
|
| 306 |
-
计算空间平滑性损失,考虑多张图像中属于同一物体的像素点的空间平滑性。
|
| 307 |
-
|
| 308 |
-
参数:
|
| 309 |
-
- point_maps: (B, H, W, 3),每张图像的空间坐标 (x, y, z) B是batch大小
|
| 310 |
-
- semantic_maps: (B, H, W, 1),每张图像的语义标签
|
| 311 |
-
|
| 312 |
-
返回:
|
| 313 |
-
- 总损失值
|
| 314 |
-
"""
|
| 315 |
-
|
| 316 |
-
B, H, W = semantic_maps.shape
|
| 317 |
-
|
| 318 |
-
# 将点图和语义图调整为二维形式
|
| 319 |
-
point_maps = point_maps.view(B, -1, 3) # (B, H*W, 3)
|
| 320 |
-
semantic_maps = semantic_maps.view(B, -1) # (B, H*W)
|
| 321 |
-
confidence_maps = confidence_maps.view(B, -1) # (B, H*W)
|
| 322 |
-
|
| 323 |
-
# 存储损失值
|
| 324 |
-
total_loss = 0.0
|
| 325 |
-
|
| 326 |
-
# 对每张图像中的每个像素进行计算
|
| 327 |
-
for b in range(B):
|
| 328 |
-
# 获取当前图像的点图和语义图
|
| 329 |
-
point_map = point_maps[b]
|
| 330 |
-
semantic_map = semantic_maps[b]
|
| 331 |
-
confidence_map = confidence_maps[b]
|
| 332 |
-
|
| 333 |
-
# 创建图像的索引
|
| 334 |
-
row_idx, col_idx = torch.meshgrid(torch.arange(H), torch.arange(W))
|
| 335 |
-
row_idx = row_idx.flatten()
|
| 336 |
-
col_idx = col_idx.flatten()
|
| 337 |
-
|
| 338 |
-
# 定义八邻域偏移
|
| 339 |
-
neighbor_offsets = torch.tensor([[-1, 0], [1, 0], [0, -1], [0, 1],
|
| 340 |
-
[-1, -1], [-1, 1], [1, -1], [1, 1]], dtype=torch.long)
|
| 341 |
-
|
| 342 |
-
# 对每个像素点进行计算(仅在当前图像内计算邻域关系)
|
| 343 |
-
for offset in neighbor_offsets:
|
| 344 |
-
# 计算邻居位置
|
| 345 |
-
neighbor_row = row_idx + offset[0]
|
| 346 |
-
neighbor_col = col_idx + offset[1]
|
| 347 |
-
|
| 348 |
-
# 确保邻居在图像内部
|
| 349 |
-
valid_mask = (neighbor_row >= 0) & (neighbor_row < H) & (neighbor_col >= 0) & (neighbor_col < W)
|
| 350 |
-
valid_row = neighbor_row[valid_mask]
|
| 351 |
-
valid_col = neighbor_col[valid_mask]
|
| 352 |
-
|
| 353 |
-
# 获取有效像素点的索引
|
| 354 |
-
idx = valid_mask.nonzero(as_tuple=True)[0]
|
| 355 |
-
neighbor_idx = valid_row * W + valid_col
|
| 356 |
-
|
| 357 |
-
# 获取相邻像素点的语义标签和空间坐��
|
| 358 |
-
sem_i = semantic_map[idx]
|
| 359 |
-
sem_j = semantic_map[neighbor_idx]
|
| 360 |
-
p_i = point_map[idx]
|
| 361 |
-
p_j = point_map[neighbor_idx]
|
| 362 |
-
conf_i = confidence_map[idx]
|
| 363 |
-
conf_j = confidence_map[neighbor_idx]
|
| 364 |
-
|
| 365 |
-
# 计算空间坐标差异的平方
|
| 366 |
-
distance = torch.sum((p_i - p_j)**2, dim=1)
|
| 367 |
-
|
| 368 |
-
# 如果相邻像素属于同一语义类别,计算加权损失
|
| 369 |
-
loss_mask = (sem_i == sem_j)
|
| 370 |
-
|
| 371 |
-
# 反向加权,低置信度的点会有更高的权重
|
| 372 |
-
# inverse_weight_i = 1.0 / (conf_i) # 防止除零错误
|
| 373 |
-
# inverse_weight_j = 1.0 / (conf_j)
|
| 374 |
-
weighted_distance = loss_mask * distance # 加权损失 * inverse_weight_i * inverse_weight_j
|
| 375 |
-
total_loss += torch.sum(weighted_distance)
|
| 376 |
-
|
| 377 |
-
# 跨图计算:对于同一语义类别的像素,只计算其均值差异,避免两两计算
|
| 378 |
-
# for b2 in range(B):
|
| 379 |
-
# if b == b2:
|
| 380 |
-
# continue # 跳过与自己图像的比较
|
| 381 |
-
# point_map_b2 = point_maps[b2]
|
| 382 |
-
# semantic_map_b2 = semantic_maps[b2]
|
| 383 |
-
# confidence_map_b2 = confidence_maps[b2]
|
| 384 |
-
|
| 385 |
-
# for sem_id in torch.unique(semantic_map):
|
| 386 |
-
# sem_mask_a = (semantic_map == sem_id)
|
| 387 |
-
# sem_mask_b2 = (semantic_map_b2 == sem_id)
|
| 388 |
-
|
| 389 |
-
# # 提取同一语义类别的像素点
|
| 390 |
-
# shared_points_a = point_map[sem_mask_a]
|
| 391 |
-
# shared_points_b2 = point_map_b2[sem_mask_b2]
|
| 392 |
-
# shared_conf_a = confidence_map[sem_mask_a]
|
| 393 |
-
# shared_conf_b2 = confidence_map_b2[sem_mask_b2]
|
| 394 |
-
|
| 395 |
-
# if shared_points_a.shape[0] > 0 and shared_points_b2.shape[0] > 0:
|
| 396 |
-
# # 计算这些像素点的均值
|
| 397 |
-
# mean_a = shared_points_a.mean(dim=0) # 当前图像该语义类别的均值
|
| 398 |
-
# mean_b2 = shared_points_b2.mean(dim=0) # 第b2图像该语义类别的均值
|
| 399 |
-
# mean_conf_a = shared_conf_a.mean() # 当前图像该语义类别的置信度均值
|
| 400 |
-
# mean_conf_b2 = shared_conf_b2.mean() # 第b2图像该语义类别的置信度均值
|
| 401 |
-
|
| 402 |
-
# # 计算均值之间的空间差异,并考虑置信度的加权
|
| 403 |
-
# distance_cross = torch.sum((mean_a - mean_b2) ** 2)
|
| 404 |
-
# weighted_distance_cross = distance_cross * mean_conf_a * mean_conf_b2
|
| 405 |
-
# total_loss += weighted_distance_cross
|
| 406 |
-
|
| 407 |
-
# 平均损失
|
| 408 |
-
return total_loss / (B * H * W)
|
| 409 |
-
|
| 410 |
-
|
| 411 |
-
|
| 412 |
-
def forward(self, cur_iter=0):
|
| 413 |
-
pw_poses = self.get_pw_poses() # cam-to-world
|
| 414 |
-
pw_adapt = self.get_adaptors().unsqueeze(1)
|
| 415 |
-
proj_pts3d = self.get_pts3d(raw=True)
|
| 416 |
-
|
| 417 |
-
loss = 0.0
|
| 418 |
-
|
| 419 |
-
# depth = self.get_depthmaps(raw=True)
|
| 420 |
-
# print(depth.shape)
|
| 421 |
-
# if cur_iter < 100:
|
| 422 |
-
# # for i, pointmap in enumerate(proj_pts3d):
|
| 423 |
-
# # loss += self.spatial_smoothness_loss(pointmap, seg_maps[i].cuda())
|
| 424 |
-
|
| 425 |
-
# # depths = self.get_depthmaps()
|
| 426 |
-
# # # cogs = self.cogs
|
| 427 |
-
# # seg_maps = self.segmaps
|
| 428 |
-
# # im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf]))
|
| 429 |
-
|
| 430 |
-
# # for i, depth in enumerate(depths):
|
| 431 |
-
# # # print(seg_maps[i].shape)
|
| 432 |
-
# # # H, W = depth.shape
|
| 433 |
-
# # # tmp = cogs[i].reshape(-1, 1024)
|
| 434 |
-
# # # tmp = torch.matmul(tmp, self.cog_matrix.detach().t())
|
| 435 |
-
# # # tmp / (tmp.norm(dim=-1, keepdim=True)+0.000000000001)
|
| 436 |
-
# # # tmp = tmp.reshape(H, W, 3)
|
| 437 |
-
# # loss += self.segmap_loss(depth, seg_maps[i], im_conf[i])
|
| 438 |
-
# # loss += self.semantic_loss(cogs[i], depth)
|
| 439 |
-
|
| 440 |
-
# # im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf]))
|
| 441 |
-
|
| 442 |
-
# # cogs = self.cogs.permute(0, 3, 1, 2)
|
| 443 |
-
# # cogs = F.interpolate(cogs, scale_factor=2, mode='nearest')
|
| 444 |
-
# # cogs = cogs.permute(0, 2, 3, 1)
|
| 445 |
-
# # cogs = torch.stack(self.cogs).view(-1, 1024)
|
| 446 |
-
# # proj = proj_pts3d.view(-1, 3)
|
| 447 |
-
# # proj = proj / proj.norm(dim=-1, keepdim=True)
|
| 448 |
-
# # img_conf = im_conf.view(-1,1)
|
| 449 |
-
|
| 450 |
-
# # selected_indices = torch.where(img_conf > 2.0)[0]
|
| 451 |
-
# # img_conf = img_conf[selected_indices]
|
| 452 |
-
# # cogs = cogs[selected_indices]
|
| 453 |
-
# # proj = proj[selected_indices]
|
| 454 |
-
# # print(img_conf.shape, cogs.shape, proj.shape)
|
| 455 |
-
# # proj_dis = torch.matmul(proj, proj.t())
|
| 456 |
-
# # cogs_dis = torch.matmul(cogs, cogs.t())
|
| 457 |
-
# # loss += (im_conf * F.mse_loss(proj_dis, cogs_dis, reduction='none')).mean()
|
| 458 |
-
|
| 459 |
-
# # if cur_iter % 2 == 0:
|
| 460 |
-
# # tmp = torch.matmul(cogs.detach(), self.cog_matrix.detach().t())
|
| 461 |
-
# # tmp = tmp / (tmp.norm(dim=-1, keepdim=True)+0.000000000001)
|
| 462 |
-
# # loss += 0/1*(img_conf * F.mse_loss(proj, tmp, reduction='none')).mean()
|
| 463 |
-
# # if cur_iter % 2 == 1:
|
| 464 |
-
# # tmp = torch.matmul(cogs.view(-1, 1024), self.cog_matrix.detach().t())
|
| 465 |
-
# # tmp = tmp / tmp.norm(dim=-1, keepdim=True)
|
| 466 |
-
# # loss += (im_conf.view(-1,1) * F.mse_loss(proj.detach(), tmp, reduction='none')).mean()
|
| 467 |
-
# # if cur_iter % 3 == 2:
|
| 468 |
-
# # tmp = torch.matmul(cogs.view(-1, 1024).detach(), self.cog_matrix.t())
|
| 469 |
-
# # tmp = tmp / tmp.norm(dim=-1, keepdim=True)
|
| 470 |
-
# # loss += (im_conf.view(-1,1) * F.mse_loss(proj.detach(), tmp, reduction='none')).mean()
|
| 471 |
-
seg_maps = torch.stack(self.segmaps).cuda()
|
| 472 |
-
im_conf = self.conf_trf(torch.stack([param_tensor for param_tensor in self.im_conf]))
|
| 473 |
-
loss += self.spatial_smoothness_loss_multi_image(proj_pts3d, seg_maps, im_conf)
|
| 474 |
-
# # if cur_iter > 100:
|
| 475 |
-
# # rotate pairwise prediction according to pw_poses
|
| 476 |
-
# aligned_pred_i = geotrf(pw_poses, pw_adapt * self._stacked_pred_i)
|
| 477 |
-
# aligned_pred_j = geotrf(pw_poses, pw_adapt * self._stacked_pred_j)
|
| 478 |
-
|
| 479 |
-
# loss += self.spatial_smoothness_loss_multi_image(aligned_pred_i, seg_maps[self._ei], im_conf[self._ei])
|
| 480 |
-
# loss += self.spatial_smoothness_loss_multi_image(aligned_pred_j, seg_maps[self._ej], im_conf[self._ej])
|
| 481 |
-
|
| 482 |
-
# # compute the less
|
| 483 |
-
# loss += self.dist(proj_pts3d[self._ei], aligned_pred_i, weight=self._weight_i).sum() / self.total_area_i
|
| 484 |
-
# loss += self.dist(proj_pts3d[self._ej], aligned_pred_j, weight=self._weight_j).sum() / self.total_area_j
|
| 485 |
-
|
| 486 |
-
return loss
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
def _fast_depthmap_to_pts3d(depth, pixel_grid, focal, pp):
|
| 490 |
-
pp = pp.unsqueeze(1)
|
| 491 |
-
focal = focal.unsqueeze(1)
|
| 492 |
-
assert focal.shape == (len(depth), 1, 1)
|
| 493 |
-
assert pp.shape == (len(depth), 1, 2)
|
| 494 |
-
assert pixel_grid.shape == depth.shape + (2,)
|
| 495 |
-
depth = depth.unsqueeze(-1)
|
| 496 |
-
return torch.cat((depth * (pixel_grid - pp) / focal, depth), dim=-1)
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
def ParameterStack(params, keys=None, is_param=None, fill=0):
|
| 500 |
-
if keys is not None:
|
| 501 |
-
params = [params[k] for k in keys]
|
| 502 |
-
|
| 503 |
-
if fill > 0:
|
| 504 |
-
params = [_ravel_hw(p, fill) for p in params]
|
| 505 |
-
|
| 506 |
-
requires_grad = params[0].requires_grad
|
| 507 |
-
assert all(p.requires_grad == requires_grad for p in params)
|
| 508 |
-
|
| 509 |
-
params = torch.stack(list(params)).float().detach()
|
| 510 |
-
if is_param or requires_grad:
|
| 511 |
-
params = nn.Parameter(params)
|
| 512 |
-
params.requires_grad_(requires_grad)
|
| 513 |
-
return params
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
def _ravel_hw(tensor, fill=0):
|
| 517 |
-
# ravel H,W
|
| 518 |
-
tensor = tensor.view((tensor.shape[0] * tensor.shape[1],) + tensor.shape[2:])
|
| 519 |
-
|
| 520 |
-
if len(tensor) < fill:
|
| 521 |
-
tensor = torch.cat((tensor, tensor.new_zeros((fill - len(tensor),)+tensor.shape[1:])))
|
| 522 |
-
return tensor
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
def acceptable_focal_range(H, W, minf=0.5, maxf=3.5):
|
| 526 |
-
focal_base = max(H, W) / (2 * np.tan(np.deg2rad(60) / 2)) # size / 1.1547005383792515
|
| 527 |
-
return minf*focal_base, maxf*focal_base
|
| 528 |
-
|
| 529 |
-
|
| 530 |
-
def apply_mask(img, msk):
|
| 531 |
-
img = img.copy()
|
| 532 |
-
img[msk] = 0
|
| 533 |
-
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|