import torch import torch.nn as nn import MinkowskiEngine as ME import torch.nn.functional as F class SparseGaussianHead(nn.Module): """使用稀疏3D卷积将体素特征转换为高斯参数""" def __init__(self, in_channels=164, out_channels=38): """ Args: in_channels: 输入通道数 (默认164) out_channels: 输出通道数 (高斯参数数量,默认38) """ super().__init__() # 高斯参数数量:34个特征 + 3个位置偏移 + 1个不透明度 self.num_gaussian_parameters = out_channels # 稀疏3D卷积网络 self.conv1 = ME.MinkowskiConvolution( in_channels, out_channels, kernel_size=3, stride=1, dimension=3 ) self.act = ME.MinkowskiGELU() self.conv2 = ME.MinkowskiConvolution( out_channels, out_channels, kernel_size=3, stride=1, dimension=3 ) self.init_weights() def forward(self, sparse_input: ME.SparseTensor): """ 前向传播 Args: sparse_input: 稀疏输入张量 Returns: 稀疏高斯参数张量 """ x = self.conv1(sparse_input) x = self.act(x) x = self.conv2(x) return x def init_weights(self): # """Initialize weights for modules used in this head.""" for m in self.modules(): # MinkowskiConvolution: 初始化 kernel,bias 置 0(若存在) if isinstance(m, ME.MinkowskiConvolution): # m.kernel 是 MinkowskiConvolution 的权重张量 try: ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') except Exception: # 保险:若 ME.utils 不同版本行为不同,不让程序崩溃 nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') # 若存在 bias 属性则置 0 if hasattr(m, 'bias') and m.bias is not None: try: nn.init.constant_(m.bias, 0) except Exception: # 有些 ME 版本将 bias 存在不同位置,忽略初始化错误 pass # MinkowskiBatchNorm: 将内部 bn 的 weight/bias 初始化 elif isinstance(m, ME.MinkowskiBatchNorm): # MinkowskiBatchNorm 通常封装了一个名为 bn 的 nn.BatchNorm if hasattr(m, 'bn'): try: nn.init.constant_(m.bn.weight, 1) nn.init.constant_(m.bn.bias, 0) except Exception: pass class AttentionBlock(nn.Module): """基于Flash Attention的AttentionBlock""" def __init__(self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False): super().__init__() self.channels = channels if num_head_channels == -1: self.num_heads = num_heads else: assert channels % num_head_channels == 0, ( f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}") self.num_heads = channels // num_head_channels self.use_checkpoint = use_checkpoint # 归一化层 self.norm = ME.MinkowskiBatchNorm(channels) # QKV变换 self.qkv = ME.MinkowskiLinear(channels, channels * 3) self.proj_out = ME.MinkowskiLinear(channels, channels) def _attention(self, qkv: torch.Tensor): length, width = qkv.shape ch = width // (3 * self.num_heads) qkv = qkv.reshape(length, self.num_heads, 3 * ch).unsqueeze(0) qkv = qkv.permute(0, 2, 1, 3) # (1, num_heads, length, 3 * ch) q, k, v = qkv.chunk(3, dim=-1) # (1, num_heads, length, ch) # 使用Flash Attention if hasattr(F, 'scaled_dot_product_attention'): # 新版本Pytorch API with torch.backends.cuda.sdp_kernel(enable_math=False): values = F.scaled_dot_product_attention(q, k, v)[0] else: # 旧版本兼容 values = F.scaled_dot_product_attention(q, k, v)[0] values = values.permute(1, 0, 2).reshape(length, -1) return values def forward(self, x: ME.SparseTensor): # 归一化 x_norm = self.norm(x) # 计算QKV qkv = self.qkv(x_norm) # 执行注意力计算 feature_dense = self._attention(qkv.F) feature = ME.SparseTensor( features=feature_dense, coordinate_map_key=qkv.coordinate_map_key, coordinate_manager=qkv.coordinate_manager ) # 投影回原始尺寸 output = self.proj_out(feature) return output + x # 残差连接 class SparseConvBlock(nn.Module): """稀疏3D卷积块""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): super().__init__() self.conv = ME.MinkowskiConvolution( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dimension=3 ) self.norm = ME.MinkowskiBatchNorm(out_channels) self.act = ME.MinkowskiReLU(inplace=True) def forward(self, x: ME.SparseTensor): return self.act(self.norm(self.conv(x))) class SparseUpConvBlock(nn.Module): """稀疏3D上采样卷积块""" def __init__(self, in_channels, out_channels, kernel_size=3, stride=2): super().__init__() self.upconv = ME.MinkowskiConvolutionTranspose( in_channels, out_channels, kernel_size=kernel_size, stride=stride, dimension=3 ) self.norm = ME.MinkowskiBatchNorm(out_channels) self.act = ME.MinkowskiReLU(inplace=True) def forward(self, x: ME.SparseTensor): return self.act(self.norm(self.upconv(x))) class SparseUNetWithAttention(nn.Module): """带注意力的稀疏3D U-Net模型(修正版)""" def __init__(self, in_channels, out_channels, num_blocks=4, use_attention=False): super().__init__() self.encoders = nn.ModuleList() self.decoder_blocks = nn.ModuleList() # 修改为模块列表 self.use_attention = use_attention # 存储编码器通道信息 self.encoder_channels = [] # 编码器路径 current_ch = in_channels for i in range(num_blocks): out_ch = 64 * (2 ** i) # 指数增长通道数 # out_ch = 128 * (2 ** i) # 指数增长通道数 self.encoders.append(SparseConvBlock(current_ch, out_ch, kernel_size=3, stride=2)) self.encoder_channels.append(out_ch) # 保存输出通道数 current_ch = out_ch # 瓶颈层 bottleneck_in = current_ch bottleneck_out = bottleneck_in * 2 self.bottleneck = nn.ModuleList() self.bottleneck.append(SparseConvBlock(bottleneck_in, bottleneck_out, kernel_size=3, stride=2)) if use_attention: self.bottleneck.append(AttentionBlock(bottleneck_out)) self.bottleneck.append(SparseConvBlock(bottleneck_out, bottleneck_out, kernel_size=3, stride=1)) # 解码器路径 - 关键修改 # 使用变量跟踪当前通道数 current_decoder_ch = bottleneck_out for i in range(num_blocks): # 解码器输出:对应编码器层的通道数 decoder_out = self.encoder_channels[-1-i] # 创建上采样层 upconv = SparseUpConvBlock(current_decoder_ch, decoder_out, kernel_size=3, stride=2) # 创建跳跃连接后的卷积层 after_cat = ME.MinkowskiConvolution( decoder_out + self.encoder_channels[-1-i], # 拼接后通道数 decoder_out, # 输出通道数 kernel_size=1, stride=1, dimension=3 ) # 组合成完整的解码器块 self.decoder_blocks.append(nn.ModuleList([upconv, after_cat])) # 更新当前通道数为解码器输出通道数 current_decoder_ch = decoder_out # 输出层 # self.output_conv = ME.MinkowskiConvolution( # current_decoder_ch, # 使用最后一个解码器块的输出通道数 # out_channels, # kernel_size=3, # stride=1, # dimension=3 # ) # 添加最终上采样层 self.final_upsample = SparseUpConvBlock( self.encoder_channels[0], # 输入通道数 out_channels, # 输出通道数 kernel_size=3, stride=2 ) def forward(self, x: ME.SparseTensor): encoder_outputs = [] # 编码器路径 for encoder in self.encoders: x = encoder(x) encoder_outputs.append(x) # 瓶颈层 for layer in self.bottleneck: x = layer(x) # 解码器路径 for i, decoder_block in enumerate(self.decoder_blocks): upconv, after_cat = decoder_block # 上采样 x = upconv(x) # 与对应编码器层连接 enc_index = len(encoder_outputs) - i - 1 if enc_index >= 0 and enc_index < len(encoder_outputs): x = ME.cat(x, encoder_outputs[enc_index]) # 调整通道数 x = after_cat(x) # # 最终输出卷积 # output = self.output_conv(x) # 最终上采样到原始分辨率 output = self.final_upsample(x) return output # 测试代码 if __name__ == "__main__": # 1. 创建输入数据 batch_size, channels, depth, height, width = 1, 128, 40, 80, 80 dense_feature = torch.randn(batch_size, channels, depth, height, width) # 2. 创建稀疏张量 non_zero_mask = dense_feature.abs().sum(dim=1) > 0 coordinates = torch.nonzero(non_zero_mask).int().contiguous() features = dense_feature[coordinates[:, 0], :, coordinates[:, 1], coordinates[:, 2], coordinates[:, 3]] device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') sparse_tensor = ME.SparseTensor( features=features.to(device), coordinates=coordinates.to(device), tensor_stride=1 ) print(f"创建了稀疏张量: {coordinates.shape[0]}个体素") # 3. 创建模型 model = SparseUNetWithAttention( in_channels=channels, out_channels=channels, num_blocks=3, use_attention=True ).to(device) # 打印模型结构 print("模型结构:") print(model) # 4. 前向传播 output = model(sparse_tensor) print("前向传播成功!") print("输出特征形状:", output.F.shape) print("输出坐标形状:", output.C.shape) # 5. 检查坐标是否一致 input_coords = coordinates.cpu() output_coords = output.C.cpu() # 直接比较坐标是否完全一致 coord_equal = torch.equal(input_coords, output_coords) print(f"\n输入输出坐标是否完全一致: {coord_equal}") # 如果坐标一致,则不需要进一步检查 if coord_equal: print("模型保持了输入的空间分辨率") else: # 如果不一致,进行更详细的检查 print("\n坐标范围比较:") print(f"输入深度范围: {input_coords[:,1].min().item()} - {input_coords[:,1].max().item()}") print(f"输出深度范围: {output_coords[:,1].min().item()} - {output_coords[:,1].max().item()}") print(f"输入高度范围: {input_coords[:,2].min().item()} - {input_coords[:,2].max().item()}") print(f"输出高度范围: {output_coords[:,2].min().item()} - {output_coords[:,2].max().item()}") print(f"输入宽度范围: {input_coords[:,3].min().item()} - {input_coords[:,3].max().item()}") print(f"输出宽度范围: {output_coords[:,3].min().item()} - {output_coords[:,3].max().item()}") # 检查坐标数量 print(f"\n体素数量比较:") print(f"输入体素数量: {input_coords.shape[0]}") print(f"输出体素数量: {output_coords.shape[0]}")