File size: 12,863 Bytes
a6dd040 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 |
import torch
import torch.nn as nn
import MinkowskiEngine as ME
import torch.nn.functional as F
class SparseGaussianHead(nn.Module):
"""使用稀疏3D卷积将体素特征转换为高斯参数"""
def __init__(self, in_channels=164, out_channels=38):
"""
Args:
in_channels: 输入通道数 (默认164)
out_channels: 输出通道数 (高斯参数数量,默认38)
"""
super().__init__()
# 高斯参数数量:34个特征 + 3个位置偏移 + 1个不透明度
self.num_gaussian_parameters = out_channels
# 稀疏3D卷积网络
self.conv1 = ME.MinkowskiConvolution(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dimension=3
)
self.act = ME.MinkowskiGELU()
self.conv2 = ME.MinkowskiConvolution(
out_channels,
out_channels,
kernel_size=3,
stride=1,
dimension=3
)
self.init_weights()
def forward(self, sparse_input: ME.SparseTensor):
"""
前向传播
Args:
sparse_input: 稀疏输入张量
Returns:
稀疏高斯参数张量
"""
x = self.conv1(sparse_input)
x = self.act(x)
x = self.conv2(x)
return x
def init_weights(self):
# """Initialize weights for modules used in this head."""
for m in self.modules():
# MinkowskiConvolution: 初始化 kernel,bias 置 0(若存在)
if isinstance(m, ME.MinkowskiConvolution):
# m.kernel 是 MinkowskiConvolution 的权重张量
try:
ME.utils.kaiming_normal_(m.kernel,
mode='fan_out',
nonlinearity='relu')
except Exception:
# 保险:若 ME.utils 不同版本行为不同,不让程序崩溃
nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
# 若存在 bias 属性则置 0
if hasattr(m, 'bias') and m.bias is not None:
try:
nn.init.constant_(m.bias, 0)
except Exception:
# 有些 ME 版本将 bias 存在不同位置,忽略初始化错误
pass
# MinkowskiBatchNorm: 将内部 bn 的 weight/bias 初始化
elif isinstance(m, ME.MinkowskiBatchNorm):
# MinkowskiBatchNorm 通常封装了一个名为 bn 的 nn.BatchNorm
if hasattr(m, 'bn'):
try:
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
except Exception:
pass
class AttentionBlock(nn.Module):
"""基于Flash Attention的AttentionBlock"""
def __init__(self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False):
super().__init__()
self.channels = channels
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert channels % num_head_channels == 0, (
f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}")
self.num_heads = channels // num_head_channels
self.use_checkpoint = use_checkpoint
# 归一化层
self.norm = ME.MinkowskiBatchNorm(channels)
# QKV变换
self.qkv = ME.MinkowskiLinear(channels, channels * 3)
self.proj_out = ME.MinkowskiLinear(channels, channels)
def _attention(self, qkv: torch.Tensor):
length, width = qkv.shape
ch = width // (3 * self.num_heads)
qkv = qkv.reshape(length, self.num_heads, 3 * ch).unsqueeze(0)
qkv = qkv.permute(0, 2, 1, 3) # (1, num_heads, length, 3 * ch)
q, k, v = qkv.chunk(3, dim=-1) # (1, num_heads, length, ch)
# 使用Flash Attention
if hasattr(F, 'scaled_dot_product_attention'):
# 新版本Pytorch API
with torch.backends.cuda.sdp_kernel(enable_math=False):
values = F.scaled_dot_product_attention(q, k, v)[0]
else:
# 旧版本兼容
values = F.scaled_dot_product_attention(q, k, v)[0]
values = values.permute(1, 0, 2).reshape(length, -1)
return values
def forward(self, x: ME.SparseTensor):
# 归一化
x_norm = self.norm(x)
# 计算QKV
qkv = self.qkv(x_norm)
# 执行注意力计算
feature_dense = self._attention(qkv.F)
feature = ME.SparseTensor(
features=feature_dense,
coordinate_map_key=qkv.coordinate_map_key,
coordinate_manager=qkv.coordinate_manager
)
# 投影回原始尺寸
output = self.proj_out(feature)
return output + x # 残差连接
class SparseConvBlock(nn.Module):
"""稀疏3D卷积块"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1):
super().__init__()
self.conv = ME.MinkowskiConvolution(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
dimension=3
)
self.norm = ME.MinkowskiBatchNorm(out_channels)
self.act = ME.MinkowskiReLU(inplace=True)
def forward(self, x: ME.SparseTensor):
return self.act(self.norm(self.conv(x)))
class SparseUpConvBlock(nn.Module):
"""稀疏3D上采样卷积块"""
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2):
super().__init__()
self.upconv = ME.MinkowskiConvolutionTranspose(
in_channels, out_channels,
kernel_size=kernel_size,
stride=stride,
dimension=3
)
self.norm = ME.MinkowskiBatchNorm(out_channels)
self.act = ME.MinkowskiReLU(inplace=True)
def forward(self, x: ME.SparseTensor):
return self.act(self.norm(self.upconv(x)))
class SparseUNetWithAttention(nn.Module):
"""带注意力的稀疏3D U-Net模型(修正版)"""
def __init__(self, in_channels, out_channels, num_blocks=4, use_attention=False):
super().__init__()
self.encoders = nn.ModuleList()
self.decoder_blocks = nn.ModuleList() # 修改为模块列表
self.use_attention = use_attention
# 存储编码器通道信息
self.encoder_channels = []
# 编码器路径
current_ch = in_channels
for i in range(num_blocks):
out_ch = 64 * (2 ** i) # 指数增长通道数
# out_ch = 128 * (2 ** i) # 指数增长通道数
self.encoders.append(SparseConvBlock(current_ch, out_ch, kernel_size=3, stride=2))
self.encoder_channels.append(out_ch) # 保存输出通道数
current_ch = out_ch
# 瓶颈层
bottleneck_in = current_ch
bottleneck_out = bottleneck_in * 2
self.bottleneck = nn.ModuleList()
self.bottleneck.append(SparseConvBlock(bottleneck_in, bottleneck_out, kernel_size=3, stride=2))
if use_attention:
self.bottleneck.append(AttentionBlock(bottleneck_out))
self.bottleneck.append(SparseConvBlock(bottleneck_out, bottleneck_out, kernel_size=3, stride=1))
# 解码器路径 - 关键修改
# 使用变量跟踪当前通道数
current_decoder_ch = bottleneck_out
for i in range(num_blocks):
# 解码器输出:对应编码器层的通道数
decoder_out = self.encoder_channels[-1-i]
# 创建上采样层
upconv = SparseUpConvBlock(current_decoder_ch, decoder_out, kernel_size=3, stride=2)
# 创建跳跃连接后的卷积层
after_cat = ME.MinkowskiConvolution(
decoder_out + self.encoder_channels[-1-i], # 拼接后通道数
decoder_out, # 输出通道数
kernel_size=1,
stride=1,
dimension=3
)
# 组合成完整的解码器块
self.decoder_blocks.append(nn.ModuleList([upconv, after_cat]))
# 更新当前通道数为解码器输出通道数
current_decoder_ch = decoder_out
# 输出层
# self.output_conv = ME.MinkowskiConvolution(
# current_decoder_ch, # 使用最后一个解码器块的输出通道数
# out_channels,
# kernel_size=3,
# stride=1,
# dimension=3
# )
# 添加最终上采样层
self.final_upsample = SparseUpConvBlock(
self.encoder_channels[0], # 输入通道数
out_channels, # 输出通道数
kernel_size=3,
stride=2
)
def forward(self, x: ME.SparseTensor):
encoder_outputs = []
# 编码器路径
for encoder in self.encoders:
x = encoder(x)
encoder_outputs.append(x)
# 瓶颈层
for layer in self.bottleneck:
x = layer(x)
# 解码器路径
for i, decoder_block in enumerate(self.decoder_blocks):
upconv, after_cat = decoder_block
# 上采样
x = upconv(x)
# 与对应编码器层连接
enc_index = len(encoder_outputs) - i - 1
if enc_index >= 0 and enc_index < len(encoder_outputs):
x = ME.cat(x, encoder_outputs[enc_index])
# 调整通道数
x = after_cat(x)
# # 最终输出卷积
# output = self.output_conv(x)
# 最终上采样到原始分辨率
output = self.final_upsample(x)
return output
# 测试代码
if __name__ == "__main__":
# 1. 创建输入数据
batch_size, channels, depth, height, width = 1, 128, 40, 80, 80
dense_feature = torch.randn(batch_size, channels, depth, height, width)
# 2. 创建稀疏张量
non_zero_mask = dense_feature.abs().sum(dim=1) > 0
coordinates = torch.nonzero(non_zero_mask).int().contiguous()
features = dense_feature[coordinates[:, 0], :, coordinates[:, 1], coordinates[:, 2], coordinates[:, 3]]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sparse_tensor = ME.SparseTensor(
features=features.to(device),
coordinates=coordinates.to(device),
tensor_stride=1
)
print(f"创建了稀疏张量: {coordinates.shape[0]}个体素")
# 3. 创建模型
model = SparseUNetWithAttention(
in_channels=channels,
out_channels=channels,
num_blocks=3,
use_attention=True
).to(device)
# 打印模型结构
print("模型结构:")
print(model)
# 4. 前向传播
output = model(sparse_tensor)
print("前向传播成功!")
print("输出特征形状:", output.F.shape)
print("输出坐标形状:", output.C.shape)
# 5. 检查坐标是否一致
input_coords = coordinates.cpu()
output_coords = output.C.cpu()
# 直接比较坐标是否完全一致
coord_equal = torch.equal(input_coords, output_coords)
print(f"\n输入输出坐标是否完全一致: {coord_equal}")
# 如果坐标一致,则不需要进一步检查
if coord_equal:
print("模型保持了输入的空间分辨率")
else:
# 如果不一致,进行更详细的检查
print("\n坐标范围比较:")
print(f"输入深度范围: {input_coords[:,1].min().item()} - {input_coords[:,1].max().item()}")
print(f"输出深度范围: {output_coords[:,1].min().item()} - {output_coords[:,1].max().item()}")
print(f"输入高度范围: {input_coords[:,2].min().item()} - {input_coords[:,2].max().item()}")
print(f"输出高度范围: {output_coords[:,2].min().item()} - {output_coords[:,2].max().item()}")
print(f"输入宽度范围: {input_coords[:,3].min().item()} - {input_coords[:,3].max().item()}")
print(f"输出宽度范围: {output_coords[:,3].min().item()} - {output_coords[:,3].max().item()}")
# 检查坐标数量
print(f"\n体素数量比较:")
print(f"输入体素数量: {input_coords.shape[0]}")
print(f"输出体素数量: {output_coords.shape[0]}") |