|
|
import pandas as pd |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from scipy.stats import pearsonr, kendalltau |
|
|
import seaborn as sns |
|
|
|
|
|
|
|
|
def get_label_dict(): |
|
|
""" |
|
|
返回一个字典 |
|
|
""" |
|
|
df = pd.read_excel('../dataset/r2_case.xlsx') |
|
|
df = df[~pd.isna(df['MIC GM (ug/ml)'])] |
|
|
values = {} |
|
|
for _, row in df.iterrows(): |
|
|
desc = row['SEQUENCE - D-type amino acid substitution'] |
|
|
value = row['MIC GM (ug/ml)'] |
|
|
values[desc] = value |
|
|
template = list(values.keys())[0].upper() |
|
|
t_value = values[template] |
|
|
tags = {} |
|
|
for mutation, value in values.items(): |
|
|
if mutation != template: |
|
|
tags[mutation] = np.log2(value / t_value) |
|
|
|
|
|
return tags |
|
|
|
|
|
|
|
|
sns.set_theme(style="whitegrid") |
|
|
|
|
|
predicted_csv_files = ["vanilla.csv", "sfda.csv", "uda.csv"] |
|
|
titles = ["Vanilla", "SFDA", "UDA"] |
|
|
|
|
|
|
|
|
true_data = get_label_dict() |
|
|
|
|
|
|
|
|
fig, axes = plt.subplots(1, len(predicted_csv_files), figsize=(3.9 * len(predicted_csv_files), 4), sharex=True, sharey=True) |
|
|
if len(predicted_csv_files) == 1: |
|
|
axes = [axes] |
|
|
|
|
|
|
|
|
for i, csv_file in enumerate(predicted_csv_files): |
|
|
|
|
|
pred_pd = pd.read_csv(f'./kcc/{csv_file}', index_col='seq') |
|
|
true_values, predicted_values, seqs = [], [], [] |
|
|
for seq, true_value in true_data.items(): |
|
|
true_values.append(true_value) |
|
|
predicted_values.append(pred_pd.loc[seq]['pred']) |
|
|
seqs.append(seq) |
|
|
|
|
|
|
|
|
corr_p, _ = pearsonr(true_values, predicted_values) |
|
|
corr_k, _ = kendalltau(true_values, predicted_values) |
|
|
|
|
|
|
|
|
|
|
|
ax = axes[i] |
|
|
ax.plot([-1, 2.5], [-1, 2.5], color='gray', linestyle='--', linewidth=1) |
|
|
ax.scatter(true_values, predicted_values, alpha=0.7) |
|
|
ax.set_title(titles[i], fontsize=14) |
|
|
ax.set_xlabel("True Log2 MIC Ratio", fontsize=12) |
|
|
ax.set_xticks(np.arange(-1, 2.6, 0.5)) |
|
|
ax.set_ylabel("Predicted Log2 MIC Ratio", fontsize=12) |
|
|
ax.set_yticks(np.arange(-1, 2.6, 0.5)) |
|
|
ax.tick_params(labelleft=True) |
|
|
ax.text(0.40, 0.03, f'PCC = {corr_p:.3f}\nKCC = {corr_k:.3f}', transform=ax.transAxes, fontsize=12, color='darkred') |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
|
|
|
plt.subplots_adjust(wspace=0.5) |
|
|
plt.tight_layout() |
|
|
plt.savefig('1_kcc_vis.svg') |
|
|
plt.show() |