File size: 5,010 Bytes
3a85408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import torch
import torch.nn as nn
import torchvision.models as models
import resnet1d

__all__ = ['ResNet1D_MultiTask', 'get_model']

class ResNet1D_MultiTask(resnet1d.ResNet1D):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # 获取特征维度
        in_features = self.dense.in_features

        # 移除原始的预测层
        delattr(self, 'dense')
                # 添加多任务预测头
        self.prediction_head = nn.Sequential(
            # 第一层:512 -> 256
            nn.Linear(in_features, in_features//2),
            nn.BatchNorm1d(in_features//2),
            nn.ReLU(),
            nn.Dropout(p=0.3),

            # 第二层:256 -> 128
            nn.Linear(in_features//2, in_features//4),
            nn.BatchNorm1d(in_features//4),
            nn.ReLU(),
            nn.Dropout(p=0.3),

            # 输出层:128 -> 8
            nn.Linear(in_features//4, 8)
        )
    def forward(self, x):
        # 获取特征提取器的输出
        out = x
        
        # first conv
        out = self.first_block_conv(out)
        if self.use_bn:
            out = self.first_block_bn(out)
        out = self.first_block_relu(out)
        
        # residual blocks
        for i_block in range(self.n_block):
            net = self.basicblock_list[i_block]
            out = net(out)
            
        # 特征聚合
        if self.use_bn:
            out = self.final_bn(out)
        out = self.final_relu(out)
        out = out.mean(-1)  # 全局平均池化   
        
        out=self.prediction_head(out)
        
        return out  # 输出 8 个指标的预测值


def get_model(model_type):
    if model_type == 'A':  # ResNet18
        return ResNet1D_MultiTask(
            in_channels=1,
            base_filters=32,  # 减小base_filters,降低显存占用
            kernel_size=3,  # 使用3x3卷积核
            stride=2,
            groups=1,
            n_block=8,  # ResNet18的配置
            n_classes=8
        )
    elif model_type == 'B':  # ResNet34
        return ResNet1D_MultiTask(
            in_channels=1,
            base_filters=32,  # 调整base_filters
            kernel_size=3,  # 使用3x3卷积核
            stride=2,
            groups=1,
            n_block=16,  # ResNet34的配置
            n_classes=8
        )
    elif model_type == 'C':  # ResNet50
        return ResNet1D_MultiTask(
            in_channels=1,
            base_filters=32,  # 调整base_filters
            kernel_size=3,  # 使用3x3卷积核
            stride=2,
            groups=1,
            n_block=24,  # ResNet50的配置
            n_classes=8
        )
    else:
        raise ValueError("Invalid model type. Choose 'A' for ResNet18, 'B' for ResNet34, or 'C' for ResNet50")

def print_model_info():
    """
    打印模型关键信息(简化版)
    """
    try:
        from torchsummary import summary
    except ImportError:
        print("请先安装torchsummary: pip install torchsummary")
        return
        
    import torch
    device = torch.device("cpu")
    
    model_types = ['A', 'B', 'C']
    model_names = {
        'A': 'ResNet18',
        'B': 'ResNet34',
        'C': 'ResNet50'
    }
    
    # 模型配置信息
    model_configs = {
        'A': {'n_block': 8, 'base_filters': 32, 'kernel_size': 3},
        'B': {'n_block': 16, 'base_filters': 32, 'kernel_size': 3},
        'C': {'n_block': 24, 'base_filters': 32, 'kernel_size': 3}
    }
    
    print("\n" + "="*50)
    print(f"{'LUCAS土壤光谱分析模型架构':^48}")
    print("="*50)
    print(f"{'输入: (batch_size=15228, channels=1, length=130)':^48}")
    print(f"{'输出: 8个土壤属性预测值':^48}")
    print("-"*50)
    
    for model_type in model_types:
        model = get_model(model_type).to(device)
        config = model_configs[model_type]
        total_params = sum(p.numel() for p in model.parameters())
        trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
        
        print(f"\n[Model {model_type}: {model_names[model_type]}]")
        print(f"网络深度: {config['n_block']} blocks")
        print(f"基础通道数: {config['base_filters']}")
        print(f"卷积核大小: {config['kernel_size']}")
        print(f"总参数量: {total_params:,}")
        print(f"可训练参数: {trainable_params:,}")
        
        # 只打印主要层的信息
        main_layers = {}
        for name, module in model.named_children():
            params = sum(p.numel() for p in module.parameters())
            if params > 0 and params/total_params > 0.05:  # 只显示占比>5%的层
                main_layers[name] = params
        
        if main_layers:
            print("\n主要层结构:")
            for name, params in main_layers.items():
                print(f"  {name:15}: {params:,} ({params/total_params*100:.1f}%)")
        print("-"*50)

if __name__ == '__main__':
    print_model_info()