File size: 6,354 Bytes
a6dd040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
from einops import rearrange, repeat
from ....geometry.projection import get_world_rays
from ....geometry.projection import sample_image_grid
from ....test.export_ply import save_point_cloud_to_ply
import torch.nn.functional as F
import MinkowskiEngine as ME



def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v):
    device = out.device
    # 0. 获取基础维度信息
    h, w = depth.shape[2:]
    _, c, _, _ = out.shape
    
    # 1. 准备投影参数
    intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j")
    extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j")
    depths = rearrange(depth, "b v h w -> b v (h w) () ()")

    # 2. 获取世界坐标
    uv_grid = sample_image_grid((h, w), device)[0]
    uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v)
    origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics)
    world_coords = origins + directions * depths[..., None]
    world_coords = world_coords.squeeze(3).squeeze(3)  # [B, V, N, 3]

    # 3. 准备特征数据
    features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v)
    features = rearrange(features, "b v c h w -> b v h w c")
    features = rearrange(features, "b v h w c -> b v (h w) c")  # [B, V, N, C]
    
    # 4. 合并所有点云数据
    all_points = rearrange(world_coords, "b v n c -> (b v n) c")  # [B*V*N, 3]
    feats_flat = features.reshape(-1, c)  # [B*V*N, C]
    
    # 5. 分离量化操作
    with torch.no_grad():
        # 计算量化坐标 - 使用四舍五入而不是向下取整
        quantized_coords = torch.round(all_points / voxel_resolution).long()
        
        # 创建坐标矩阵:批次索引 + 量化坐标
        batch_indices = torch.arange(b, device=device).repeat_interleave(v * h * w).unsqueeze(1)
        combined_coords = torch.cat([batch_indices, quantized_coords], dim=1)
        
        # 获取唯一体素ID和映射索引
        unique_coords, inverse_indices, counts = torch.unique(
            combined_coords, 
            dim=0, 
            return_inverse=True, 
            return_counts=True
        )
    
    # 6. 创建聚合特征和坐标
    num_voxels = unique_coords.shape[0]
    
    # 7. 向量化聚合 - 特征平均
    aggregated_feats = torch.zeros(num_voxels, c, device=device)
    aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat)
    aggregated_feats = aggregated_feats / counts.view(-1, 1).float()  # 平均特征
    
    # 8. 向量化聚合 - 坐标平均
    aggregated_points = torch.zeros(num_voxels, 3, device=device)
    aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points)
    aggregated_points = aggregated_points / counts.view(-1, 1).float()
    
    # 9. 创建稀疏张量
    # 使用正确的坐标格式:批次索引 + 量化坐标
    sparse_tensor = ME.SparseTensor(
        features=aggregated_feats,
        coordinates=unique_coords.int(),
        tensor_stride=1,
        device=device
    )
    
    # 10. 返回结果
    return sparse_tensor, aggregated_points, counts



# def project_features_to_me(intrinsics, extrinsics, out, depth, voxel_resolution, b, v):
#     device = out.device
#     # 0. 获取基础维度信息
#     b, v = intrinsics.shape[:2]
#     h, w = depth.shape[2:]
#     _, c, _, _ = out.shape
    
#     # 1. 准备投影参数
#     intrinsics = rearrange(intrinsics, "b v i j -> b v () () () i j")
#     extrinsics = rearrange(extrinsics, "b v i j -> b v () () () i j")
#     depths = rearrange(depth, "b v h w -> b v (h w) () ()")

#     # 2. 获取世界坐标
#     uv_grid = sample_image_grid((h, w), device)[0]
#     uv_grid = repeat(uv_grid, "h w c -> 1 v (h w) () () c", v=v)
#     origins, directions = get_world_rays(uv_grid, extrinsics, intrinsics)
#     world_coords = origins + directions * depths[..., None]
#     world_coords = world_coords.squeeze(3).squeeze(3)  # [B, V, N, 3]

#     # 3. 准备特征数据
#     features = rearrange(out, "(b v) c h w -> b v c h w", b=b, v=v)
#     features = rearrange(features, "b v c h w -> b v h w c")
#     features = rearrange(features, "b v h w c -> b v (h w) c")  # [B, V, N, C]
    
#     # 4. 合并所有点云数据
#     all_points = rearrange(world_coords, "b v n c -> (b v n) c")  # [B*V*N, 3]
#     feats_flat = features.reshape(-1, c)  # [B*V*N, C]
    
#     # 5. 分离量化操作
#     with torch.no_grad():
#         # 计算量化坐标
#         quantized_coords = (all_points / voxel_resolution).floor().long()
        
#         # 创建坐标矩阵:批处理索引 + 量化坐标
#         batch_indices = torch.arange(b * v, device=device).repeat_interleave(h * w).unsqueeze(1)
#         combined_coords = torch.cat([batch_indices, quantized_coords], dim=1)
        
#         # 获取唯一体素ID和映射索引
#         unique_coords, inverse_indices, counts = torch.unique(
#             combined_coords, 
#             dim=0, 
#             return_inverse=True, 
#             return_counts=True
#         )
    
#     # 6. 创建聚合特征和坐标
#     num_voxels = unique_coords.shape[0]
    
#     # 7. 向量化聚合 - 特征求和
#     aggregated_feats = torch.zeros(num_voxels, c, device=device)
#     aggregated_feats.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, c), feats_flat)
    
#     # 8. 向量化聚合 - 坐标平均
#     aggregated_points = torch.zeros(num_voxels, 3, device=device)
#     aggregated_points.scatter_add_(0, inverse_indices.unsqueeze(1).expand(-1, 3), all_points)
#     aggregated_points = aggregated_points / counts.view(-1, 1).float()
    
#     # 9. 创建稀疏张量
#     # 关键修改:使用相同的索引顺序
#     # 确保 sparse_tensor 和 aggregated_points 对应相同的体素
#     sparse_tensor = ME.SparseTensor(
#         features=aggregated_feats,
#         coordinates=torch.cat([
#             torch.arange(num_voxels, device=device).unsqueeze(1),  # 添加批次索引
#             unique_coords[:, 1:].float()  # 量化坐标
#         ], dim=1).int(),
#         tensor_stride=1,
#         device=device
#     )
    
#     # 10. 返回结果
#     # 现在 sparse_tensor 和 aggregated_points 有相同的点数 (num_voxels)
#     return sparse_tensor, aggregated_points, counts