File size: 4,574 Bytes
3a85408
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import torch
import numpy as np
import shap
import matplotlib.pyplot as plt
import seaborn as sns
from data_load import load_soil_data
from data_processing import preprocess_with_downsampling, process_spectra
from resnet1d_multitask import get_model

# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def load_model_and_data():
    # 定义目标列
    target_columns = ['pH.in.CaCl2', 'pH.in.H2O', 'OC', 'CaCO3', 'N', 'P', 'K', 'CEC']

    # 加载数据
    X_train, X_test, y_train, y_test, wavelengths = load_soil_data('LUCAS.2009_abs.csv', target_columns)

    # 确保数据形状为 (n_samples, n_wavelengths)
    X_train, X_test = X_train.squeeze(), X_test.squeeze()

    # 预处理数据
    X_train = process_spectra(X_train, 'Abs-SG1-SNV')
    X_test = process_spectra(X_test, 'Abs-SG1-SNV')

    X_train, X_train_nwavelengths = preprocess_with_downsampling(X_train, wavelengths, 15)
    X_test, X_test_nwavelengths = preprocess_with_downsampling(X_test, wavelengths, 15)
    
    # 将数据形状调整为 (n_samples, 1, n_wavelengths)
    X_train = X_train.reshape(X_train.shape[0], 1, X_train.shape[1])
    X_test = X_test.reshape(X_test.shape[0], 1, X_test.shape[1])
        
    # 加载模型
    device = torch.device('cpu')
    model = get_model('C')
    model.load_state_dict(torch.load('code/models/set.pth', map_location=device))
    model.eval()
    
    return model, X_test, X_test_nwavelengths

def explain_predictions(model, X_test, wavelengths, n_background=50, n_samples=100):
    """
    使用KernelExplainer计算SHAP值,更适合CPU环境
    """
    # 定义一个包装函数来处理模型预测
    def model_predict(x):
        with torch.no_grad():
            x = torch.FloatTensor(x)
            if len(x.shape) == 2:
                x = x.reshape(x.shape[0], 1, -1)
            output = model(x)
            # 如果模型输出是元组,取第一个元素
            if isinstance(output, tuple):
                output = output[0]
            return output.numpy()

    # 准备背景数据和测试数据
    background_data = X_test[:n_background].squeeze()
    test_data = X_test[:n_samples].squeeze()

    # 创建KernelExplainer
    print("创建SHAP解释器...")
    explainer = shap.KernelExplainer(model_predict, background_data)
    
    # 计算SHAP值
    print("计算SHAP值(这可能需要几分钟时间)...")
    shap_values = explainer.shap_values(test_data, nsamples=100)

    # 处理多输出模型的SHAP值
    if isinstance(shap_values, list):
        # 取所有输出的平均值
        avg_shap_values = np.mean([np.abs(sv) for sv in shap_values], axis=0)
    else:
        avg_shap_values = np.abs(shap_values)

    return avg_shap_values

def plot_top10_wavelengths(shap_values, wavelengths):
    """
    绘制贡献度排名前10的波段柱状图
    """
    # 计算每个波段的平均绝对SHAP值
    mean_shap = np.mean(np.abs(shap_values), axis=0)
    
    # 获取前10个波段的索引
    top10_idx = np.argsort(mean_shap)[-10:][::-1]
    top10_wavelengths = wavelengths[top10_idx]
    top10_values = mean_shap[top10_idx]
    
    # 创建柱状图
    plt.figure(figsize=(8, 6))
    plt.barh(range(10), top10_values, color='skyblue')
    plt.yticks(range(10), [f'{w:.1f}' for w in top10_wavelengths])
    plt.xlabel('mean(|SHAP value|) (average impact on model output magnitude)')
    plt.title('Top 10 Wavelengths by SHAP Value')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig('shap_top10_wavelengths.png', dpi=300, bbox_inches='tight')
    plt.close()


def plot_shap_summary(shap_values, wavelengths):
    """
    绘制SHAP值的蜂窝图
    """
    plt.figure(figsize=(10, 8))
    shap.summary_plot(shap_values, features=wavelengths, feature_names=[f'{w:.1f}' for w in wavelengths], plot_type='dot', show=False)
    plt.tight_layout()
    plt.savefig('shap_summary_plot.png', dpi=300, bbox_inches='tight')
    plt.close()

def main():
    # 加载模型和数据
    print("加载模型和数据...")
    model, X_test, wavelengths = load_model_and_data()
    
    # 计算SHAP值
    print("正在计算SHAP值...")
    shap_values = explain_predictions(model, X_test, wavelengths)
    
    # 绘制图表
    print("正在生成图表...")
    plot_shap_summary(shap_values, wavelengths)
    plot_top10_wavelengths(shap_values, wavelengths)
    print("分析完成!图表已保存为 'shap_summary_plot.png' 和 'shap_top10_wavelengths.png'")

if __name__ == "__main__":
    main()