File size: 4,574 Bytes
3a85408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import torch
import numpy as np
import shap
import matplotlib.pyplot as plt
import seaborn as sns
from data_load import load_soil_data
from data_processing import preprocess_with_downsampling, process_spectra
from resnet1d_multitask import get_model
# 设置中文字体
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
def load_model_and_data():
# 定义目标列
target_columns = ['pH.in.CaCl2', 'pH.in.H2O', 'OC', 'CaCO3', 'N', 'P', 'K', 'CEC']
# 加载数据
X_train, X_test, y_train, y_test, wavelengths = load_soil_data('LUCAS.2009_abs.csv', target_columns)
# 确保数据形状为 (n_samples, n_wavelengths)
X_train, X_test = X_train.squeeze(), X_test.squeeze()
# 预处理数据
X_train = process_spectra(X_train, 'Abs-SG1-SNV')
X_test = process_spectra(X_test, 'Abs-SG1-SNV')
X_train, X_train_nwavelengths = preprocess_with_downsampling(X_train, wavelengths, 15)
X_test, X_test_nwavelengths = preprocess_with_downsampling(X_test, wavelengths, 15)
# 将数据形状调整为 (n_samples, 1, n_wavelengths)
X_train = X_train.reshape(X_train.shape[0], 1, X_train.shape[1])
X_test = X_test.reshape(X_test.shape[0], 1, X_test.shape[1])
# 加载模型
device = torch.device('cpu')
model = get_model('C')
model.load_state_dict(torch.load('code/models/set.pth', map_location=device))
model.eval()
return model, X_test, X_test_nwavelengths
def explain_predictions(model, X_test, wavelengths, n_background=50, n_samples=100):
"""
使用KernelExplainer计算SHAP值,更适合CPU环境
"""
# 定义一个包装函数来处理模型预测
def model_predict(x):
with torch.no_grad():
x = torch.FloatTensor(x)
if len(x.shape) == 2:
x = x.reshape(x.shape[0], 1, -1)
output = model(x)
# 如果模型输出是元组,取第一个元素
if isinstance(output, tuple):
output = output[0]
return output.numpy()
# 准备背景数据和测试数据
background_data = X_test[:n_background].squeeze()
test_data = X_test[:n_samples].squeeze()
# 创建KernelExplainer
print("创建SHAP解释器...")
explainer = shap.KernelExplainer(model_predict, background_data)
# 计算SHAP值
print("计算SHAP值(这可能需要几分钟时间)...")
shap_values = explainer.shap_values(test_data, nsamples=100)
# 处理多输出模型的SHAP值
if isinstance(shap_values, list):
# 取所有输出的平均值
avg_shap_values = np.mean([np.abs(sv) for sv in shap_values], axis=0)
else:
avg_shap_values = np.abs(shap_values)
return avg_shap_values
def plot_top10_wavelengths(shap_values, wavelengths):
"""
绘制贡献度排名前10的波段柱状图
"""
# 计算每个波段的平均绝对SHAP值
mean_shap = np.mean(np.abs(shap_values), axis=0)
# 获取前10个波段的索引
top10_idx = np.argsort(mean_shap)[-10:][::-1]
top10_wavelengths = wavelengths[top10_idx]
top10_values = mean_shap[top10_idx]
# 创建柱状图
plt.figure(figsize=(8, 6))
plt.barh(range(10), top10_values, color='skyblue')
plt.yticks(range(10), [f'{w:.1f}' for w in top10_wavelengths])
plt.xlabel('mean(|SHAP value|) (average impact on model output magnitude)')
plt.title('Top 10 Wavelengths by SHAP Value')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.savefig('shap_top10_wavelengths.png', dpi=300, bbox_inches='tight')
plt.close()
def plot_shap_summary(shap_values, wavelengths):
"""
绘制SHAP值的蜂窝图
"""
plt.figure(figsize=(10, 8))
shap.summary_plot(shap_values, features=wavelengths, feature_names=[f'{w:.1f}' for w in wavelengths], plot_type='dot', show=False)
plt.tight_layout()
plt.savefig('shap_summary_plot.png', dpi=300, bbox_inches='tight')
plt.close()
def main():
# 加载模型和数据
print("加载模型和数据...")
model, X_test, wavelengths = load_model_and_data()
# 计算SHAP值
print("正在计算SHAP值...")
shap_values = explain_predictions(model, X_test, wavelengths)
# 绘制图表
print("正在生成图表...")
plot_shap_summary(shap_values, wavelengths)
plot_top10_wavelengths(shap_values, wavelengths)
print("分析完成!图表已保存为 'shap_summary_plot.png' 和 'shap_top10_wavelengths.png'")
if __name__ == "__main__":
main()
|