| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import roc_curve, precision_recall_curve, roc_auc_score, average_precision_score |
|
|
| |
| titles = { |
| 'R2': 'r2_case', |
| |
| |
| |
| } |
|
|
| methods = { |
| 'LSTM': "#665f88", |
| 'GRU': "#658f5f", |
| 'MHA': "#5f858f", |
| 'Mamba': "#AC9E53", |
| 'GAT': "#15A9B1", |
| 'GATv2': "#0EA488", |
| |
| 'ADAPT': "#FF0000", |
| } |
|
|
| folders = { |
| 'LSTM': '../run-cls/lstm-diff-256-ce-32-0.001-50', |
| 'GRU': '../run-cls/gru-diff-256-ce-32-0.001-50', |
| 'MHA': '../run-cls/mha-diff-256-ce-32-0.001-50', |
| 'Mamba': '../run-cls/mamba-diff-256-ce-32-0.001-50', |
| 'GAT': '../../DAminoMuta_graph/run-cls/gat-diff-256-ce-32-0.001-35', |
| 'GATv2': '../../DAminoMuta_graph/run-cls/gatv2-diff-256-ce-32-0.001-35', |
| |
| 'ADAPT': '../run-cls/rn18-diff-16-mamba-pcs-768-ce-32-0.001-50/uda_r2', |
| } |
|
|
| |
| fig, axes = plt.subplots(nrows=1, ncols=2*len(titles), figsize=(10.66*len(titles), 5)) |
|
|
| |
| for ax, (title, alias) in zip(axes[0::2], titles.items()): |
| for method, color in methods.items(): |
| |
| df = pd.read_csv(f'{folders[method]}/preds_{alias}.csv') |
| try: |
| y_score = np.concat([df[f'model_{i}_test'].values for i in range(5)], axis=0) |
| y_true = np.concat([df[f'gt'].values for _ in range(5)], axis=0) |
| except: |
| y_score = df[f'model_uda_teacher'].values |
| y_true = df[f'gt'].values |
| |
| fpr, tpr, thresholds = roc_curve(y_true, y_score) |
| auroc = roc_auc_score(y_true, y_score) |
| |
| |
| ax.plot(fpr, tpr, label=f"{method} AUC: {auroc:.2f}", |
| color=color, lw=2, alpha=0.7) |
| |
| |
| ax.set_title(f'{title} ROC Curve', fontsize=18, weight='bold') |
| |
| ax.set_xlabel("False Positive Rate", fontsize=14) |
| ax.set_ylabel("True Positive Rate", fontsize=14) |
| ax.yaxis.set_label_position("right") |
| |
| |
| ax.spines['top'].set_visible(False) |
| ax.spines['left'].set_visible(False) |
| |
| ax.xaxis.set_ticks_position('bottom') |
| ax.yaxis.set_ticks_position('right') |
| |
| |
| ax.legend(loc='lower right', fontsize=10, frameon=False, alignment='right', markerfirst=False) |
|
|
| |
| for ax, (title, alias) in zip(axes[1::2], titles.items()): |
| for method, color in methods.items(): |
| |
| df = pd.read_csv(f'{folders[method]}/preds_{alias}.csv') |
| try: |
| y_score = np.concat([df[f'model_{i}_test'].values for i in range(5)], axis=0) |
| y_true = np.concat([df[f'gt'].values for _ in range(5)], axis=0) |
| except: |
| y_score = df[f'model_uda_teacher'].values |
| y_true = df[f'gt'].values |
| |
| pr, rc, thresholds = precision_recall_curve(y_true, y_score, drop_intermediate=True) |
| auprc = average_precision_score(y_true, y_score) |
| |
| |
| ax.plot(rc, pr, label=f"{method} AUC: {auprc:.2f}", |
| color=color, lw=2, alpha=0.7) |
| |
| |
| ax.set_title(f'{title} PR Curve', fontsize=18, weight='bold') |
| |
| ax.set_xlabel("Recall", fontsize=14) |
| ax.set_ylabel("Precision", fontsize=14) |
| ax.yaxis.set_label_position("left") |
| |
| |
| ax.spines['top'].set_visible(False) |
| ax.spines['right'].set_visible(False) |
| |
| ax.xaxis.set_ticks_position('bottom') |
| ax.yaxis.set_ticks_position('left') |
| |
| |
| ax.legend(loc='upper right', fontsize=10, frameon=False, alignment='right', markerfirst=False) |
|
|
| |
| plt.tight_layout(w_pad=5) |
| plt.savefig('auroc_2.svg') |
| plt.show() |