DAminoMuta / vis /auroc.py
auralray's picture
Upload folder using huggingface_hub
acbef3a verified
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, precision_recall_curve, auc
# 读取数据,数据包含预测概率(0-1)和真实标签(0或1),真实标签列名为 'gt'
df = pd.read_csv('auroc_curves.csv')
# 定义子图标题及每个子图需要绘制的列名
titles = [
'Encoder Types',
'Encoder Widths',
'LLM Colabration'
]
cols = [
['LSTM 256 MLP', 'LSTM 256 ATT', 'GRU 256 MLP', 'GRU 256 ATT', 'MHA 256 MLP', 'MHA 256 ATT', 'Mamba 256 MLP', 'Mamba 256 ATT'],
['Mamba 128 ATT', 'Mamba 256 ATT', 'Mamba 512 ATT'],
['Mamba 256 ATT', 'DS R1', 'DS R1 Mamba Fusion']
]
color_map = {
"LSTM 256 MLP": "#1f77b4",
"LSTM 256 ATT": "#665f88",
"GRU 256 MLP": "#1f4494",
"GRU 256 ATT": "#1f55a4",
"MHA 256 MLP": "#1f6684",
"MHA 256 ATT": "#1f88b4",
"MLA 256 MLP": "#1f99c4",
"MLA 256 ATT": "#1f8F74",
"Mamba 256 MLP": "#2ca02c",
"Mamba 256 ATT": "#FF5733", # 突显:鲜艳的橙红色
"Mamba 128 ATT": "#9467bd",
"Mamba 512 ATT": "#8c564b",
"DS R1": "#e377c2",
"DS R1 Mamba Fusion": "#FF2222" # 突显:醒目的深玫红色
}
# 创建三个子图,调整图形尺寸适合展示
fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 5))
# 遍历每个子图与对应的标题、方法列数据
for ax, title, methods in zip(axes, titles, cols):
for method in methods:
# 获取预测概率及真实标签
y_true = df['gt']
y_score = df[method]
# 计算 ROC 曲线以及 AUROC 值
fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=False)
auroc = auc(fpr, tpr)
# 绘制 ROC 曲线,使用已分配的颜色,学术展示的曲线线宽合适
ax.plot(fpr, tpr, label=f"{method} AUC: {auroc:.2f}",
color=color_map[method], lw=2, alpha=0.7)
# 设置子图标题
ax.set_title(title, fontsize=18, weight='bold')
# 设置横纵坐标的标签
ax.set_xlabel("False Positive Rate", fontsize=14)
ax.set_ylabel("True Positive Rate", fontsize=14)
ax.yaxis.set_label_position("right")
# 仅保留下边框和右边框, 去掉上边框和左边框
ax.spines['top'].set_visible(False)
ax.spines['left'].set_visible(False)
# 仅在底部和右侧显示刻度
ax.xaxis.set_ticks_position('bottom')
ax.yaxis.set_ticks_position('right')
# 设置图例,放置在右下角
ax.legend(loc='lower right', fontsize=10, frameon=False, alignment='right', markerfirst=False)
# 调整整个图形的布局,避免子图之间重叠
plt.tight_layout(w_pad=5)
plt.savefig('auroc.svg')
plt.show()