depthsplat / src /model /encoder /common /mink_resnet.py
Yeqing0814's picture
Upload folder using huggingface_hub
a6dd040 verified
import MinkowskiEngine as ME
import torch
import torch.nn as nn
class SparseResidualBlock(nn.Module):
"""稀疏残差块,类似于ResNet的基本块"""
def __init__(self, in_channels, out_channels, stride=1, expansion=1):
super().__init__()
self.expansion = expansion
mid_channels = out_channels // expansion
# 主路径
self.conv1 = ME.MinkowskiConvolution(
in_channels, mid_channels, kernel_size=3,
stride=stride, dimension=3
)
self.bn1 = ME.MinkowskiBatchNorm(mid_channels)
self.relu = ME.MinkowskiReLU(inplace=True)
self.conv2 = ME.MinkowskiConvolution(
mid_channels, out_channels, kernel_size=3,
stride=1, dimension=3
)
self.bn2 = ME.MinkowskiBatchNorm(out_channels)
# 捷径连接
self.shortcut = nn.Sequential()
if stride != 1 or in_channels != out_channels:
self.shortcut = nn.Sequential(
ME.MinkowskiConvolution(
in_channels, out_channels, kernel_size=1,
stride=stride, dimension=3
),
ME.MinkowskiBatchNorm(out_channels)
)
def forward(self, x):
identity = self.shortcut(x)
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
out += identity
out = self.relu(out)
return out
class SparseGaussianHead(nn.Module):
"""使用稀疏3D卷积将体素特征转换为高斯参数"""
def __init__(self, in_channels=64, out_channels=38):
"""
Args:
in_channels: 输入通道数 (默认64)
out_channels: 输出通道数 (高斯参数数量,默认38)
"""
super().__init__()
# 高斯参数数量:34个特征 + 3个位置偏移 + 1个不透明度
self.num_gaussian_parameters = out_channels
# 稀疏3D卷积网络
self.conv1 = ME.MinkowskiConvolution(
in_channels,
out_channels,
kernel_size=3,
stride=1,
dimension=3
)
self.act = ME.MinkowskiGELU()
self.conv2 = ME.MinkowskiConvolution(
out_channels,
out_channels,
kernel_size=3,
stride=1,
dimension=3
)
self.init_weights()
def forward(self, sparse_input: ME.SparseTensor):
"""
前向传播
Args:
sparse_input: 稀疏输入张量
Returns:
稀疏高斯参数张量
"""
x = self.conv1(sparse_input)
x = self.act(x)
x = self.conv2(x)
return x
def init_weights(self):
"""初始化权重"""
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
try:
ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
except:
nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, ME.MinkowskiBatchNorm):
if hasattr(m, 'bn'):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
class MultiScaleSparseHead(nn.Module):
"""多尺度稀疏高斯头,输出四个尺度(1/2,1/4,1/8,1/16)的特征"""
def __init__(self, in_channels=164, base_channels=64, num_blocks=[2, 2, 2, 2], gaussian_out_channels=38):
"""
Args:
in_channels: 输入通道数
base_channels: 基础通道数
num_blocks: 每个阶段的残差块数量
gaussian_out_channels: 高斯参数输出通道数
"""
super().__init__()
self.in_channels = in_channels
self.base_channels = base_channels
# 初始下采样层 - 1/2分辨率
self.conv1 = ME.MinkowskiConvolution(
in_channels, base_channels, kernel_size=7,
stride=2, dimension=3
)
self.bn1 = ME.MinkowskiBatchNorm(base_channels)
self.relu = ME.MinkowskiReLU(inplace=True)
# 四个阶段的残差块,每阶段下采样2倍
# 阶段1: 1/4分辨率
self.stage1 = self._make_stage(
base_channels, base_channels * 1, num_blocks[0], stride=2
)
# 阶段2: 1/8分辨率
self.stage2 = self._make_stage(
base_channels * 1, base_channels * 2, num_blocks[1], stride=2
)
# 阶段3: 1/16分辨率
self.stage3 = self._make_stage(
base_channels * 2, base_channels * 4, num_blocks[2], stride=2
)
# 阶段4: 1/32分辨率 (但我们最终会上采样回1/16)
self.stage4 = self._make_stage(
base_channels * 4, base_channels * 8, num_blocks[3], stride=2
)
# 1/2尺度输出处理
self.conv_half = ME.MinkowskiConvolution(
base_channels, base_channels, kernel_size=1, stride=1, dimension=3
)
# 1/8尺度输出处理
self.conv_eighth = ME.MinkowskiConvolution(
base_channels * 2, base_channels, kernel_size=1, stride=1, dimension=3
)
# 1/16尺度输出处理
self.conv_sixteenth = ME.MinkowskiConvolution(
base_channels * 4, base_channels, kernel_size=1, stride=1, dimension=3
)
# 上采样层用于1/16->1/16 (保持分辨率)
self.upsample4 = ME.MinkowskiConvolution(
base_channels * 8, base_channels, kernel_size=3, stride=1, dimension=3
)
# 额外的跳跃连接融合层
self.fuse_layers = nn.ModuleList([
ME.MinkowskiConvolution(
base_channels * 2, base_channels, kernel_size=1, stride=1, dimension=3
) for _ in range(2)
])
# 高斯参数转换头
self.gaussian_heads = nn.ModuleList([
SparseGaussianHead(in_channels=base_channels, out_channels=gaussian_out_channels)
for _ in range(4) # 为每个尺度创建一个高斯头
])
self.init_weights()
def _make_stage(self, in_channels, out_channels, num_blocks, stride):
"""创建一个残差阶段"""
blocks = []
# 第一个块可能有下采样
blocks.append(SparseResidualBlock(in_channels, out_channels, stride))
# 后续块保持相同分辨率
for _ in range(1, num_blocks):
blocks.append(SparseResidualBlock(out_channels, out_channels, stride=1))
return nn.Sequential(*blocks)
def forward(self, x: ME.SparseTensor):
"""
前向传播,输出四个尺度的特征(1/2,1/4,1/8,1/16)
Returns:
list: 包含四个尺度的稀疏特征张量
"""
# 1/2分辨率特征
x_half = self.conv1(x)
x_half = self.bn1(x_half)
x_half = self.relu(x_half)
# 1/4分辨率特征
x_quarter = self.stage1(x_half)
# 1/8分辨率特征
x_eighth = self.stage2(x_quarter)
# 1/16分辨率特征
x_sixteenth = self.stage3(x_eighth)
# 1/32分辨率特征
x_thirtysecond = self.stage4(x_sixteenth)
# 上采样回1/16等效分辨率
x_sixteenth2 = self.upsample4(x_thirtysecond)
# 1/8分辨率特征处理
x_eighth_proc = self.conv_eighth(x_eighth)
# 1/16分辨率特征融合 - 修改这里
# 先将x_sixteenth的通道数调整为64
x_sixteenth_adjusted = self.conv_sixteenth(x_sixteenth)
# 然后进行加法操作
x_sixteenth_final = x_sixteenth_adjusted + x_sixteenth2
# 创建多尺度特征列表
features = [
self.conv_half(x_half), # 1/2
x_quarter, # 1/4
x_eighth_proc, # 1/8
x_sixteenth_final # 1/16
]
# 应用高斯参数转换头
gaussian_outputs = []
for i, feat in enumerate(features):
gaussian_outputs.append(self.gaussian_heads[i](feat))
return gaussian_outputs
def init_weights(self):
"""初始化权重"""
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
try:
ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
except:
nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, ME.MinkowskiBatchNorm):
if hasattr(m, 'bn'):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
elif isinstance(m, ME.MinkowskiConvolutionTranspose):
try:
ME.utils.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
except:
nn.init.kaiming_normal_(m.kernel, mode='fan_out', nonlinearity='relu')
if hasattr(m, 'bias') and m.bias is not None:
nn.init.constant_(m.bias, 0)
# 测试代码
if __name__ == "__main__":
# 1. 创建输入数据
batch_size, channels, depth, height, width = 1, 164, 100, 50, 90
dense_feature = torch.randn(batch_size, channels, depth, height, width)
# 2. 创建稀疏张量
non_zero_mask = dense_feature.abs().sum(dim=1) > 0
coordinates = torch.nonzero(non_zero_mask).int().contiguous()
features = dense_feature[coordinates[:, 0], :, coordinates[:, 1], coordinates[:, 2], coordinates[:, 3]]
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sparse_tensor = ME.SparseTensor(
features=features.to(device),
coordinates=coordinates.to(device),
tensor_stride=1
)
print(f"创建了稀疏张量: {coordinates.shape[0]}个体素")
# 3. 创建模型
model = MultiScaleSparseHead(
in_channels=channels,
base_channels=64,
num_blocks=[2, 2, 2, 2],
gaussian_out_channels=38 # 高斯参数输出通道数
).to(device)
# 打印模型结构
print("模型结构:")
print(model)
# 4. 前向传播
outputs = model(sparse_tensor)
print("\n前向传播成功!")
print(f"输出包含 {len(outputs)} 个尺度的特征")
# 5. 检查每个尺度的输出
resolutions = ["1/2", "1/4", "1/8", "1/16"]
for i, (output, res) in enumerate(zip(outputs, resolutions)):
print(f"\n尺度 {i+1} ({res}分辨率):")
print(f"特征形状: {output.F.shape}")
print(f"坐标形状: {output.C.shape}")
# 检查坐标范围
coords = output.C.cpu()
print(f"深度范围: {coords[:,1].min().item()} - {coords[:,1].max().item()}")
print(f"高度范围: {coords[:,2].min().item()} - {coords[:,2].max().item()}")
print(f"宽度范围: {coords[:,3].min().item()} - {coords[:,3].max().item()}")
# 检查特征通道数
print(f"特征通道数: {output.F.shape[1]}")
# 检查体素数量
print(f"体素数量: {coords.shape[0]}")
# 6. 检查所有输出是否在同一设备上
all_on_device = all(out.F.device == device for out in outputs)
print(f"\n所有输出都在同一设备({device})上: {all_on_device}")
# 7. 合并所有尺度的输出
all_coords = []
all_feats = []
for out in outputs:
all_coords.append(out.C)
all_feats.append(out.F)
# 拼接所有坐标和特征
all_coords = torch.cat(all_coords, dim=0)
all_feats = torch.cat(all_feats, dim=0)
# 创建合并后的稀疏张量
combined_tensor = ME.SparseTensor(
features=all_feats,
coordinates=all_coords,
tensor_stride=1
)
print("\n合并后的稀疏张量:")
print(f"特征形状: {combined_tensor.F.shape}")
print(f"坐标形状: {combined_tensor.C.shape}")
print(f"体素总数: {combined_tensor.C.shape[0]}")
# 8. 检查梯度计算
try:
# 创建模拟损失
loss = combined_tensor.F.sum()
loss.backward()
print("\n反向传播成功!")
# 检查模型参数是否有梯度
has_gradients = False
for name, param in model.named_parameters():
if param.grad is not None:
has_gradients = True
break
print(f"模型参数有梯度: {has_gradients}")
except Exception as e:
print(f"\n反向传播失败: {str(e)}")