| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import roc_curve, precision_recall_curve, auc |
|
|
| |
| df = pd.read_csv('auroc_curves.csv') |
|
|
| |
| titles = [ |
| 'Encoder Types', |
| 'Encoder Widths', |
| 'LLM Colabration' |
| ] |
| cols = [ |
| ['LSTM 256 MLP', 'LSTM 256 ATT', 'GRU 256 MLP', 'GRU 256 ATT', 'MHA 256 MLP', 'MHA 256 ATT', 'Mamba 256 MLP', 'Mamba 256 ATT'], |
| ['Mamba 128 ATT', 'Mamba 256 ATT', 'Mamba 512 ATT'], |
| ['Mamba 256 ATT', 'DS R1', 'DS R1 Mamba Fusion'] |
| ] |
|
|
| color_map = { |
| "LSTM 256 MLP": "#1f77b4", |
| "LSTM 256 ATT": "#665f88", |
| "GRU 256 MLP": "#1f4494", |
| "GRU 256 ATT": "#1f55a4", |
| "MHA 256 MLP": "#1f6684", |
| "MHA 256 ATT": "#1f88b4", |
| "MLA 256 MLP": "#1f99c4", |
| "MLA 256 ATT": "#1f8F74", |
| "Mamba 256 MLP": "#2ca02c", |
| "Mamba 256 ATT": "#FF5733", |
| "Mamba 128 ATT": "#9467bd", |
| "Mamba 512 ATT": "#8c564b", |
| "DS R1": "#e377c2", |
| "DS R1 Mamba Fusion": "#FF2222" |
| } |
|
|
| |
| fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(16, 5)) |
|
|
| |
| for ax, title, methods in zip(axes, titles, cols): |
| for method in methods: |
| |
| y_true = df['gt'] |
| y_score = df[method] |
| |
| fpr, tpr, thresholds = roc_curve(y_true, y_score, drop_intermediate=False) |
| auroc = auc(fpr, tpr) |
| |
| |
| ax.plot(fpr, tpr, label=f"{method} AUC: {auroc:.2f}", |
| color=color_map[method], lw=2, alpha=0.7) |
| |
| |
| ax.set_title(title, fontsize=18, weight='bold') |
| |
| ax.set_xlabel("False Positive Rate", fontsize=14) |
| ax.set_ylabel("True Positive Rate", fontsize=14) |
| ax.yaxis.set_label_position("right") |
| |
| |
| ax.spines['top'].set_visible(False) |
| ax.spines['left'].set_visible(False) |
| |
| ax.xaxis.set_ticks_position('bottom') |
| ax.yaxis.set_ticks_position('right') |
| |
| |
| ax.legend(loc='lower right', fontsize=10, frameon=False, alignment='right', markerfirst=False) |
|
|
| |
| plt.tight_layout(w_pad=5) |
| plt.savefig('auroc.svg') |
| plt.show() |