MM-DLS / mm-dls /plot_results.py
FangDai's picture
Upload 11 files
a19a7aa verified
# code/plot_results.py
# ============================================================
# End-to-end paper-style plotting (curves + tables)
# - Subtype (binary): ROC + PR + Calibration (with tables)
# - TNM (multiclass OVR): ROC + PR + Calibration (with tables, per class)
# - DFS/OS survival: KM + Cox HR + log-rank + C-index/Brier (with at-risk text)
#
# IMPORTANT:
# - Safe to import (NO plotting on import)
# - Call plot_all(result_dir, fig_dir) after main.py saves outputs
# ============================================================
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import label_binarize
from sklearn.metrics import (
roc_curve, auc,
precision_recall_curve, average_precision_score,
confusion_matrix,
brier_score_loss
)
from sklearn.calibration import calibration_curve
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import multivariate_logrank_test
from lifelines.utils import concordance_index
from scipy.stats import norm
# ============================================================
# Basic I/O helpers
# ============================================================
def _ensure_dir(path: str):
os.makedirs(path, exist_ok=True)
def _exists(path: str) -> bool:
return os.path.exists(path) and os.path.isfile(path)
def _load_npy(path: str):
if not _exists(path):
return None
return np.load(path, allow_pickle=True)
def _maybe_sim_ext(labels, scores, noise=0.03, seed=42):
"""
Simulate an external test split when not provided.
Keeps labels same; adds small noise to scores then clips to [0,1].
"""
rng = np.random.RandomState(seed)
if scores is None:
return None, None
s = scores.copy()
s = np.clip(s + rng.normal(0, noise, s.shape), 0.0, 1.0)
return labels.copy(), s
# ============================================================
# Metrics helpers
# ============================================================
def _calc_binary_roc(y_true, y_score):
fpr, tpr, _ = roc_curve(y_true, y_score)
roc_auc = auc(fpr, tpr)
brier = brier_score_loss(y_true, y_score)
return fpr, tpr, roc_auc, brier
def _calc_binary_pr(y_true, y_score):
p, r, _ = precision_recall_curve(y_true, y_score)
ap = average_precision_score(y_true, y_score)
return p, r, ap
def _spec_npv_binary(y_true, y_score, thresh=0.5):
y_pred = (y_score >= thresh).astype(int)
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
specificity = tn / (tn + fp) if (tn + fp) else 0.0
npv = tn / (tn + fn) if (tn + fn) else 0.0
return specificity, npv
def _ece(y_true, y_score, n_bins=10):
bins = np.linspace(0.0, 1.0, n_bins + 1)
binids = np.digitize(y_score, bins) - 1
ece = 0.0
for i in range(n_bins):
m = binids == i
if m.sum() > 0:
prob_true = np.mean(y_true[m])
prob_pred = np.mean(y_score[m])
ece += (m.sum() / len(y_score)) * abs(prob_pred - prob_true)
return float(ece)
def _calc_ovr_auc(y_bin, y_score):
"""One-vs-rest ROC for multiclass. Returns dict: {class_i: (fpr,tpr,auc)}"""
out = {}
for i in range(y_bin.shape[1]):
fpr, tpr, _ = roc_curve(y_bin[:, i], y_score[:, i])
out[i] = (fpr, tpr, auc(fpr, tpr))
return out
def _calc_ovr_pr(y_bin, y_score):
"""One-vs-rest PR for multiclass. Returns dict: {class_i: (p,r,ap)}"""
out = {}
for i in range(y_bin.shape[1]):
p, r, _ = precision_recall_curve(y_bin[:, i], y_score[:, i])
ap = average_precision_score(y_bin[:, i], y_score[:, i])
out[i] = (p, r, ap)
return out
def _acc_ovr(y_true_bin, y_score, thresh=0.5):
y_pred = (y_score >= thresh).astype(int)
return float((y_pred == y_true_bin).mean())
# ============================================================
# Table helpers (paper-style)
# ============================================================
def _auto_col_widths(col_labels, bbox_w):
lens = np.array([max(4, len(c)) for c in col_labels], dtype=float)
ratio = lens / lens.sum()
return bbox_w * ratio
def _add_table(ax, table_data, row_labels, col_labels, colors=None,
bbox=(0.05, -0.50, 0.95, 0.30),
fontsize=13, rowlabel_width=0.18):
"""
colors: list[str] length = len(row_labels) (for per-row coloring)
"""
tbl = plt.table(
cellText=table_data,
rowLabels=row_labels,
colLabels=col_labels,
cellLoc='center',
rowLoc='left',
colLoc='center',
bbox=list(bbox),
)
tbl.auto_set_font_size(False)
tbl.set_fontsize(fontsize)
cells = tbl.get_celld()
# set column widths (excluding row label col=-1)
col_widths = _auto_col_widths(col_labels, bbox[2])
for col in range(len(col_labels)):
for row in range(len(row_labels) + 1): # header included
cells[(row, col)].set_width(col_widths[col])
# row label width
for row in range(1, len(row_labels) + 1):
if (row, -1) in cells:
cells[(row, -1)].set_width(rowlabel_width)
# styling: no grid lines
for (r, c), cell in cells.items():
cell.set_linewidth(0)
# optional per-row color
if colors is not None:
for r in range(1, len(row_labels) + 1):
# color values (not the header)
for c in range(len(col_labels)):
if (r, c) in cells:
cells[(r, c)].get_text().set_color(colors[r - 1])
# row label
if (r, -1) in cells:
cells[(r, -1)].get_text().set_color(colors[r - 1])
return tbl
# ============================================================
# Subtype (binary) plots: ROC / PR / Calibration
# ============================================================
def plot_subtype_binary(result_dir="./results", fig_dir="./figures",
title_suffix="(LUAD vs LUSC)"):
_ensure_dir(fig_dir)
# Required: train/val/test
paths = {
"Train": (os.path.join(result_dir, "subtype_train_labels.npy"),
os.path.join(result_dir, "subtype_train_scores.npy")),
"Int.Valid": (os.path.join(result_dir, "subtype_val_labels.npy"),
os.path.join(result_dir, "subtype_val_scores.npy")),
"Int.Test": (os.path.join(result_dir, "subtype_test_labels.npy"),
os.path.join(result_dir, "subtype_test_scores.npy")),
}
data = {}
missing_core = False
for k, (lp, sp) in paths.items():
y = _load_npy(lp)
s = _load_npy(sp)
if y is None or s is None:
print(f"[plot_subtype_binary] Skip: missing {lp} or {sp}")
missing_core = True
break
data[k] = (y.astype(int), s.astype(float))
if missing_core:
return
# External (simulated) if not present
ext_lp = os.path.join(result_dir, "subtype_test2_labels.npy")
ext_sp = os.path.join(result_dir, "subtype_test2_scores.npy")
ext_y = _load_npy(ext_lp)
ext_s = _load_npy(ext_sp)
if ext_y is None or ext_s is None:
ext_y, ext_s = _maybe_sim_ext(data["Int.Test"][0], data["Int.Test"][1], noise=0.04, seed=7)
data["Ext.Test"] = (ext_y.astype(int), ext_s.astype(float))
# Colors (match your style)
colors = {
"Train": "#0074B7",
"Int.Valid": "#60A3D9",
"Int.Test": "#6CC4DC",
"Ext.Test": "#61649f",
}
row_colors = [colors["Train"], colors["Int.Valid"], colors["Int.Test"], colors["Ext.Test"]]
# ---------- ROC (Figure 4a-like) ----------
roc_items = {}
for k, (y, s) in data.items():
fpr, tpr, auc_k, brier_k = _calc_binary_roc(y, s)
roc_items[k] = dict(fpr=fpr, tpr=tpr, auc=auc_k, brier=brier_k, y=y, s=s)
auc_list = np.array([roc_items[k]["auc"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float)
auc_cv = float(np.std(auc_list) / np.mean(auc_list)) if np.mean(auc_list) > 0 else 0.0
fig, ax = plt.subplots(figsize=(5, 7), facecolor="white")
ax.set_facecolor("white")
for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
ax.plot(roc_items[k]["fpr"], roc_items[k]["tpr"],
label=f"{k} (AUC = {roc_items[k]['auc']:.2f})",
color=colors[k], linewidth=3)
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
ax.set_xlim([-0.01, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xticks(np.linspace(0, 1, 6))
ax.set_yticks(np.linspace(0, 1, 6))
ax.set_xlabel("False Positive Rate", fontsize=14)
ax.set_ylabel("True Positive Rate", fontsize=14)
ax.set_title(f"Pathological Subtype Classification ROC Curves\n{title_suffix}", fontsize=14)
ax.legend(loc="lower right", fontsize=12)
ax.grid(alpha=0.3)
# Table: Number / AUC CV / Brier Score
def _posneg(y):
neg = int((y == 0).sum())
pos = int((y == 1).sum())
return f"{neg} vs {pos}"
row_labels = ["Train", "Int.Valid", "Int.Test", "Ext.Test"]
col_labels = ["Number", "AUC CV", "Brier Score"]
table_data = [
[_posneg(roc_items["Train"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Train']['brier']:.3f}"],
[_posneg(roc_items["Int.Valid"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Int.Valid']['brier']:.3f}"],
[_posneg(roc_items["Int.Test"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Int.Test']['brier']:.3f}"],
[_posneg(roc_items["Ext.Test"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Ext.Test']['brier']:.3f}"],
]
_add_table(ax, table_data, row_labels, col_labels, colors=row_colors,
bbox=(0.05, -0.52, 0.98, 0.30), fontsize=12, rowlabel_width=0.20)
plt.subplots_adjust(bottom=0.42)
plt.savefig(os.path.join(fig_dir, "Figure4a_subtype_ROC.png"), dpi=600, bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, "Figure4a_subtype_ROC.pdf"), dpi=600, bbox_inches="tight")
plt.close()
# ---------- PR (Figure 4b-like) ----------
pr_items = {}
for k, (y, s) in data.items():
p, r, ap = _calc_binary_pr(y, s)
spec, npv = _spec_npv_binary(y, s, thresh=0.5)
pr_items[k] = dict(p=p, r=r, ap=ap, spec=spec, npv=npv, y=y, s=s)
ap_vals = np.array([pr_items[k]["ap"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float)
ap_cv = float(np.std(ap_vals) / np.mean(ap_vals)) if np.mean(ap_vals) > 0 else 0.0
fig, ax = plt.subplots(figsize=(7, 5.3), facecolor="white")
ax.set_facecolor("white")
for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
ax.plot(pr_items[k]["r"], pr_items[k]["p"],
label=f"{k} (AP={pr_items[k]['ap']:.2f})",
color={
"Train": "#7F8FA3",
"Int.Valid": "#FFA0A3",
"Int.Test": "#77DDF9",
"Ext.Test": "#61649f",
}[k],
linewidth=3)
ax.fill_between(pr_items[k]["r"], pr_items[k]["p"], step='post', alpha=0.1,
color={
"Train": "#7F8FA3",
"Int.Valid": "#FFA0A3",
"Int.Test": "#77DDF9",
"Ext.Test": "#61649f",
}[k])
ax.set_xlim(-0.01, 1.01)
ax.set_ylim(-0.01, 1.01)
ax.set_xlabel("Recall", fontsize=14)
ax.set_ylabel("Precision", fontsize=14)
ax.set_title(f"Pathological Subtype Classification Precision-Recall Curves\n{title_suffix}", fontsize=14)
ax.legend(loc="lower left", fontsize=12)
ax.grid(alpha=0.3)
row_labels = [
f"Train (n={len(pr_items['Train']['y'])})",
f"Int.Valid (n={len(pr_items['Int.Valid']['y'])})",
f"Int.Test (n={len(pr_items['Int.Test']['y'])})",
f"Ext.Test (n={len(pr_items['Ext.Test']['y'])})",
]
col_labels = ["AP CV", "Specificity", "NPV", "Average Precision"]
table_data = [
[f"{ap_cv:.2f}", f"{pr_items['Train']['spec']:.2f}", f"{pr_items['Train']['npv']:.2f}", f"{pr_items['Train']['ap']:.2f}"],
[f"{ap_cv:.2f}", f"{pr_items['Int.Valid']['spec']:.2f}", f"{pr_items['Int.Valid']['npv']:.2f}", f"{pr_items['Int.Valid']['ap']:.2f}"],
[f"{ap_cv:.2f}", f"{pr_items['Int.Test']['spec']:.2f}", f"{pr_items['Int.Test']['npv']:.2f}", f"{pr_items['Int.Test']['ap']:.2f}"],
[f"{ap_cv:.2f}", f"{pr_items['Ext.Test']['spec']:.2f}", f"{pr_items['Ext.Test']['npv']:.2f}", f"{pr_items['Ext.Test']['ap']:.2f}"],
]
pr_row_colors = ["#7F8FA3", "#FFA0A3", "#77DDF9", "#61649f"]
_add_table(ax, table_data, row_labels, col_labels, colors=pr_row_colors,
bbox=(0.10, -0.55, 0.90, 0.30), fontsize=12, rowlabel_width=0.28)
plt.subplots_adjust(bottom=0.45)
plt.savefig(os.path.join(fig_dir, "Figure4b_subtype_PR.png"), dpi=600, bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, "Figure4b_subtype_PR.pdf"), dpi=600, bbox_inches="tight")
plt.close()
# ---------- Calibration (Figure 4c-like) ----------
fig, ax = plt.subplots(figsize=(5, 5.4), facecolor="white")
ax.set_facecolor("white")
calib_colors = {
"Train": "#7F8FA3",
"Int.Valid": "#FFA0A3",
"Int.Test": "#77DDF9",
"Ext.Test": "#61649f",
}
eces = {}
for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
y, s = data[k]
prob_true, prob_pred = calibration_curve(y, s, n_bins=10)
ax.plot(prob_pred, prob_true, marker='o', label=k, color=calib_colors[k])
eces[k] = _ece(y, s, n_bins=10)
ax.plot([0, 1], [0, 1], 'k--', label='Perfect')
ax.set_xlim(-0.01, 1.01)
ax.set_ylim(-0.01, 1.01)
ax.set_xlabel("Mean Predicted Probability", fontsize=14)
ax.set_ylabel("Fraction of Positives", fontsize=14)
ax.set_title(f"Pathological Subtype Classification Calibration Curves\n{title_suffix}", fontsize=14)
ax.legend(loc="lower right", fontsize=12)
ax.grid(alpha=0.3)
row_labels = [
f"Train (n={len(data['Train'][0])})",
f"Int.Valid (n={len(data['Int.Valid'][0])})",
f"Int.Test (n={len(data['Int.Test'][0])})",
f"Ext.Test (n={len(data['Ext.Test'][0])})",
]
col_labels = ["ECE"]
table_data = [
[f"{eces['Train']:.3f}"],
[f"{eces['Int.Valid']:.3f}"],
[f"{eces['Int.Test']:.3f}"],
[f"{eces['Ext.Test']:.3f}"],
]
_add_table(ax, table_data, row_labels, col_labels, colors=pr_row_colors,
bbox=(0.30, -0.55, 0.65, 0.30), fontsize=12, rowlabel_width=0.40)
plt.subplots_adjust(bottom=0.42)
plt.savefig(os.path.join(fig_dir, "Figure4c_subtype_Calibration.png"), dpi=600, bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, "Figure4c_subtype_Calibration.pdf"), dpi=600, bbox_inches="tight")
plt.close()
print("✔ Subtype (binary) figures generated.")
# ============================================================
# TNM (multiclass OVR) plots: ROC / PR / Calibration + tables
# ============================================================
def plot_tnm_multiclass(result_dir="./results", fig_dir="./figures"):
_ensure_dir(fig_dir)
req = [
"tnm_train_labels.npy", "tnm_train_scores.npy",
"tnm_val_labels.npy", "tnm_val_scores.npy",
"tnm_test_labels.npy", "tnm_test_scores.npy",
]
for f in req:
if not _exists(os.path.join(result_dir, f)):
print(f"[plot_tnm_multiclass] Skip: missing {os.path.join(result_dir, f)}")
return
train_y = np.load(os.path.join(result_dir, "tnm_train_labels.npy")).astype(int)
train_s = np.load(os.path.join(result_dir, "tnm_train_scores.npy")).astype(float)
val_y = np.load(os.path.join(result_dir, "tnm_val_labels.npy")).astype(int)
val_s = np.load(os.path.join(result_dir, "tnm_val_scores.npy")).astype(float)
test_y = np.load(os.path.join(result_dir, "tnm_test_labels.npy")).astype(int)
test_s = np.load(os.path.join(result_dir, "tnm_test_scores.npy")).astype(float)
# external (simulated unless provided)
test2_lp = os.path.join(result_dir, "tnm_test2_labels.npy")
test2_sp = os.path.join(result_dir, "tnm_test2_scores.npy")
test2_y = _load_npy(test2_lp)
test2_s = _load_npy(test2_sp)
if test2_y is None or test2_s is None:
test2_y, test2_s = _maybe_sim_ext(test_y, test_s, noise=0.05, seed=9)
test2_y = test2_y.astype(int)
test2_s = test2_s.astype(float)
classes = [0, 1, 2]
names = ['Stage I-II', 'Stage III', 'Stage IV']
colors = ['#0074B7', '#60A3D9', '#6CC4DC']
bins = {
"Train": (label_binarize(train_y, classes), train_s, train_y),
"Int.Valid": (label_binarize(val_y, classes), val_s, val_y),
"Int.Test": (label_binarize(test_y, classes), test_s, test_y),
"Ext.Test": (label_binarize(test2_y, classes), test2_s, test2_y),
}
row_labels_base = ["Train", "Int.Valid", "Int.Test", "Ext.Test"]
row_colors = ["#0074B7", "#60A3D9", "#6CC4DC", "#22a2c3"]
# ---------- Figure 5a1: ROC per class + table ----------
for i, cname in enumerate(names):
fig, ax = plt.subplots(figsize=(5, 6), facecolor="white")
ax.set_facecolor("white")
aucs = {}
fprs = {}
tprs = {}
sample_counts = {}
accs = {}
for key, (yb, ys, ylab) in bins.items():
ovr = _calc_ovr_auc(yb, ys)
fpr, tpr, auc_i = ovr[i]
fprs[key], tprs[key], aucs[key] = fpr, tpr, float(auc_i)
sample_counts[key] = str(int((ylab == i).sum()))
accs[key] = _acc_ovr(yb[:, i], ys[:, i], thresh=0.5)
# plot 4 curves with different linestyles like your original
styles = {"Train": "-", "Int.Valid": "--", "Int.Test": ":", "Ext.Test": "-."}
for key in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
ax.plot(fprs[key], tprs[key], linestyle=styles[key],
label=f"{key} (AUC = {aucs[key]:.2f})",
color=colors[i], linewidth=2.5)
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
ax.set_xlim([-0.01, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xticks(np.linspace(0, 1, 6))
ax.set_yticks(np.linspace(0, 1, 6))
ax.set_xlabel('False Positive Rate', fontsize=13)
ax.set_ylabel('True Positive Rate', fontsize=13)
ax.set_title(f'TNM stage Classification ROC Curve \nfor {cname}', fontsize=14)
ax.legend(loc="lower right", fontsize=11)
ax.grid(alpha=0.3)
# table (Sample Count / AUC / Accuracy) — same spirit as your original
col_labels = ["Sample Count", "AUC", "Accuracy"]
table_data = [
[sample_counts["Train"], f"{aucs['Train']:.2f}", f"{accs['Train']:.3f}"],
[sample_counts["Int.Valid"], f"{aucs['Int.Valid']:.2f}", f"{accs['Int.Valid']:.3f}"],
[sample_counts["Int.Test"], f"{aucs['Int.Test']:.2f}", f"{accs['Int.Test']:.3f}"],
[sample_counts["Ext.Test"], f"{aucs['Ext.Test']:.2f}", f"{accs['Ext.Test']:.3f}"],
]
_add_table(ax, table_data, row_labels_base, col_labels, colors=[colors[i]]*4,
bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18)
plt.subplots_adjust(bottom=0.38)
safe_name = cname.replace(" ", "_").replace("-", "_")
plt.savefig(os.path.join(fig_dir, f"Figure5a1_{safe_name}.png"), dpi=600, bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, f"Figure5a1_{safe_name}.pdf"), dpi=600, bbox_inches="tight")
plt.close()
# ---------- Figure 5a2: PR per class + table ----------
for i, cname in enumerate(names):
fig, ax = plt.subplots(figsize=(5, 6.5), facecolor="white")
ax.set_facecolor("white")
# PR curves for each split
pr = {}
for key, (yb, ys, ylab) in bins.items():
p, r, ap = _calc_ovr_pr(yb, ys)[i]
spec, npv = _spec_npv_binary(yb[:, i], ys[:, i], thresh=0.5)
pr[key] = dict(p=p, r=r, ap=float(ap), spec=spec, npv=npv)
# AP CV across splits (per class)
ap_vals = np.array([pr[k]["ap"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float)
ap_cv = float(np.std(ap_vals) / np.mean(ap_vals)) if np.mean(ap_vals) > 0 else 0.0
styles = {"Train": "-", "Int.Valid": "--", "Int.Test": ":", "Ext.Test": "-."}
colors_pr = ['#7F8FA3', '#FFA0A3', '#77DDF9'] # your TNM PR palette (3 classes)
c_use = colors_pr[i]
for key in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
ax.plot(pr[key]["r"], pr[key]["p"], linestyle=styles[key],
label=f"{key} (AP={pr[key]['ap']:.2f})",
color=c_use, linewidth=2.5)
ax.set_xlim([-0.01, 1.0])
ax.set_ylim([0.0, 1.05])
ax.set_xticks(np.linspace(0, 1, 6))
ax.set_yticks(np.linspace(0, 1, 6))
ax.set_xlabel('Recall', fontsize=14)
ax.set_ylabel('Precision', fontsize=14)
ax.set_title(f'TNM stage Classification Precision-Recall Curve \nfor {cname}', fontsize=14)
ax.legend(loc="lower left", fontsize=12)
ax.grid(alpha=0.3)
col_labels = ["AP CV", "Specificity", "NPV", "Average Precision"]
table_data = [
[f"{ap_cv:.2f}", f"{pr['Train']['spec']:.2f}", f"{pr['Train']['npv']:.2f}", f"{pr['Train']['ap']:.2f}"],
[f"{ap_cv:.2f}", f"{pr['Int.Valid']['spec']:.2f}", f"{pr['Int.Valid']['npv']:.2f}", f"{pr['Int.Valid']['ap']:.2f}"],
[f"{ap_cv:.2f}", f"{pr['Int.Test']['spec']:.2f}", f"{pr['Int.Test']['npv']:.2f}", f"{pr['Int.Test']['ap']:.2f}"],
[f"{ap_cv:.2f}", f"{pr['Ext.Test']['spec']:.2f}", f"{pr['Ext.Test']['npv']:.2f}", f"{pr['Ext.Test']['ap']:.2f}"],
]
_add_table(ax, table_data, row_labels_base, col_labels, colors=[c_use]*4,
bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18)
plt.subplots_adjust(bottom=0.40)
safe_name = cname.replace(" ", "_").replace("-", "_")
plt.savefig(os.path.join(fig_dir, f"Figure5a2_{safe_name}.png"), dpi=600, bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, f"Figure5a2_{safe_name}.pdf"), dpi=600, bbox_inches="tight")
plt.close()
# ---------- Figure 5a3: Calibration per class + table (ECE) ----------
for i, cname in enumerate(names):
fig, ax = plt.subplots(figsize=(5, 6.3), facecolor="white")
ax.set_facecolor("white")
calib_cols = ["#0074B7", "#60A3D9", "#6CC4DC", "#22a2c3"] # split colors
eces = {}
for (key, (yb, ys, _)), c in zip(bins.items(), calib_cols):
pt, pp = calibration_curve(yb[:, i], ys[:, i], n_bins=10, strategy="uniform")
ax.plot(pp, pt, marker='o', label=key, color=c)
eces[key] = _ece(yb[:, i], ys[:, i], n_bins=10)
ax.plot([0, 1], [0, 1], 'k--', label='Perfectly Calibrated')
ax.set_xlim(-0.01, 1.01)
ax.set_ylim(-0.01, 1.01)
ax.set_xlabel('Mean Predicted Probability', fontsize=13)
ax.set_ylabel('Fraction of Positives', fontsize=13)
ax.set_title(f'TNM stage Classification Calibration Curve \nfor {cname}', fontsize=14)
ax.legend(loc='upper left', fontsize=11)
ax.grid(alpha=0.3)
col_labels = ["ECE"]
table_data = [
[f"{eces['Train']:.3f}"],
[f"{eces['Int.Valid']:.3f}"],
[f"{eces['Int.Test']:.3f}"],
[f"{eces['Ext.Test']:.3f}"],
]
_add_table(ax, table_data, row_labels_base, col_labels, colors=calib_cols,
bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18)
plt.subplots_adjust(bottom=0.38)
safe_name = cname.replace(" ", "_").replace("-", "_")
plt.savefig(os.path.join(fig_dir, f"Figure5a3_{safe_name}.png"), dpi=600, bbox_inches="tight")
plt.savefig(os.path.join(fig_dir, f"Figure5a3_{safe_name}.pdf"), dpi=600, bbox_inches="tight")
plt.close()
print("✔ TNM multiclass figures generated.")
# ============================================================
# Survival plots (DFS/OS): KM + Cox HR + log-rank + at-risk text
# ============================================================
def _evaluate_survival(df):
df = df.copy()
df["risk_score"] = df["group"].map({"Low": 0, "Mediate": 1, "High": 2})
c_index = concordance_index(df["time"], -df["risk_score"], df["event"])
time_point = 30
y_true = (df["time"] > time_point).astype(int)
y_prob = 1 - df["risk_score"] / 2.0
brier = brier_score_loss(y_true, y_prob)
return float(c_index), float(brier)
def _plot_km_with_hr_and_atrisk(df, title, save_path, n_total=None):
kmf = KaplanMeierFitter()
fig, ax = plt.subplots(figsize=(8, 6), facecolor="white")
ax.set_facecolor("white")
colors = {"Low": "#91c7ae", "Mediate": "#f7b977", "High": "#d87c7c"}
groups = ["Low", "Mediate", "High"]
# curves + capture handles
lines = {}
at_risk_table = []
times = np.arange(0, 70, 10)
for g in groups:
m = (df["group"] == g)
if m.sum() == 0:
at_risk_table.append([0 for _ in times])
continue
kmf.fit(df.loc[m, "time"], event_observed=df.loc[m, "event"], label=g)
kmf.plot_survival_function(ci_show=True, linewidth=2, color=colors[g], ax=ax)
lines[g] = ax.get_lines()[-1]
at_risk_table.append([int(np.sum(df.loc[m, "time"] >= t)) for t in times])
handles = [lines.get("Low"), lines.get("Mediate"), lines.get("High")]
labels = ["Low", "Medium", "High"]
ax.legend(handles, labels, title="Groups", loc="upper right", framealpha=0.5, fontsize=12, title_fontsize=12)
# at-risk text (match your style)
# place below x-axis
for i, t in enumerate(times):
l, m, h = at_risk_table[0][i], at_risk_table[1][i], at_risk_table[2][i]
ax.text(t, -0.38, str(l), color="#207f4c", fontsize=13, ha='center')
ax.text(t, -0.48, str(m), color="#fca106", fontsize=13, ha='center')
ax.text(t, -0.58, str(h), color="#cc163a", fontsize=13, ha='center')
ax.text(-1, -0.28, 'Number at risk', color='black', ha='center', fontsize=13)
ax.text(-10, -0.38, "Low", color="#207f4c", fontsize=13)
ax.text(-10, -0.48, "Medium", color="#fca106", fontsize=13)
ax.text(-10, -0.58, "High", color="#cc163a", fontsize=13)
# Cox HR + Wald p
dfx = df.copy()
dfx["group_code"] = dfx["group"].map({"Low": 0, "Mediate": 1, "High": 2})
cph = CoxPHFitter()
cph.fit(dfx[["time", "event", "group_code"]], duration_col="time", event_col="event")
coef = float(cph.params_["group_code"])
se = float(cph.standard_errors_["group_code"])
hr_med_vs_low = float(np.exp(coef))
hr_high_vs_low = float(np.exp(2 * coef))
z_med = (coef) / se
p_med = float(2 * (1 - norm.cdf(abs(z_med))))
z_high = (2 * coef) / se
p_high = float(2 * (1 - norm.cdf(abs(z_high))))
# global stats
c_index, brier = _evaluate_survival(df)
logrank_p = float(multivariate_logrank_test(df["time"], df["group"], df["event"]).p_value)
ax.text(25, 0.46, f"P={logrank_p:.3f}", fontsize=12)
ax.text(25, 0.36, f"C-index={c_index:.3f}", fontsize=12)
ax.text(25, 0.26, f"Brier Score={brier:.3f}", fontsize=12)
ax.text(25, 0.16, f"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}", fontsize=12)
ax.text(25, 0.06, f"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}", fontsize=12)
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
if n_total is None:
n_total = len(df)
ax.set_title(f"{title}\n(n={n_total})", fontsize=14)
ax.set_xlabel("Time since treatment start (months)", fontsize=13)
ax.set_ylabel("Survival probability", fontsize=13)
ax.set_ylim(0, 1.05)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_path + ".png", dpi=600, bbox_inches="tight")
plt.savefig(save_path + ".pdf", dpi=600, bbox_inches="tight")
plt.close()
def plot_survival(result_dir="./results", fig_dir="./figures"):
_ensure_dir(fig_dir)
# DFS/OS for train/val/test; ext optional
for split in ["train", "val", "test"]:
dfs_path = os.path.join(result_dir, f"dfs_{split}.csv")
os_path = os.path.join(result_dir, f"os_{split}.csv")
if _exists(dfs_path):
df = pd.read_csv(dfs_path)
_plot_km_with_hr_and_atrisk(df,
title=f"Disease-Free Survival (DFS) — Kaplan-Meier Curves ({split})",
save_path=os.path.join(fig_dir, f"DFS_{split}"),
n_total=len(df))
else:
print(f"[plot_survival] Skip DFS {split}: missing {dfs_path}")
if _exists(os_path):
df = pd.read_csv(os_path)
_plot_km_with_hr_and_atrisk(df,
title=f"Overall Survival (OS) — Kaplan-Meier Curves ({split})",
save_path=os.path.join(fig_dir, f"OS_{split}"),
n_total=len(df))
else:
print(f"[plot_survival] Skip OS {split}: missing {os_path}")
print("✔ DFS / OS KM figures generated (where available).")
# ============================================================
# Public entry: plot_all
# ============================================================
def plot_all(result_dir="./results", fig_dir="./figures",
do_subtype=True, do_tnm=True, do_survival=True):
_ensure_dir(fig_dir)
if do_subtype:
plot_subtype_binary(result_dir=result_dir, fig_dir=fig_dir)
if do_tnm:
plot_tnm_multiclass(result_dir=result_dir, fig_dir=fig_dir)
if do_survival:
plot_survival(result_dir=result_dir, fig_dir=fig_dir)
# ============================================================
# CLI usage (optional)
# ============================================================
if __name__ == "__main__":
plot_all("./results", "./figures")