File size: 6,354 Bytes
a6dd040 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 | import torch
from einops import rearrange, repeat
from ....geometry.projection import get_world_rays
from ....geometry.projection import sample_image_grid
from ....test.export_ply import save_point_cloud_to_ply
import torch.nn.functional as F
import MinkowskiEngine as ME
def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v):
device = out.device
# 0. 获取基础维度信息
h, w = depth.shape[2:]
_, c, _, _ = out.shape
# 1. 准备投影参数
intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j")
extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j")
depths = rearrange(depth, "b v h w -> b v (h w) () ()")
# 2. 获取世界坐标
uv_grid = sample_image_grid((h, w), device)[0]
uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v)
origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics)
world_coords = origins + directions * depths[..., None]
world_coords = world_coords.squeeze(3).squeeze(3) # [B, V, N, 3]
# 3. 准备特征数据
features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v)
features = rearrange(features, "b v c h w -> b v h w c")
features = rearrange(features, "b v h w c -> b v (h w) c") # [B, V, N, C]
# 4. 合并所有点云数据
all_points = rearrange(world_coords, "b v n c -> (b v n) c") # [B*V*N, 3]
feats_flat = features.reshape(-1, c) # [B*V*N, C]
# 5. 分离量化操作
with torch.no_grad():
# 计算量化坐标 - 使用四舍五入而不是向下取整
quantized_coords = torch.round(all_points / voxel_resolution).long()
# 创建坐标矩阵:批次索引 + 量化坐标
batch_indices = torch.arange(b, device=device).repeat_interleave(v * h * w).unsqueeze(1)
combined_coords = torch.cat([batch_indices, quantized_coords], dim=1)
# 获取唯一体素ID和映射索引
unique_coords, inverse_indices, counts = torch.unique(
combined_coords,
dim=0,
return_inverse=True,
return_counts=True
)
# 6. 创建聚合特征和坐标
num_voxels = unique_coords.shape[0]
# 7. 向量化聚合 - 特征平均
aggregated_feats = torch.zeros(num_voxels, c, device=device)
aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat)
aggregated_feats = aggregated_feats / counts.view(-1, 1).float() # 平均特征
# 8. 向量化聚合 - 坐标平均
aggregated_points = torch.zeros(num_voxels, 3, device=device)
aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points)
aggregated_points = aggregated_points / counts.view(-1, 1).float()
# 9. 创建稀疏张量
# 使用正确的坐标格式:批次索引 + 量化坐标
sparse_tensor = ME.SparseTensor(
features=aggregated_feats,
coordinates=unique_coords.int(),
tensor_stride=1,
device=device
)
# 10. 返回结果
return sparse_tensor, aggregated_points, counts
# def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v):
# device = out.device
# # 0. 获取基础维度信息
# b, v = intrinsics.shape[:2]
# h, w = depth.shape[2:]
# _, c, _, _ = out.shape
# # 1. 准备投影参数
# intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j")
# extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j")
# depths = rearrange(depth, "b v h w -> b v (h w) () ()")
# # 2. 获取世界坐标
# uv_grid = sample_image_grid((h, w), device)[0]
# uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v)
# origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics)
# world_coords = origins + directions * depths[..., None]
# world_coords = world_coords.squeeze(3).squeeze(3) # [B, V, N, 3]
# # 3. 准备特征数据
# features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v)
# features = rearrange(features, "b v c h w -> b v h w c")
# features = rearrange(features, "b v h w c -> b v (h w) c") # [B, V, N, C]
# # 4. 合并所有点云数据
# all_points = rearrange(world_coords, "b v n c -> (b v n) c") # [B*V*N, 3]
# feats_flat = features.reshape(-1, c) # [B*V*N, C]
# # 5. 分离量化操作
# with torch.no_grad():
# # 计算量化坐标
# quantized_coords = (all_points / voxel_resolution).floor().long()
# # 创建坐标矩阵:批处理索引 + 量化坐标
# batch_indices = torch.arange(b * v, device=device).repeat_interleave(h * w).unsqueeze(1)
# combined_coords = torch.cat([batch_indices, quantized_coords], dim=1)
# # 获取唯一体素ID和映射索引
# unique_coords, inverse_indices, counts = torch.unique(
# combined_coords,
# dim=0,
# return_inverse=True,
# return_counts=True
# )
# # 6. 创建聚合特征和坐标
# num_voxels = unique_coords.shape[0]
# # 7. 向量化聚合 - 特征求和
# aggregated_feats = torch.zeros(num_voxels, c, device=device)
# aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat)
# # 8. 向量化聚合 - 坐标平均
# aggregated_points = torch.zeros(num_voxels, 3, device=device)
# aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points)
# aggregated_points = aggregated_points / counts.view(-1, 1).float()
# # 9. 创建稀疏张量
# # 关键修改:使用相同的索引顺序
# # 确保 sparse_tensor 和 aggregated_points 对应相同的体素
# sparse_tensor = ME.SparseTensor(
# features=aggregated_feats,
# coordinates=torch.cat([
# torch.arange(num_voxels, device=device).unsqueeze(1), # 添加批次索引
# unique_coords[:, 1:].float() # 量化坐标
# ], dim=1).int(),
# tensor_stride=1,
# device=device
# )
# # 10. 返回结果
# # 现在 sparse_tensor 和 aggregated_points 有相同的点数 (num_voxels)
# return sparse_tensor, aggregated_points, counts
|