File size: 12,863 Bytes
a6dd040
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
import torch
import torch.nn as nn
import MinkowskiEngine as ME
import torch.nn.functional as F

class SparseGaussianHead(nn.Module):
    """使用稀疏3D卷积将体素特征转换为高斯参数"""
    def __init__(self, in_channels=164, out_channels=38):
        """
        Args:
            in_channels: 输入通道数 (默认164)
            out_channels: 输出通道数 (高斯参数数量,默认38)
        """
        super().__init__()
        
        # 高斯参数数量:34个特征 + 3个位置偏移 + 1个不透明度
        self.num_gaussian_parameters = out_channels
        
        # 稀疏3D卷积网络
        self.conv1 = ME.MinkowskiConvolution(
            in_channels, 
            out_channels, 
            kernel_size=3,
            stride=1,
            dimension=3
        )
        self.act = ME.MinkowskiGELU()
        self.conv2 = ME.MinkowskiConvolution(
            out_channels, 
            out_channels, 
            kernel_size=3,
            stride=1,
            dimension=3
        )
    
        self.init_weights()
    
    def forward(self, sparse_input: ME.SparseTensor):
        """
        前向传播
        Args:
            sparse_input: 稀疏输入张量
        Returns:
            稀疏高斯参数张量
        """
        x = self.conv1(sparse_input)
        x = self.act(x)
        x = self.conv2(x)
        return x
    
    def init_weights(self):
    # """Initialize weights for modules used in this head."""
        for m in self.modules():
            # MinkowskiConvolution: 初始化 kernel,bias 置 0(若存在)
            if isinstance(m, ME.MinkowskiConvolution):
                # m.kernel 是 MinkowskiConvolution 的权重张量
                try:
                    ME.utils.kaiming_normal_(m.kernel,
                                                mode='fan_out',
                                                nonlinearity='relu')
                except Exception:
                    # 保险:若 ME.utils 不同版本行为不同,不让程序崩溃
                    nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
                # 若存在 bias 属性则置 0
                if hasattr(m, 'bias') and m.bias is not None:
                    try:
                        nn.init.constant_(m.bias, 0)
                    except Exception:
                        # 有些 ME 版本将 bias 存在不同位置,忽略初始化错误
                        pass

            # MinkowskiBatchNorm: 将内部 bn 的 weight/bias 初始化
            elif isinstance(m, ME.MinkowskiBatchNorm):
                # MinkowskiBatchNorm 通常封装了一个名为 bn 的 nn.BatchNorm
                if hasattr(m, 'bn'):
                    try:
                        nn.init.constant_(m.bn.weight, 1)
                        nn.init.constant_(m.bn.bias, 0)
                    except Exception:
                        pass






class AttentionBlock(nn.Module):
    """基于Flash Attention的AttentionBlock"""
    def __init__(self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False):
        super().__init__()
        self.channels = channels
        if num_head_channels == -1:
            self.num_heads = num_heads
        else:
            assert channels % num_head_channels == 0, (
                f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}")
            self.num_heads = channels // num_head_channels
        self.use_checkpoint = use_checkpoint
        
        # 归一化层
        self.norm = ME.MinkowskiBatchNorm(channels)
        
        # QKV变换
        self.qkv = ME.MinkowskiLinear(channels, channels * 3)
        self.proj_out = ME.MinkowskiLinear(channels, channels)

    def _attention(self, qkv: torch.Tensor):
        length, width = qkv.shape
        ch = width // (3 * self.num_heads)
        qkv = qkv.reshape(length, self.num_heads, 3 * ch).unsqueeze(0)
        qkv = qkv.permute(0, 2, 1, 3)  # (1, num_heads, length, 3 * ch)
        q, k, v = qkv.chunk(3, dim=-1)  # (1, num_heads, length, ch)
        
        # 使用Flash Attention
        if hasattr(F, 'scaled_dot_product_attention'):
            # 新版本Pytorch API
            with torch.backends.cuda.sdp_kernel(enable_math=False):
                values = F.scaled_dot_product_attention(q, k, v)[0]
        else:
            # 旧版本兼容
            values = F.scaled_dot_product_attention(q, k, v)[0]
        
        values = values.permute(1, 0, 2).reshape(length, -1)
        return values

    def forward(self, x: ME.SparseTensor):
        # 归一化
        x_norm = self.norm(x)
        
        # 计算QKV
        qkv = self.qkv(x_norm)
        
        # 执行注意力计算
        feature_dense = self._attention(qkv.F)
        feature = ME.SparseTensor(
            features=feature_dense,
            coordinate_map_key=qkv.coordinate_map_key,
            coordinate_manager=qkv.coordinate_manager
        )
        
        # 投影回原始尺寸
        output = self.proj_out(feature)
        return output + x  # 残差连接

class SparseConvBlock(nn.Module):
    """稀疏3D卷积块"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
        super().__init__()
        self.conv = ME.MinkowskiConvolution(
            in_channels, out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            dimension=3
        )
        self.norm = ME.MinkowskiBatchNorm(out_channels)
        self.act = ME.MinkowskiReLU(inplace=True)
        
    def forward(self, x: ME.SparseTensor):
        return self.act(self.norm(self.conv(x)))

class SparseUpConvBlock(nn.Module):
    """稀疏3D上采样卷积块"""
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
        super().__init__()
        self.upconv = ME.MinkowskiConvolutionTranspose(
            in_channels, out_channels, 
            kernel_size=kernel_size, 
            stride=stride, 
            dimension=3
        )
        self.norm = ME.MinkowskiBatchNorm(out_channels)
        self.act = ME.MinkowskiReLU(inplace=True)
        
    def forward(self, x: ME.SparseTensor):
        return self.act(self.norm(self.upconv(x)))

class SparseUNetWithAttention(nn.Module):
    """带注意力的稀疏3D U-Net模型(修正版)"""
    def __init__(self, in_channels, out_channels, num_blocks=4, use_attention=False):
        super().__init__()
        self.encoders = nn.ModuleList()
        self.decoder_blocks = nn.ModuleList()  # 修改为模块列表
        self.use_attention = use_attention
        
        # 存储编码器通道信息
        self.encoder_channels = []
        
        # 编码器路径
        current_ch = in_channels
        for i in range(num_blocks):
            out_ch = 64 * (2 ** i)  # 指数增长通道数
            # out_ch = 128 * (2 ** i)  # 指数增长通道数
            self.encoders.append(SparseConvBlock(current_ch, out_ch, kernel_size=3, stride=2))
            self.encoder_channels.append(out_ch)  # 保存输出通道数
            current_ch = out_ch

        # 瓶颈层
        bottleneck_in = current_ch
        bottleneck_out = bottleneck_in * 2
        
        self.bottleneck = nn.ModuleList()
        self.bottleneck.append(SparseConvBlock(bottleneck_in, bottleneck_out, kernel_size=3, stride=2))
        
        if use_attention:
            self.bottleneck.append(AttentionBlock(bottleneck_out))
        
        self.bottleneck.append(SparseConvBlock(bottleneck_out, bottleneck_out, kernel_size=3, stride=1))

        # 解码器路径 - 关键修改
        # 使用变量跟踪当前通道数
        current_decoder_ch = bottleneck_out
        
        for i in range(num_blocks):
            # 解码器输出:对应编码器层的通道数
            decoder_out = self.encoder_channels[-1-i]
            
            # 创建上采样层
            upconv = SparseUpConvBlock(current_decoder_ch, decoder_out, kernel_size=3, stride=2)
            
            # 创建跳跃连接后的卷积层
            after_cat = ME.MinkowskiConvolution(
                decoder_out + self.encoder_channels[-1-i],  # 拼接后通道数
                decoder_out,  # 输出通道数
                kernel_size=1, 
                stride=1, 
                dimension=3
            )
            
            # 组合成完整的解码器块
            self.decoder_blocks.append(nn.ModuleList([upconv, after_cat]))
            
            # 更新当前通道数为解码器输出通道数
            current_decoder_ch = decoder_out

        # 输出层
        # self.output_conv = ME.MinkowskiConvolution(
        #     current_decoder_ch,  # 使用最后一个解码器块的输出通道数
        #     out_channels, 
        #     kernel_size=3, 
        #     stride=1, 
        #     dimension=3
        # )
        
         # 添加最终上采样层
        self.final_upsample = SparseUpConvBlock(
            self.encoder_channels[0],  # 输入通道数
            out_channels,               # 输出通道数
            kernel_size=3, 
            stride=2
        )

    def forward(self, x: ME.SparseTensor):
        encoder_outputs = []
        
        # 编码器路径
        for encoder in self.encoders:
            x = encoder(x)
            encoder_outputs.append(x)

        # 瓶颈层
        for layer in self.bottleneck:
            x = layer(x)

        # 解码器路径
        for i, decoder_block in enumerate(self.decoder_blocks):
            upconv, after_cat = decoder_block
            
            # 上采样
            x = upconv(x)
            
            # 与对应编码器层连接
            enc_index = len(encoder_outputs) - i - 1
            if enc_index >= 0 and enc_index < len(encoder_outputs):
                x = ME.cat(x, encoder_outputs[enc_index])
                
                # 调整通道数
                x = after_cat(x)

        # # 最终输出卷积
        # output = self.output_conv(x)
        
        # 最终上采样到原始分辨率
        output = self.final_upsample(x)
        
        return output









    
# 测试代码
if __name__ == "__main__":
    # 1. 创建输入数据
    batch_size, channels, depth, height, width = 1, 128, 40, 80, 80
    dense_feature = torch.randn(batch_size, channels, depth, height, width)
    
    # 2. 创建稀疏张量
    non_zero_mask = dense_feature.abs().sum(dim=1) > 0
    coordinates = torch.nonzero(non_zero_mask).int().contiguous()
    features = dense_feature[coordinates[:, 0], :, coordinates[:, 1], coordinates[:, 2], coordinates[:, 3]]
    
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    sparse_tensor = ME.SparseTensor(
        features=features.to(device),
        coordinates=coordinates.to(device),
        tensor_stride=1
    )
    
    print(f"创建了稀疏张量: {coordinates.shape[0]}个体素")
    
    # 3. 创建模型
    model = SparseUNetWithAttention(
        in_channels=channels, 
        out_channels=channels,
        num_blocks=3, 
        use_attention=True
    ).to(device)
    
    # 打印模型结构
    print("模型结构:")
    print(model)
    
    # 4. 前向传播
    output = model(sparse_tensor)
    
    print("前向传播成功!")
    print("输出特征形状:", output.F.shape)
    print("输出坐标形状:", output.C.shape)
    
    # 5. 检查坐标是否一致
    input_coords = coordinates.cpu()
    output_coords = output.C.cpu()
    
    # 直接比较坐标是否完全一致
    coord_equal = torch.equal(input_coords, output_coords)
    print(f"\n输入输出坐标是否完全一致: {coord_equal}")
    
    # 如果坐标一致,则不需要进一步检查
    if coord_equal:
        print("模型保持了输入的空间分辨率")
    else:
        # 如果不一致,进行更详细的检查
        print("\n坐标范围比较:")
        print(f"输入深度范围: {input_coords[:,1].min().item()} - {input_coords[:,1].max().item()}")
        print(f"输出深度范围: {output_coords[:,1].min().item()} - {output_coords[:,1].max().item()}")
        
        print(f"输入高度范围: {input_coords[:,2].min().item()} - {input_coords[:,2].max().item()}")
        print(f"输出高度范围: {output_coords[:,2].min().item()} - {output_coords[:,2].max().item()}")
        
        print(f"输入宽度范围: {input_coords[:,3].min().item()} - {input_coords[:,3].max().item()}")
        print(f"输出宽度范围: {output_coords[:,3].min().item()} - {output_coords[:,3].max().item()}")
        
        # 检查坐标数量
        print(f"\n体素数量比较:")
        print(f"输入体素数量: {input_coords.shape[0]}")
        print(f"输出体素数量: {output_coords.shape[0]}")