import torch from einops import rearrange, repeat from ....geometry.projection import get_world_rays from ....geometry.projection import sample_image_grid from ....test.export_ply import save_point_cloud_to_ply import torch.nn.functional as F import MinkowskiEngine as ME def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v): device = out.device # 0. 获取基础维度信息 h, w = depth.shape[2:] _, c, _, _ = out.shape # 1. 准备投影参数 intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j") extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j") depths = rearrange(depth, "b v h w -> b v (h w) () ()") # 2. 获取世界坐标 uv_grid = sample_image_grid((h, w), device)[0] uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v) origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics) world_coords = origins + directions * depths[..., None] world_coords = world_coords.squeeze(3).squeeze(3) # [B, V, N, 3] # 3. 准备特征数据 features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v) features = rearrange(features, "b v c h w -> b v h w c") features = rearrange(features, "b v h w c -> b v (h w) c") # [B, V, N, C] # 4. 合并所有点云数据 all_points = rearrange(world_coords, "b v n c -> (b v n) c") # [B*V*N, 3] feats_flat = features.reshape(-1, c) # [B*V*N, C] # 5. 分离量化操作 with torch.no_grad(): # 计算量化坐标 - 使用四舍五入而不是向下取整 quantized_coords = torch.round(all_points / voxel_resolution).long() # 创建坐标矩阵:批次索引 + 量化坐标 batch_indices = torch.arange(b, device=device).repeat_interleave(v * h * w).unsqueeze(1) combined_coords = torch.cat([batch_indices, quantized_coords], dim=1) # 获取唯一体素ID和映射索引 unique_coords, inverse_indices, counts = torch.unique( combined_coords, dim=0, return_inverse=True, return_counts=True ) # 6. 创建聚合特征和坐标 num_voxels = unique_coords.shape[0] # 7. 向量化聚合 - 特征平均 aggregated_feats = torch.zeros(num_voxels, c, device=device) aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat) aggregated_feats = aggregated_feats / counts.view(-1, 1).float() # 平均特征 # 8. 向量化聚合 - 坐标平均 aggregated_points = torch.zeros(num_voxels, 3, device=device) aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points) aggregated_points = aggregated_points / counts.view(-1, 1).float() # 9. 创建稀疏张量 # 使用正确的坐标格式:批次索引 + 量化坐标 sparse_tensor = ME.SparseTensor( features=aggregated_feats, coordinates=unique_coords.int(), tensor_stride=1, device=device ) # 10. 返回结果 return sparse_tensor, aggregated_points, counts # def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v): # device = out.device # # 0. 获取基础维度信息 # b, v = intrinsics.shape[:2] # h, w = depth.shape[2:] # _, c, _, _ = out.shape # # 1. 准备投影参数 # intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j") # extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j") # depths = rearrange(depth, "b v h w -> b v (h w) () ()") # # 2. 获取世界坐标 # uv_grid = sample_image_grid((h, w), device)[0] # uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v) # origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics) # world_coords = origins + directions * depths[..., None] # world_coords = world_coords.squeeze(3).squeeze(3) # [B, V, N, 3] # # 3. 准备特征数据 # features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v) # features = rearrange(features, "b v c h w -> b v h w c") # features = rearrange(features, "b v h w c -> b v (h w) c") # [B, V, N, C] # # 4. 合并所有点云数据 # all_points = rearrange(world_coords, "b v n c -> (b v n) c") # [B*V*N, 3] # feats_flat = features.reshape(-1, c) # [B*V*N, C] # # 5. 分离量化操作 # with torch.no_grad(): # # 计算量化坐标 # quantized_coords = (all_points / voxel_resolution).floor().long() # # 创建坐标矩阵:批处理索引 + 量化坐标 # batch_indices = torch.arange(b * v, device=device).repeat_interleave(h * w).unsqueeze(1) # combined_coords = torch.cat([batch_indices, quantized_coords], dim=1) # # 获取唯一体素ID和映射索引 # unique_coords, inverse_indices, counts = torch.unique( # combined_coords, # dim=0, # return_inverse=True, # return_counts=True # ) # # 6. 创建聚合特征和坐标 # num_voxels = unique_coords.shape[0] # # 7. 向量化聚合 - 特征求和 # aggregated_feats = torch.zeros(num_voxels, c, device=device) # aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat) # # 8. 向量化聚合 - 坐标平均 # aggregated_points = torch.zeros(num_voxels, 3, device=device) # aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points) # aggregated_points = aggregated_points / counts.view(-1, 1).float() # # 9. 创建稀疏张量 # # 关键修改:使用相同的索引顺序 # # 确保 sparse_tensor 和 aggregated_points 对应相同的体素 # sparse_tensor = ME.SparseTensor( # features=aggregated_feats, # coordinates=torch.cat([ # torch.arange(num_voxels, device=device).unsqueeze(1), # 添加批次索引 # unique_coords[:, 1:].float() # 量化坐标 # ], dim=1).int(), # tensor_stride=1, # device=device # ) # # 10. 返回结果 # # 现在 sparse_tensor 和 aggregated_points 有相同的点数 (num_voxels) # return sparse_tensor, aggregated_points, counts