Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
import torch
from einops import rearrange, repeat
from ....geometry.projection import get_world_rays
from ....geometry.projection import sample_image_grid
from ....test.export_ply import save_point_cloud_to_ply
import torch.nn.functional as F
import MinkowskiEngine as ME
def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v):
device = out.device
# 0. 获取基础维度信息
h, w = depth.shape[2:]
_, c, _, _ = out.shape
# 1. 准备投影参数
intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j")
extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j")
depths = rearrange(depth, "b v h w -> b v (h w) () ()")
# 2. 获取世界坐标
uv_grid = sample_image_grid((h, w), device)[0]
uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v)
origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics)
world_coords = origins + directions * depths[..., None]
world_coords = world_coords.squeeze(3).squeeze(3) # [B, V, N, 3]
# 3. 准备特征数据
features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v)
features = rearrange(features, "b v c h w -> b v h w c")
features = rearrange(features, "b v h w c -> b v (h w) c") # [B, V, N, C]
# 4. 合并所有点云数据
all_points = rearrange(world_coords, "b v n c -> (b v n) c") # [B*V*N, 3]
feats_flat = features.reshape(-1, c) # [B*V*N, C]
# 5. 分离量化操作
with torch.no_grad():
# 计算量化坐标 - 使用四舍五入而不是向下取整
quantized_coords = torch.round(all_points / voxel_resolution).long()
# 创建坐标矩阵:批次索引 + 量化坐标
batch_indices = torch.arange(b, device=device).repeat_interleave(v * h * w).unsqueeze(1)
combined_coords = torch.cat([batch_indices, quantized_coords], dim=1)
# 获取唯一体素ID和映射索引
unique_coords, inverse_indices, counts = torch.unique(
combined_coords,
dim=0,
return_inverse=True,
return_counts=True
)
# 6. 创建聚合特征和坐标
num_voxels = unique_coords.shape[0]
# 7. 向量化聚合 - 特征平均
aggregated_feats = torch.zeros(num_voxels, c, device=device)
aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat)
aggregated_feats = aggregated_feats / counts.view(-1, 1).float() # 平均特征
# 8. 向量化聚合 - 坐标平均
aggregated_points = torch.zeros(num_voxels, 3, device=device)
aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points)
aggregated_points = aggregated_points / counts.view(-1, 1).float()
# 9. 创建稀疏张量
# 使用正确的坐标格式:批次索引 + 量化坐标
sparse_tensor = ME.SparseTensor(
features=aggregated_feats,
coordinates=unique_coords.int(),
tensor_stride=1,
device=device
)
# 10. 返回结果
return sparse_tensor, aggregated_points, counts
# def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v):
# device = out.device
# # 0. 获取基础维度信息
# b, v = intrinsics.shape[:2]
# h, w = depth.shape[2:]
# _, c, _, _ = out.shape
# # 1. 准备投影参数
# intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j")
# extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j")
# depths = rearrange(depth, "b v h w -> b v (h w) () ()")
# # 2. 获取世界坐标
# uv_grid = sample_image_grid((h, w), device)[0]
# uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v)
# origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics)
# world_coords = origins + directions * depths[..., None]
# world_coords = world_coords.squeeze(3).squeeze(3) # [B, V, N, 3]
# # 3. 准备特征数据
# features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v)
# features = rearrange(features, "b v c h w -> b v h w c")
# features = rearrange(features, "b v h w c -> b v (h w) c") # [B, V, N, C]
# # 4. 合并所有点云数据
# all_points = rearrange(world_coords, "b v n c -> (b v n) c") # [B*V*N, 3]
# feats_flat = features.reshape(-1, c) # [B*V*N, C]
# # 5. 分离量化操作
# with torch.no_grad():
# # 计算量化坐标
# quantized_coords = (all_points / voxel_resolution).floor().long()
# # 创建坐标矩阵:批处理索引 + 量化坐标
# batch_indices = torch.arange(b * v, device=device).repeat_interleave(h * w).unsqueeze(1)
# combined_coords = torch.cat([batch_indices, quantized_coords], dim=1)
# # 获取唯一体素ID和映射索引
# unique_coords, inverse_indices, counts = torch.unique(
# combined_coords,
# dim=0,
# return_inverse=True,
# return_counts=True
# )
# # 6. 创建聚合特征和坐标
# num_voxels = unique_coords.shape[0]
# # 7. 向量化聚合 - 特征求和
# aggregated_feats = torch.zeros(num_voxels, c, device=device)
# aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat)
# # 8. 向量化聚合 - 坐标平均
# aggregated_points = torch.zeros(num_voxels, 3, device=device)
# aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points)
# aggregated_points = aggregated_points / counts.view(-1, 1).float()
# # 9. 创建稀疏张量
# # 关键修改:使用相同的索引顺序
# # 确保 sparse_tensor 和 aggregated_points 对应相同的体素
# sparse_tensor = ME.SparseTensor(
# features=aggregated_feats,
# coordinates=torch.cat([
# torch.arange(num_voxels, device=device).unsqueeze(1), # 添加批次索引
# unique_coords[:, 1:].float() # 量化坐标
# ], dim=1).int(),
# tensor_stride=1,
# device=device
# )
# # 10. 返回结果
# # 现在 sparse_tensor 和 aggregated_points 有相同的点数 (num_voxels)
# return sparse_tensor, aggregated_points, counts