Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
import torch
import torch.nn as nn
import MinkowskiEngine as ME
import torch.nn.functional as F
class SparseGaussianHead(nn.Module):
"""使用稀疏3D卷积将体素特征转换为高斯参数"""
def __init__(self, in_channels=164, out_channels=38):
"""
Args:
in_channels: 输入通道数 (默认164)
out_channels: 输出通道数 (高斯参数数量,默认38)
"""
super().__init__()
# 高斯参数数量:34个特征 + 3个位置偏移 + 1个不透明度
self.num_gaussian_parameters = out_channels
# 稀疏3D卷积网络
self.conv1 = ME.MinkowskiConvolution(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dimension=3
)
self.act = ME.MinkowskiGELU()
self.conv2 = ME.MinkowskiConvolution(
out_channels,
out_channels,
kernel_size=3,
stride=1,
dimension=3
)
self.init_weights()
def forward(self, sparse_input: ME.SparseTensor):
"""
前向传播
Args:
sparse_input: 稀疏输入张量
Returns:
稀疏高斯参数张量
"""
x = self.conv1(sparse_input)
x = self.act(x)
x = self.conv2(x)
return x
def init_weights(self):
# """Initialize weights for modules used in this head."""
for m in self.modules():
# MinkowskiConvolution: 初始化 kernel,bias 置 0(若存在)
if isinstance(m, ME.MinkowskiConvolution):
# m.kernel 是 MinkowskiConvolution 的权重张量
try:
ME.utils.kaiming_normal_(m.kernel,
mode='fan_out',
nonlinearity='relu')
except Exception:
# 保险:若 ME.utils 不同版本行为不同,不让程序崩溃
nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
# 若存在 bias 属性则置 0
if hasattr(m, 'bias') and m.bias is not None:
try:
nn.init.constant_(m.bias, 0)
except Exception:
# 有些 ME 版本将 bias 存在不同位置,忽略初始化错误
pass
# MinkowskiBatchNorm: 将内部 bn 的 weight/bias 初始化
elif isinstance(m, ME.MinkowskiBatchNorm):
# MinkowskiBatchNorm 通常封装了一个名为 bn 的 nn.BatchNorm
if hasattr(m, 'bn'):
try:
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
except Exception:
pass
class AttentionBlock(nn.Module):
"""基于Flash Attention的AttentionBlock"""
def __init__(self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert channels % num_head_channels == 0, (
f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}")
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
# 归一化层
self.norm = ME.MinkowskiBatchNorm(channels)
# QKV变换
self.qkv = ME.MinkowskiLinear(channels, channels * 3)
self.proj_out = ME.MinkowskiLinear(channels, channels)
def _attention(self, qkv: torch.Tensor):
length, width = qkv.shape
ch = width // (3 * self.num_heads)
qkv = qkv.reshape(length, self.num_heads, 3 * ch).unsqueeze(0)
qkv = qkv.permute(0, 2, 1, 3) # (1, num_heads, length, 3 * ch)
q, k, v = qkv.chunk(3, dim=-1) # (1, num_heads, length, ch)
# 使用Flash Attention
if hasattr(F, 'scaled_dot_product_attention'):
# 新版本Pytorch API
with torch.backends.cuda.sdp_kernel(enable_math=False):
values = F.scaled_dot_product_attention(q, k, v)[0]
else:
# 旧版本兼容
values = F.scaled_dot_product_attention(q, k, v)[0]
values = values.permute(1, 0, 2).reshape(length, -1)
return values
def forward(self, x: ME.SparseTensor):
# 归一化
x_norm = self.norm(x)
# 计算QKV
qkv = self.qkv(x_norm)
# 执行注意力计算
feature_dense = self._attention(qkv.F)
feature = ME.SparseTensor(
features=feature_dense,
coordinate_map_key=qkv.coordinate_map_key,
coordinate_manager=qkv.coordinate_manager
)
# 投影回原始尺寸
output = self.proj_out(feature)
return output + x # 残差连接
class SparseConvBlock(nn.Module):
"""稀疏3D卷积块"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
super().__init__()
self.conv = ME.MinkowskiConvolution(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
dimension=3
)
self.norm = ME.MinkowskiBatchNorm(out_channels)
self.act = ME.MinkowskiReLU(inplace=True)
def forward(self, x: ME.SparseTensor):
return self.act(self.norm(self.conv(x)))
class SparseUpConvBlock(nn.Module):
"""稀疏3D上采样卷积块"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
super().__init__()
self.upconv = ME.MinkowskiConvolutionTranspose(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
dimension=3
)
self.norm = ME.MinkowskiBatchNorm(out_channels)
self.act = ME.MinkowskiReLU(inplace=True)
def forward(self, x: ME.SparseTensor):
return self.act(self.norm(self.upconv(x)))
class SparseUNetWithAttention(nn.Module):
"""带注意力的稀疏3D U-Net模型(修正版)"""
def __init__(self, in_channels, out_channels, num_blocks=4, use_attention=False):
super().__init__()
self.encoders = nn.ModuleList()
self.decoder_blocks = nn.ModuleList() # 修改为模块列表
self.use_attention = use_attention
# 存储编码器通道信息
self.encoder_channels = []
# 编码器路径
current_ch = in_channels
for i in range(num_blocks):
out_ch = 64 * (2 ** i) # 指数增长通道数
# out_ch = 128 * (2 ** i) # 指数增长通道数
self.encoders.append(SparseConvBlock(current_ch, out_ch, kernel_size=3, stride=2))
self.encoder_channels.append(out_ch) # 保存输出通道数
current_ch = out_ch
# 瓶颈层
bottleneck_in = current_ch
bottleneck_out = bottleneck_in * 2
self.bottleneck = nn.ModuleList()
self.bottleneck.append(SparseConvBlock(bottleneck_in, bottleneck_out, kernel_size=3, stride=2))
if use_attention:
self.bottleneck.append(AttentionBlock(bottleneck_out))
self.bottleneck.append(SparseConvBlock(bottleneck_out, bottleneck_out, kernel_size=3, stride=1))
# 解码器路径 - 关键修改
# 使用变量跟踪当前通道数
current_decoder_ch = bottleneck_out
for i in range(num_blocks):
# 解码器输出:对应编码器层的通道数
decoder_out = self.encoder_channels[-1-i]
# 创建上采样层
upconv = SparseUpConvBlock(current_decoder_ch, decoder_out, kernel_size=3, stride=2)
# 创建跳跃连接后的卷积层
after_cat = ME.MinkowskiConvolution(
decoder_out + self.encoder_channels[-1-i], # 拼接后通道数
decoder_out, # 输出通道数
kernel_size=1,
stride=1,
dimension=3
)
# 组合成完整的解码器块
self.decoder_blocks.append(nn.ModuleList([upconv, after_cat]))
# 更新当前通道数为解码器输出通道数
current_decoder_ch = decoder_out
# 输出层
# self.output_conv = ME.MinkowskiConvolution(
# current_decoder_ch, # 使用最后一个解码器块的输出通道数
# out_channels,
# kernel_size=3,
# stride=1,
# dimension=3
# )
# 添加最终上采样层
self.final_upsample = SparseUpConvBlock(
self.encoder_channels[0], # 输入通道数
out_channels, # 输出通道数
kernel_size=3,
stride=2
)
def forward(self, x: ME.SparseTensor):
encoder_outputs = []
# 编码器路径
for encoder in self.encoders:
x = encoder(x)
encoder_outputs.append(x)
# 瓶颈层
for layer in self.bottleneck:
x = layer(x)
# 解码器路径
for i, decoder_block in enumerate(self.decoder_blocks):
upconv, after_cat = decoder_block
# 上采样
x = upconv(x)
# 与对应编码器层连接
enc_index = len(encoder_outputs) - i - 1
if enc_index >= 0 and enc_index < len(encoder_outputs):
x = ME.cat(x, encoder_outputs[enc_index])
# 调整通道数
x = after_cat(x)
# # 最终输出卷积
# output = self.output_conv(x)
# 最终上采样到原始分辨率
output = self.final_upsample(x)
return output
# 测试代码
if __name__ == "__main__":
# 1. 创建输入数据
batch_size, channels, depth, height, width = 1, 128, 40, 80, 80
dense_feature = torch.randn(batch_size, channels, depth, height, width)
# 2. 创建稀疏张量
non_zero_mask = dense_feature.abs().sum(dim=1) > 0
coordinates = torch.nonzero(non_zero_mask).int().contiguous()
features = dense_feature[coordinates[:, 0], :, coordinates[:, 1], coordinates[:, 2], coordinates[:, 3]]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sparse_tensor = ME.SparseTensor(
features=features.to(device),
coordinates=coordinates.to(device),
tensor_stride=1
)
print(f"创建了稀疏张量: {coordinates.shape[0]}个体素")
# 3. 创建模型
model = SparseUNetWithAttention(
in_channels=channels,
out_channels=channels,
num_blocks=3,
use_attention=True
).to(device)
# 打印模型结构
print("模型结构:")
print(model)
# 4. 前向传播
output = model(sparse_tensor)
print("前向传播成功!")
print("输出特征形状:", output.F.shape)
print("输出坐标形状:", output.C.shape)
# 5. 检查坐标是否一致
input_coords = coordinates.cpu()
output_coords = output.C.cpu()
# 直接比较坐标是否完全一致
coord_equal = torch.equal(input_coords, output_coords)
print(f"\n输入输出坐标是否完全一致: {coord_equal}")
# 如果坐标一致,则不需要进一步检查
if coord_equal:
print("模型保持了输入的空间分辨率")
else:
# 如果不一致,进行更详细的检查
print("\n坐标范围比较:")
print(f"输入深度范围: {input_coords[:,1].min().item()} - {input_coords[:,1].max().item()}")
print(f"输出深度范围: {output_coords[:,1].min().item()} - {output_coords[:,1].max().item()}")
print(f"输入高度范围: {input_coords[:,2].min().item()} - {input_coords[:,2].max().item()}")
print(f"输出高度范围: {output_coords[:,2].min().item()} - {output_coords[:,2].max().item()}")
print(f"输入宽度范围: {input_coords[:,3].min().item()} - {input_coords[:,3].max().item()}")
print(f"输出宽度范围: {output_coords[:,3].min().item()} - {output_coords[:,3].max().item()}")
# 检查坐标数量
print(f"\n体素数量比较:")
print(f"输入体素数量: {input_coords.shape[0]}")
print(f"输出体素数量: {output_coords.shape[0]}")