import pandas as pd import numpy as np import matplotlib.pyplot as plt from scipy.stats import pearsonr, kendalltau import seaborn as sns def get_label_dict(): """ 返回一个字典 """ df = pd.read_excel('../dataset/r2_case.xlsx') df = df[~pd.isna(df['MIC GM (ug/ml)'])] values = {} for _, row in df.iterrows(): desc = row['SEQUENCE - D-type amino acid substitution'] value = row['MIC GM (ug/ml)'] values[desc] = value template = list(values.keys())[0].upper() t_value = values[template] tags = {} for mutation, value in values.items(): if mutation != template: tags[mutation] = np.log2(value / t_value) return tags # 设置学术风格 sns.set_theme(style="whitegrid") predicted_csv_files = ["vanilla.csv", "sfda.csv", "uda.csv"] # 预测值CSV文件列表 titles = ["Vanilla", "SFDA", "UDA"] # 读取真值 true_data = get_label_dict() # 创建子图 fig, axes = plt.subplots(1, len(predicted_csv_files), figsize=(3.9 * len(predicted_csv_files), 4), sharex=True, sharey=True) if len(predicted_csv_files) == 1: axes = [axes] # 确保 axes 是可迭代的 # 遍历预测值文件并绘制子图 for i, csv_file in enumerate(predicted_csv_files): # 读取当前预测值 pred_pd = pd.read_csv(f'./kcc/{csv_file}', index_col='seq') true_values, predicted_values, seqs = [], [], [] for seq, true_value in true_data.items(): true_values.append(true_value) predicted_values.append(pred_pd.loc[seq]['pred']) seqs.append(seq) # 计算PCC corr_p, _ = pearsonr(true_values, predicted_values) corr_k, _ = kendalltau(true_values, predicted_values) # 绘制散点图 ax = axes[i] ax.plot([-1, 2.5], [-1, 2.5], color='gray', linestyle='--', linewidth=1) ax.scatter(true_values, predicted_values, alpha=0.7) ax.set_title(titles[i], fontsize=14) ax.set_xlabel("True Log2 MIC Ratio", fontsize=12) ax.set_xticks(np.arange(-1, 2.6, 0.5)) ax.set_ylabel("Predicted Log2 MIC Ratio", fontsize=12) ax.set_yticks(np.arange(-1, 2.6, 0.5)) ax.tick_params(labelleft=True) ax.text(0.40, 0.03, f'PCC = {corr_p:.3f}\nKCC = {corr_k:.3f}', transform=ax.transAxes, fontsize=12, color='darkred') ax.grid(alpha=0.3) # 调整布局和显示 plt.subplots_adjust(wspace=0.5) plt.tight_layout() plt.savefig('1_kcc_vis.svg') plt.show()