|
|
import torch |
|
|
import torch.nn as nn |
|
|
import MinkowskiEngine as ME |
|
|
import torch.nn.functional as F |
|
|
|
|
|
class SparseGaussianHead(nn.Module): |
|
|
"""使用稀疏3D卷积将体素特征转换为高斯参数""" |
|
|
def __init__(self, in_channels=164, out_channels=38): |
|
|
""" |
|
|
Args: |
|
|
in_channels: 输入通道数 (默认164) |
|
|
out_channels: 输出通道数 (高斯参数数量,默认38) |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
self.num_gaussian_parameters = out_channels |
|
|
|
|
|
|
|
|
self.conv1 = ME.MinkowskiConvolution( |
|
|
in_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
dimension=3 |
|
|
) |
|
|
self.act = ME.MinkowskiGELU() |
|
|
self.conv2 = ME.MinkowskiConvolution( |
|
|
out_channels, |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=1, |
|
|
dimension=3 |
|
|
) |
|
|
|
|
|
self.init_weights() |
|
|
|
|
|
def forward(self, sparse_input: ME.SparseTensor): |
|
|
""" |
|
|
前向传播 |
|
|
Args: |
|
|
sparse_input: 稀疏输入张量 |
|
|
Returns: |
|
|
稀疏高斯参数张量 |
|
|
""" |
|
|
x = self.conv1(sparse_input) |
|
|
x = self.act(x) |
|
|
x = self.conv2(x) |
|
|
return x |
|
|
|
|
|
def init_weights(self): |
|
|
|
|
|
for m in self.modules(): |
|
|
|
|
|
if isinstance(m, ME.MinkowskiConvolution): |
|
|
|
|
|
try: |
|
|
ME.utils.kaiming_normal_(m.kernel, |
|
|
mode='fan_out', |
|
|
nonlinearity='relu') |
|
|
except Exception: |
|
|
|
|
|
nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu') |
|
|
|
|
|
if hasattr(m, 'bias') and m.bias is not None: |
|
|
try: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
elif isinstance(m, ME.MinkowskiBatchNorm): |
|
|
|
|
|
if hasattr(m, 'bn'): |
|
|
try: |
|
|
nn.init.constant_(m.bn.weight, 1) |
|
|
nn.init.constant_(m.bn.bias, 0) |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class AttentionBlock(nn.Module): |
|
|
"""基于Flash Attention的AttentionBlock""" |
|
|
def __init__(self, channels, num_heads=1, num_head_channels=-1, use_checkpoint=False): |
|
|
super().__init__() |
|
|
self.channels = channels |
|
|
if num_head_channels == -1: |
|
|
self.num_heads = num_heads |
|
|
else: |
|
|
assert channels % num_head_channels == 0, ( |
|
|
f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}") |
|
|
self.num_heads = channels // num_head_channels |
|
|
self.use_checkpoint = use_checkpoint |
|
|
|
|
|
|
|
|
self.norm = ME.MinkowskiBatchNorm(channels) |
|
|
|
|
|
|
|
|
self.qkv = ME.MinkowskiLinear(channels, channels * 3) |
|
|
self.proj_out = ME.MinkowskiLinear(channels, channels) |
|
|
|
|
|
def _attention(self, qkv: torch.Tensor): |
|
|
length, width = qkv.shape |
|
|
ch = width // (3 * self.num_heads) |
|
|
qkv = qkv.reshape(length, self.num_heads, 3 * ch).unsqueeze(0) |
|
|
qkv = qkv.permute(0, 2, 1, 3) |
|
|
q, k, v = qkv.chunk(3, dim=-1) |
|
|
|
|
|
|
|
|
if hasattr(F, 'scaled_dot_product_attention'): |
|
|
|
|
|
with torch.backends.cuda.sdp_kernel(enable_math=False): |
|
|
values = F.scaled_dot_product_attention(q, k, v)[0] |
|
|
else: |
|
|
|
|
|
values = F.scaled_dot_product_attention(q, k, v)[0] |
|
|
|
|
|
values = values.permute(1, 0, 2).reshape(length, -1) |
|
|
return values |
|
|
|
|
|
def forward(self, x: ME.SparseTensor): |
|
|
|
|
|
x_norm = self.norm(x) |
|
|
|
|
|
|
|
|
qkv = self.qkv(x_norm) |
|
|
|
|
|
|
|
|
feature_dense = self._attention(qkv.F) |
|
|
feature = ME.SparseTensor( |
|
|
features=feature_dense, |
|
|
coordinate_map_key=qkv.coordinate_map_key, |
|
|
coordinate_manager=qkv.coordinate_manager |
|
|
) |
|
|
|
|
|
|
|
|
output = self.proj_out(feature) |
|
|
return output + x |
|
|
|
|
|
class SparseConvBlock(nn.Module): |
|
|
"""稀疏3D卷积块""" |
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=1): |
|
|
super().__init__() |
|
|
self.conv = ME.MinkowskiConvolution( |
|
|
in_channels, out_channels, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
dimension=3 |
|
|
) |
|
|
self.norm = ME.MinkowskiBatchNorm(out_channels) |
|
|
self.act = ME.MinkowskiReLU(inplace=True) |
|
|
|
|
|
def forward(self, x: ME.SparseTensor): |
|
|
return self.act(self.norm(self.conv(x))) |
|
|
|
|
|
class SparseUpConvBlock(nn.Module): |
|
|
"""稀疏3D上采样卷积块""" |
|
|
def __init__(self, in_channels, out_channels, kernel_size=3, stride=2): |
|
|
super().__init__() |
|
|
self.upconv = ME.MinkowskiConvolutionTranspose( |
|
|
in_channels, out_channels, |
|
|
kernel_size=kernel_size, |
|
|
stride=stride, |
|
|
dimension=3 |
|
|
) |
|
|
self.norm = ME.MinkowskiBatchNorm(out_channels) |
|
|
self.act = ME.MinkowskiReLU(inplace=True) |
|
|
|
|
|
def forward(self, x: ME.SparseTensor): |
|
|
return self.act(self.norm(self.upconv(x))) |
|
|
|
|
|
class SparseUNetWithAttention(nn.Module): |
|
|
"""带注意力的稀疏3D U-Net模型(修正版)""" |
|
|
def __init__(self, in_channels, out_channels, num_blocks=4, use_attention=False): |
|
|
super().__init__() |
|
|
self.encoders = nn.ModuleList() |
|
|
self.decoder_blocks = nn.ModuleList() |
|
|
self.use_attention = use_attention |
|
|
|
|
|
|
|
|
self.encoder_channels = [] |
|
|
|
|
|
|
|
|
current_ch = in_channels |
|
|
for i in range(num_blocks): |
|
|
out_ch = 64 * (2 ** i) |
|
|
|
|
|
self.encoders.append(SparseConvBlock(current_ch, out_ch, kernel_size=3, stride=2)) |
|
|
self.encoder_channels.append(out_ch) |
|
|
current_ch = out_ch |
|
|
|
|
|
|
|
|
bottleneck_in = current_ch |
|
|
bottleneck_out = bottleneck_in * 2 |
|
|
|
|
|
self.bottleneck = nn.ModuleList() |
|
|
self.bottleneck.append(SparseConvBlock(bottleneck_in, bottleneck_out, kernel_size=3, stride=2)) |
|
|
|
|
|
if use_attention: |
|
|
self.bottleneck.append(AttentionBlock(bottleneck_out)) |
|
|
|
|
|
self.bottleneck.append(SparseConvBlock(bottleneck_out, bottleneck_out, kernel_size=3, stride=1)) |
|
|
|
|
|
|
|
|
|
|
|
current_decoder_ch = bottleneck_out |
|
|
|
|
|
for i in range(num_blocks): |
|
|
|
|
|
decoder_out = self.encoder_channels[-1-i] |
|
|
|
|
|
|
|
|
upconv = SparseUpConvBlock(current_decoder_ch, decoder_out, kernel_size=3, stride=2) |
|
|
|
|
|
|
|
|
after_cat = ME.MinkowskiConvolution( |
|
|
decoder_out + self.encoder_channels[-1-i], |
|
|
decoder_out, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
dimension=3 |
|
|
) |
|
|
|
|
|
|
|
|
self.decoder_blocks.append(nn.ModuleList([upconv, after_cat])) |
|
|
|
|
|
|
|
|
current_decoder_ch = decoder_out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.final_upsample = SparseUpConvBlock( |
|
|
self.encoder_channels[0], |
|
|
out_channels, |
|
|
kernel_size=3, |
|
|
stride=2 |
|
|
) |
|
|
|
|
|
def forward(self, x: ME.SparseTensor): |
|
|
encoder_outputs = [] |
|
|
|
|
|
|
|
|
for encoder in self.encoders: |
|
|
x = encoder(x) |
|
|
encoder_outputs.append(x) |
|
|
|
|
|
|
|
|
for layer in self.bottleneck: |
|
|
x = layer(x) |
|
|
|
|
|
|
|
|
for i, decoder_block in enumerate(self.decoder_blocks): |
|
|
upconv, after_cat = decoder_block |
|
|
|
|
|
|
|
|
x = upconv(x) |
|
|
|
|
|
|
|
|
enc_index = len(encoder_outputs) - i - 1 |
|
|
if enc_index >= 0 and enc_index < len(encoder_outputs): |
|
|
x = ME.cat(x, encoder_outputs[enc_index]) |
|
|
|
|
|
|
|
|
x = after_cat(x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
output = self.final_upsample(x) |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
batch_size, channels, depth, height, width = 1, 128, 40, 80, 80 |
|
|
dense_feature = torch.randn(batch_size, channels, depth, height, width) |
|
|
|
|
|
|
|
|
non_zero_mask = dense_feature.abs().sum(dim=1) > 0 |
|
|
coordinates = torch.nonzero(non_zero_mask).int().contiguous() |
|
|
features = dense_feature[coordinates[:, 0], :, coordinates[:, 1], coordinates[:, 2], coordinates[:, 3]] |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
sparse_tensor = ME.SparseTensor( |
|
|
features=features.to(device), |
|
|
coordinates=coordinates.to(device), |
|
|
tensor_stride=1 |
|
|
) |
|
|
|
|
|
print(f"创建了稀疏张量: {coordinates.shape[0]}个体素") |
|
|
|
|
|
|
|
|
model = SparseUNetWithAttention( |
|
|
in_channels=channels, |
|
|
out_channels=channels, |
|
|
num_blocks=3, |
|
|
use_attention=True |
|
|
).to(device) |
|
|
|
|
|
|
|
|
print("模型结构:") |
|
|
print(model) |
|
|
|
|
|
|
|
|
output = model(sparse_tensor) |
|
|
|
|
|
print("前向传播成功!") |
|
|
print("输出特征形状:", output.F.shape) |
|
|
print("输出坐标形状:", output.C.shape) |
|
|
|
|
|
|
|
|
input_coords = coordinates.cpu() |
|
|
output_coords = output.C.cpu() |
|
|
|
|
|
|
|
|
coord_equal = torch.equal(input_coords, output_coords) |
|
|
print(f"\n输入输出坐标是否完全一致: {coord_equal}") |
|
|
|
|
|
|
|
|
if coord_equal: |
|
|
print("模型保持了输入的空间分辨率") |
|
|
else: |
|
|
|
|
|
print("\n坐标范围比较:") |
|
|
print(f"输入深度范围: {input_coords[:,1].min().item()} - {input_coords[:,1].max().item()}") |
|
|
print(f"输出深度范围: {output_coords[:,1].min().item()} - {output_coords[:,1].max().item()}") |
|
|
|
|
|
print(f"输入高度范围: {input_coords[:,2].min().item()} - {input_coords[:,2].max().item()}") |
|
|
print(f"输出高度范围: {output_coords[:,2].min().item()} - {output_coords[:,2].max().item()}") |
|
|
|
|
|
print(f"输入宽度范围: {input_coords[:,3].min().item()} - {input_coords[:,3].max().item()}") |
|
|
print(f"输出宽度范围: {output_coords[:,3].min().item()} - {output_coords[:,3].max().item()}") |
|
|
|
|
|
|
|
|
print(f"\n体素数量比较:") |
|
|
print(f"输入体素数量: {input_coords.shape[0]}") |
|
|
print(f"输出体素数量: {output_coords.shape[0]}") |