File size: 2,423 Bytes
acbef3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import pearsonr, kendalltau
import seaborn as sns
def get_label_dict():
"""
返回一个字典
"""
df = pd.read_excel('../dataset/r2_case.xlsx')
df = df[~pd.isna(df['MIC GM (ug/ml)'])]
values = {}
for _, row in df.iterrows():
desc = row['SEQUENCE - D-type amino acid substitution']
value = row['MIC GM (ug/ml)']
values[desc] = value
template = list(values.keys())[0].upper()
t_value = values[template]
tags = {}
for mutation, value in values.items():
if mutation != template:
tags[mutation] = np.log2(value / t_value)
return tags
# 设置学术风格
sns.set_theme(style="whitegrid")
predicted_csv_files = ["vanilla.csv", "sfda.csv", "uda.csv"] # 预测值CSV文件列表
titles = ["Vanilla", "SFDA", "UDA"]
# 读取真值
true_data = get_label_dict()
# 创建子图
fig, axes = plt.subplots(1, len(predicted_csv_files), figsize=(3.9 * len(predicted_csv_files), 4), sharex=True, sharey=True)
if len(predicted_csv_files) == 1:
axes = [axes] # 确保 axes 是可迭代的
# 遍历预测值文件并绘制子图
for i, csv_file in enumerate(predicted_csv_files):
# 读取当前预测值
pred_pd = pd.read_csv(f'./kcc/{csv_file}', index_col='seq')
true_values, predicted_values, seqs = [], [], []
for seq, true_value in true_data.items():
true_values.append(true_value)
predicted_values.append(pred_pd.loc[seq]['pred'])
seqs.append(seq)
# 计算PCC
corr_p, _ = pearsonr(true_values, predicted_values)
corr_k, _ = kendalltau(true_values, predicted_values)
# 绘制散点图
ax = axes[i]
ax.plot([-1, 2.5], [-1, 2.5], color='gray', linestyle='--', linewidth=1)
ax.scatter(true_values, predicted_values, alpha=0.7)
ax.set_title(titles[i], fontsize=14)
ax.set_xlabel("True Log2 MIC Ratio", fontsize=12)
ax.set_xticks(np.arange(-1, 2.6, 0.5))
ax.set_ylabel("Predicted Log2 MIC Ratio", fontsize=12)
ax.set_yticks(np.arange(-1, 2.6, 0.5))
ax.tick_params(labelleft=True)
ax.text(0.40, 0.03, f'PCC = {corr_p:.3f}\nKCC = {corr_k:.3f}', transform=ax.transAxes, fontsize=12, color='darkred')
ax.grid(alpha=0.3)
# 调整布局和显示
plt.subplots_adjust(wspace=0.5)
plt.tight_layout()
plt.savefig('1_kcc_vis.svg')
plt.show() |