File size: 5,010 Bytes
3a85408 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import torch
import torch.nn as nn
import torchvision.models as models
import resnet1d
__all__ = ['ResNet1D_MultiTask', 'get_model']
class ResNet1D_MultiTask(resnet1d.ResNet1D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# 获取特征维度
in_features = self.dense.in_features
# 移除原始的预测层
delattr(self, 'dense')
# 添加多任务预测头
self.prediction_head = nn.Sequential(
# 第一层:512 -> 256
nn.Linear(in_features, in_features//2),
nn.BatchNorm1d(in_features//2),
nn.ReLU(),
nn.Dropout(p=0.3),
# 第二层:256 -> 128
nn.Linear(in_features//2, in_features//4),
nn.BatchNorm1d(in_features//4),
nn.ReLU(),
nn.Dropout(p=0.3),
# 输出层:128 -> 8
nn.Linear(in_features//4, 8)
)
def forward(self, x):
# 获取特征提取器的输出
out = x
# first conv
out = self.first_block_conv(out)
if self.use_bn:
out = self.first_block_bn(out)
out = self.first_block_relu(out)
# residual blocks
for i_block in range(self.n_block):
net = self.basicblock_list[i_block]
out = net(out)
# 特征聚合
if self.use_bn:
out = self.final_bn(out)
out = self.final_relu(out)
out = out.mean(-1) # 全局平均池化
out=self.prediction_head(out)
return out # 输出 8 个指标的预测值
def get_model(model_type):
if model_type == 'A': # ResNet18
return ResNet1D_MultiTask(
in_channels=1,
base_filters=32, # 减小base_filters,降低显存占用
kernel_size=3, # 使用3x3卷积核
stride=2,
groups=1,
n_block=8, # ResNet18的配置
n_classes=8
)
elif model_type == 'B': # ResNet34
return ResNet1D_MultiTask(
in_channels=1,
base_filters=32, # 调整base_filters
kernel_size=3, # 使用3x3卷积核
stride=2,
groups=1,
n_block=16, # ResNet34的配置
n_classes=8
)
elif model_type == 'C': # ResNet50
return ResNet1D_MultiTask(
in_channels=1,
base_filters=32, # 调整base_filters
kernel_size=3, # 使用3x3卷积核
stride=2,
groups=1,
n_block=24, # ResNet50的配置
n_classes=8
)
else:
raise ValueError("Invalid model type. Choose 'A' for ResNet18, 'B' for ResNet34, or 'C' for ResNet50")
def print_model_info():
"""
打印模型关键信息(简化版)
"""
try:
from torchsummary import summary
except ImportError:
print("请先安装torchsummary: pip install torchsummary")
return
import torch
device = torch.device("cpu")
model_types = ['A', 'B', 'C']
model_names = {
'A': 'ResNet18',
'B': 'ResNet34',
'C': 'ResNet50'
}
# 模型配置信息
model_configs = {
'A': {'n_block': 8, 'base_filters': 32, 'kernel_size': 3},
'B': {'n_block': 16, 'base_filters': 32, 'kernel_size': 3},
'C': {'n_block': 24, 'base_filters': 32, 'kernel_size': 3}
}
print("\n" + "="*50)
print(f"{'LUCAS土壤光谱分析模型架构':^48}")
print("="*50)
print(f"{'输入: (batch_size=15228, channels=1, length=130)':^48}")
print(f"{'输出: 8个土壤属性预测值':^48}")
print("-"*50)
for model_type in model_types:
model = get_model(model_type).to(device)
config = model_configs[model_type]
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"\n[Model {model_type}: {model_names[model_type]}]")
print(f"网络深度: {config['n_block']} blocks")
print(f"基础通道数: {config['base_filters']}")
print(f"卷积核大小: {config['kernel_size']}")
print(f"总参数量: {total_params:,}")
print(f"可训练参数: {trainable_params:,}")
# 只打印主要层的信息
main_layers = {}
for name, module in model.named_children():
params = sum(p.numel() for p in module.parameters())
if params > 0 and params/total_params > 0.05: # 只显示占比>5%的层
main_layers[name] = params
if main_layers:
print("\n主要层结构:")
for name, params in main_layers.items():
print(f" {name:15}: {params:,} ({params/total_params*100:.1f}%)")
print("-"*50)
if __name__ == '__main__':
print_model_info() |