cs3319-project2 / code /figures /fig4_highorder_ablation.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
1.78 kB
"""Fig 4: higher-order directed citation propagation ablation (the innovation).
Source: validation_runs/dynamic_seed202/high_order_graph_stack/validation_summary.csv
(4 rows: base_highorder / rich_rw7 / rich_rw7_highorder / rich_rw7_highorder_directed).
"""
from pathlib import Path
import sys
import pandas as pd
import matplotlib.pyplot as plt
sys.path.insert(0, str(Path(__file__).resolve().parent))
from plot_style import apply, save, PALETTE_DEEP as C # noqa: E402
apply()
ROOT = Path(__file__).resolve().parents[2]
FIG = ROOT / "reports" / "figures"
df = pd.read_csv(ROOT / "validation_runs/dynamic_seed202/high_order_graph_stack/validation_summary.csv")
df = df.sort_values("n_features").reset_index(drop=True)
short = [s.replace("rich_rw7_", "").replace("base_", "").replace("_", "\n") for s in df["stage"]]
fig, ax = plt.subplots(figsize=(9, 5.2))
x = list(range(len(df)))
w = 0.38
ax.bar([i - w / 2 for i in x], df["validation_f1"], w, color=C[0], label="Validation F1")
ax.set_ylabel("Validation F1")
ax.set_ylim(0.960, 0.9705)
ax.set_xticks(x)
ax.set_xticklabels(short, fontsize=9)
for i, (f, n) in enumerate(zip(df["validation_f1"], df["n_features"])):
ax.text(i - w / 2, f + 0.0004, f"{f:.4f}", ha="center", fontsize=8.6)
ax.text(i - w / 2, 0.9604, f"{n} feats", ha="center", fontsize=8, color="dimgray")
ax2 = ax.twinx()
ax2.bar([i + w / 2 for i in x], df["auc"], w, color=C[3], alpha=0.75, label="AUC")
ax2.set_ylabel("AUC")
ax2.set_ylim(0.9935, 0.9955)
ax2.grid(False)
h1, l1 = ax.get_legend_handles_labels()
h2, l2 = ax2.get_legend_handles_labels()
ax.legend(h1 + h2, l1 + l2, loc="lower right")
ax.set_title("Higher-order directed citation propagation ablation")
save(fig, "fig4_highorder_ablation", FIG)
print("saved fig4_highorder_ablation")