|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import numpy as np |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
from sklearn.preprocessing import label_binarize |
|
|
from sklearn.metrics import ( |
|
|
roc_curve, auc, |
|
|
precision_recall_curve, average_precision_score, |
|
|
confusion_matrix, |
|
|
brier_score_loss |
|
|
) |
|
|
from sklearn.calibration import calibration_curve |
|
|
|
|
|
from lifelines import KaplanMeierFitter, CoxPHFitter |
|
|
from lifelines.statistics import multivariate_logrank_test |
|
|
from lifelines.utils import concordance_index |
|
|
from scipy.stats import norm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _ensure_dir(path: str): |
|
|
os.makedirs(path, exist_ok=True) |
|
|
|
|
|
|
|
|
def _exists(path: str) -> bool: |
|
|
return os.path.exists(path) and os.path.isfile(path) |
|
|
|
|
|
|
|
|
def _load_npy(path: str): |
|
|
if not _exists(path): |
|
|
return None |
|
|
return np.load(path, allow_pickle=True) |
|
|
|
|
|
|
|
|
def _maybe_sim_ext(labels, scores, noise=0.03, seed=42): |
|
|
""" |
|
|
Simulate an external test split when not provided. |
|
|
Keeps labels same; adds small noise to scores then clips to [0,1]. |
|
|
""" |
|
|
rng = np.random.RandomState(seed) |
|
|
if scores is None: |
|
|
return None, None |
|
|
s = scores.copy() |
|
|
s = np.clip(s + rng.normal(0, noise, s.shape), 0.0, 1.0) |
|
|
return labels.copy(), s |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _calc_binary_roc(y_true, y_score): |
|
|
fpr, tpr, _ = roc_curve(y_true, y_score) |
|
|
roc_auc = auc(fpr, tpr) |
|
|
brier = brier_score_loss(y_true, y_score) |
|
|
return fpr, tpr, roc_auc, brier |
|
|
|
|
|
|
|
|
def _calc_binary_pr(y_true, y_score): |
|
|
p, r, _ = precision_recall_curve(y_true, y_score) |
|
|
ap = average_precision_score(y_true, y_score) |
|
|
return p, r, ap |
|
|
|
|
|
|
|
|
def _spec_npv_binary(y_true, y_score, thresh=0.5): |
|
|
y_pred = (y_score >= thresh).astype(int) |
|
|
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel() |
|
|
specificity = tn / (tn + fp) if (tn + fp) else 0.0 |
|
|
npv = tn / (tn + fn) if (tn + fn) else 0.0 |
|
|
return specificity, npv |
|
|
|
|
|
|
|
|
def _ece(y_true, y_score, n_bins=10): |
|
|
bins = np.linspace(0.0, 1.0, n_bins + 1) |
|
|
binids = np.digitize(y_score, bins) - 1 |
|
|
ece = 0.0 |
|
|
for i in range(n_bins): |
|
|
m = binids == i |
|
|
if m.sum() > 0: |
|
|
prob_true = np.mean(y_true[m]) |
|
|
prob_pred = np.mean(y_score[m]) |
|
|
ece += (m.sum() / len(y_score)) * abs(prob_pred - prob_true) |
|
|
return float(ece) |
|
|
|
|
|
|
|
|
def _calc_ovr_auc(y_bin, y_score): |
|
|
"""One-vs-rest ROC for multiclass. Returns dict: {class_i: (fpr,tpr,auc)}""" |
|
|
out = {} |
|
|
for i in range(y_bin.shape[1]): |
|
|
fpr, tpr, _ = roc_curve(y_bin[:, i], y_score[:, i]) |
|
|
out[i] = (fpr, tpr, auc(fpr, tpr)) |
|
|
return out |
|
|
|
|
|
|
|
|
def _calc_ovr_pr(y_bin, y_score): |
|
|
"""One-vs-rest PR for multiclass. Returns dict: {class_i: (p,r,ap)}""" |
|
|
out = {} |
|
|
for i in range(y_bin.shape[1]): |
|
|
p, r, _ = precision_recall_curve(y_bin[:, i], y_score[:, i]) |
|
|
ap = average_precision_score(y_bin[:, i], y_score[:, i]) |
|
|
out[i] = (p, r, ap) |
|
|
return out |
|
|
|
|
|
|
|
|
def _acc_ovr(y_true_bin, y_score, thresh=0.5): |
|
|
y_pred = (y_score >= thresh).astype(int) |
|
|
return float((y_pred == y_true_bin).mean()) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _auto_col_widths(col_labels, bbox_w): |
|
|
lens = np.array([max(4, len(c)) for c in col_labels], dtype=float) |
|
|
ratio = lens / lens.sum() |
|
|
return bbox_w * ratio |
|
|
|
|
|
|
|
|
def _add_table(ax, table_data, row_labels, col_labels, colors=None, |
|
|
bbox=(0.05, -0.50, 0.95, 0.30), |
|
|
fontsize=13, rowlabel_width=0.18): |
|
|
""" |
|
|
colors: list[str] length = len(row_labels) (for per-row coloring) |
|
|
""" |
|
|
tbl = plt.table( |
|
|
cellText=table_data, |
|
|
rowLabels=row_labels, |
|
|
colLabels=col_labels, |
|
|
cellLoc='center', |
|
|
rowLoc='left', |
|
|
colLoc='center', |
|
|
bbox=list(bbox), |
|
|
) |
|
|
tbl.auto_set_font_size(False) |
|
|
tbl.set_fontsize(fontsize) |
|
|
|
|
|
cells = tbl.get_celld() |
|
|
|
|
|
col_widths = _auto_col_widths(col_labels, bbox[2]) |
|
|
for col in range(len(col_labels)): |
|
|
for row in range(len(row_labels) + 1): |
|
|
cells[(row, col)].set_width(col_widths[col]) |
|
|
|
|
|
|
|
|
for row in range(1, len(row_labels) + 1): |
|
|
if (row, -1) in cells: |
|
|
cells[(row, -1)].set_width(rowlabel_width) |
|
|
|
|
|
|
|
|
for (r, c), cell in cells.items(): |
|
|
cell.set_linewidth(0) |
|
|
|
|
|
|
|
|
if colors is not None: |
|
|
for r in range(1, len(row_labels) + 1): |
|
|
|
|
|
for c in range(len(col_labels)): |
|
|
if (r, c) in cells: |
|
|
cells[(r, c)].get_text().set_color(colors[r - 1]) |
|
|
|
|
|
if (r, -1) in cells: |
|
|
cells[(r, -1)].get_text().set_color(colors[r - 1]) |
|
|
|
|
|
return tbl |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_subtype_binary(result_dir="./results", fig_dir="./figures", |
|
|
title_suffix="(LUAD vs LUSC)"): |
|
|
_ensure_dir(fig_dir) |
|
|
|
|
|
|
|
|
paths = { |
|
|
"Train": (os.path.join(result_dir, "subtype_train_labels.npy"), |
|
|
os.path.join(result_dir, "subtype_train_scores.npy")), |
|
|
"Int.Valid": (os.path.join(result_dir, "subtype_val_labels.npy"), |
|
|
os.path.join(result_dir, "subtype_val_scores.npy")), |
|
|
"Int.Test": (os.path.join(result_dir, "subtype_test_labels.npy"), |
|
|
os.path.join(result_dir, "subtype_test_scores.npy")), |
|
|
} |
|
|
|
|
|
data = {} |
|
|
missing_core = False |
|
|
for k, (lp, sp) in paths.items(): |
|
|
y = _load_npy(lp) |
|
|
s = _load_npy(sp) |
|
|
if y is None or s is None: |
|
|
print(f"[plot_subtype_binary] Skip: missing {lp} or {sp}") |
|
|
missing_core = True |
|
|
break |
|
|
data[k] = (y.astype(int), s.astype(float)) |
|
|
|
|
|
if missing_core: |
|
|
return |
|
|
|
|
|
|
|
|
ext_lp = os.path.join(result_dir, "subtype_test2_labels.npy") |
|
|
ext_sp = os.path.join(result_dir, "subtype_test2_scores.npy") |
|
|
ext_y = _load_npy(ext_lp) |
|
|
ext_s = _load_npy(ext_sp) |
|
|
if ext_y is None or ext_s is None: |
|
|
ext_y, ext_s = _maybe_sim_ext(data["Int.Test"][0], data["Int.Test"][1], noise=0.04, seed=7) |
|
|
data["Ext.Test"] = (ext_y.astype(int), ext_s.astype(float)) |
|
|
|
|
|
|
|
|
colors = { |
|
|
"Train": "#0074B7", |
|
|
"Int.Valid": "#60A3D9", |
|
|
"Int.Test": "#6CC4DC", |
|
|
"Ext.Test": "#61649f", |
|
|
} |
|
|
row_colors = [colors["Train"], colors["Int.Valid"], colors["Int.Test"], colors["Ext.Test"]] |
|
|
|
|
|
|
|
|
roc_items = {} |
|
|
for k, (y, s) in data.items(): |
|
|
fpr, tpr, auc_k, brier_k = _calc_binary_roc(y, s) |
|
|
roc_items[k] = dict(fpr=fpr, tpr=tpr, auc=auc_k, brier=brier_k, y=y, s=s) |
|
|
|
|
|
auc_list = np.array([roc_items[k]["auc"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float) |
|
|
auc_cv = float(np.std(auc_list) / np.mean(auc_list)) if np.mean(auc_list) > 0 else 0.0 |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(5, 7), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]: |
|
|
ax.plot(roc_items[k]["fpr"], roc_items[k]["tpr"], |
|
|
label=f"{k} (AUC = {roc_items[k]['auc']:.2f})", |
|
|
color=colors[k], linewidth=3) |
|
|
|
|
|
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3) |
|
|
ax.set_xlim([-0.01, 1.0]) |
|
|
ax.set_ylim([0.0, 1.05]) |
|
|
ax.set_xticks(np.linspace(0, 1, 6)) |
|
|
ax.set_yticks(np.linspace(0, 1, 6)) |
|
|
ax.set_xlabel("False Positive Rate", fontsize=14) |
|
|
ax.set_ylabel("True Positive Rate", fontsize=14) |
|
|
ax.set_title(f"Pathological Subtype Classification ROC Curves\n{title_suffix}", fontsize=14) |
|
|
ax.legend(loc="lower right", fontsize=12) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
|
|
|
def _posneg(y): |
|
|
neg = int((y == 0).sum()) |
|
|
pos = int((y == 1).sum()) |
|
|
return f"{neg} vs {pos}" |
|
|
|
|
|
row_labels = ["Train", "Int.Valid", "Int.Test", "Ext.Test"] |
|
|
col_labels = ["Number", "AUC CV", "Brier Score"] |
|
|
table_data = [ |
|
|
[_posneg(roc_items["Train"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Train']['brier']:.3f}"], |
|
|
[_posneg(roc_items["Int.Valid"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Int.Valid']['brier']:.3f}"], |
|
|
[_posneg(roc_items["Int.Test"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Int.Test']['brier']:.3f}"], |
|
|
[_posneg(roc_items["Ext.Test"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Ext.Test']['brier']:.3f}"], |
|
|
] |
|
|
_add_table(ax, table_data, row_labels, col_labels, colors=row_colors, |
|
|
bbox=(0.05, -0.52, 0.98, 0.30), fontsize=12, rowlabel_width=0.20) |
|
|
|
|
|
plt.subplots_adjust(bottom=0.42) |
|
|
plt.savefig(os.path.join(fig_dir, "Figure4a_subtype_ROC.png"), dpi=600, bbox_inches="tight") |
|
|
plt.savefig(os.path.join(fig_dir, "Figure4a_subtype_ROC.pdf"), dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
pr_items = {} |
|
|
for k, (y, s) in data.items(): |
|
|
p, r, ap = _calc_binary_pr(y, s) |
|
|
spec, npv = _spec_npv_binary(y, s, thresh=0.5) |
|
|
pr_items[k] = dict(p=p, r=r, ap=ap, spec=spec, npv=npv, y=y, s=s) |
|
|
|
|
|
ap_vals = np.array([pr_items[k]["ap"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float) |
|
|
ap_cv = float(np.std(ap_vals) / np.mean(ap_vals)) if np.mean(ap_vals) > 0 else 0.0 |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(7, 5.3), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]: |
|
|
ax.plot(pr_items[k]["r"], pr_items[k]["p"], |
|
|
label=f"{k} (AP={pr_items[k]['ap']:.2f})", |
|
|
color={ |
|
|
"Train": "#7F8FA3", |
|
|
"Int.Valid": "#FFA0A3", |
|
|
"Int.Test": "#77DDF9", |
|
|
"Ext.Test": "#61649f", |
|
|
}[k], |
|
|
linewidth=3) |
|
|
ax.fill_between(pr_items[k]["r"], pr_items[k]["p"], step='post', alpha=0.1, |
|
|
color={ |
|
|
"Train": "#7F8FA3", |
|
|
"Int.Valid": "#FFA0A3", |
|
|
"Int.Test": "#77DDF9", |
|
|
"Ext.Test": "#61649f", |
|
|
}[k]) |
|
|
|
|
|
ax.set_xlim(-0.01, 1.01) |
|
|
ax.set_ylim(-0.01, 1.01) |
|
|
ax.set_xlabel("Recall", fontsize=14) |
|
|
ax.set_ylabel("Precision", fontsize=14) |
|
|
ax.set_title(f"Pathological Subtype Classification Precision-Recall Curves\n{title_suffix}", fontsize=14) |
|
|
ax.legend(loc="lower left", fontsize=12) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
row_labels = [ |
|
|
f"Train (n={len(pr_items['Train']['y'])})", |
|
|
f"Int.Valid (n={len(pr_items['Int.Valid']['y'])})", |
|
|
f"Int.Test (n={len(pr_items['Int.Test']['y'])})", |
|
|
f"Ext.Test (n={len(pr_items['Ext.Test']['y'])})", |
|
|
] |
|
|
col_labels = ["AP CV", "Specificity", "NPV", "Average Precision"] |
|
|
table_data = [ |
|
|
[f"{ap_cv:.2f}", f"{pr_items['Train']['spec']:.2f}", f"{pr_items['Train']['npv']:.2f}", f"{pr_items['Train']['ap']:.2f}"], |
|
|
[f"{ap_cv:.2f}", f"{pr_items['Int.Valid']['spec']:.2f}", f"{pr_items['Int.Valid']['npv']:.2f}", f"{pr_items['Int.Valid']['ap']:.2f}"], |
|
|
[f"{ap_cv:.2f}", f"{pr_items['Int.Test']['spec']:.2f}", f"{pr_items['Int.Test']['npv']:.2f}", f"{pr_items['Int.Test']['ap']:.2f}"], |
|
|
[f"{ap_cv:.2f}", f"{pr_items['Ext.Test']['spec']:.2f}", f"{pr_items['Ext.Test']['npv']:.2f}", f"{pr_items['Ext.Test']['ap']:.2f}"], |
|
|
] |
|
|
pr_row_colors = ["#7F8FA3", "#FFA0A3", "#77DDF9", "#61649f"] |
|
|
_add_table(ax, table_data, row_labels, col_labels, colors=pr_row_colors, |
|
|
bbox=(0.10, -0.55, 0.90, 0.30), fontsize=12, rowlabel_width=0.28) |
|
|
|
|
|
plt.subplots_adjust(bottom=0.45) |
|
|
plt.savefig(os.path.join(fig_dir, "Figure4b_subtype_PR.png"), dpi=600, bbox_inches="tight") |
|
|
plt.savefig(os.path.join(fig_dir, "Figure4b_subtype_PR.pdf"), dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
fig, ax = plt.subplots(figsize=(5, 5.4), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
calib_colors = { |
|
|
"Train": "#7F8FA3", |
|
|
"Int.Valid": "#FFA0A3", |
|
|
"Int.Test": "#77DDF9", |
|
|
"Ext.Test": "#61649f", |
|
|
} |
|
|
eces = {} |
|
|
for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]: |
|
|
y, s = data[k] |
|
|
prob_true, prob_pred = calibration_curve(y, s, n_bins=10) |
|
|
ax.plot(prob_pred, prob_true, marker='o', label=k, color=calib_colors[k]) |
|
|
eces[k] = _ece(y, s, n_bins=10) |
|
|
|
|
|
ax.plot([0, 1], [0, 1], 'k--', label='Perfect') |
|
|
ax.set_xlim(-0.01, 1.01) |
|
|
ax.set_ylim(-0.01, 1.01) |
|
|
ax.set_xlabel("Mean Predicted Probability", fontsize=14) |
|
|
ax.set_ylabel("Fraction of Positives", fontsize=14) |
|
|
ax.set_title(f"Pathological Subtype Classification Calibration Curves\n{title_suffix}", fontsize=14) |
|
|
ax.legend(loc="lower right", fontsize=12) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
row_labels = [ |
|
|
f"Train (n={len(data['Train'][0])})", |
|
|
f"Int.Valid (n={len(data['Int.Valid'][0])})", |
|
|
f"Int.Test (n={len(data['Int.Test'][0])})", |
|
|
f"Ext.Test (n={len(data['Ext.Test'][0])})", |
|
|
] |
|
|
col_labels = ["ECE"] |
|
|
table_data = [ |
|
|
[f"{eces['Train']:.3f}"], |
|
|
[f"{eces['Int.Valid']:.3f}"], |
|
|
[f"{eces['Int.Test']:.3f}"], |
|
|
[f"{eces['Ext.Test']:.3f}"], |
|
|
] |
|
|
_add_table(ax, table_data, row_labels, col_labels, colors=pr_row_colors, |
|
|
bbox=(0.30, -0.55, 0.65, 0.30), fontsize=12, rowlabel_width=0.40) |
|
|
|
|
|
plt.subplots_adjust(bottom=0.42) |
|
|
plt.savefig(os.path.join(fig_dir, "Figure4c_subtype_Calibration.png"), dpi=600, bbox_inches="tight") |
|
|
plt.savefig(os.path.join(fig_dir, "Figure4c_subtype_Calibration.pdf"), dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
print("✔ Subtype (binary) figures generated.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_tnm_multiclass(result_dir="./results", fig_dir="./figures"): |
|
|
_ensure_dir(fig_dir) |
|
|
|
|
|
req = [ |
|
|
"tnm_train_labels.npy", "tnm_train_scores.npy", |
|
|
"tnm_val_labels.npy", "tnm_val_scores.npy", |
|
|
"tnm_test_labels.npy", "tnm_test_scores.npy", |
|
|
] |
|
|
for f in req: |
|
|
if not _exists(os.path.join(result_dir, f)): |
|
|
print(f"[plot_tnm_multiclass] Skip: missing {os.path.join(result_dir, f)}") |
|
|
return |
|
|
|
|
|
train_y = np.load(os.path.join(result_dir, "tnm_train_labels.npy")).astype(int) |
|
|
train_s = np.load(os.path.join(result_dir, "tnm_train_scores.npy")).astype(float) |
|
|
|
|
|
val_y = np.load(os.path.join(result_dir, "tnm_val_labels.npy")).astype(int) |
|
|
val_s = np.load(os.path.join(result_dir, "tnm_val_scores.npy")).astype(float) |
|
|
|
|
|
test_y = np.load(os.path.join(result_dir, "tnm_test_labels.npy")).astype(int) |
|
|
test_s = np.load(os.path.join(result_dir, "tnm_test_scores.npy")).astype(float) |
|
|
|
|
|
|
|
|
test2_lp = os.path.join(result_dir, "tnm_test2_labels.npy") |
|
|
test2_sp = os.path.join(result_dir, "tnm_test2_scores.npy") |
|
|
test2_y = _load_npy(test2_lp) |
|
|
test2_s = _load_npy(test2_sp) |
|
|
if test2_y is None or test2_s is None: |
|
|
test2_y, test2_s = _maybe_sim_ext(test_y, test_s, noise=0.05, seed=9) |
|
|
test2_y = test2_y.astype(int) |
|
|
test2_s = test2_s.astype(float) |
|
|
|
|
|
classes = [0, 1, 2] |
|
|
names = ['Stage I-II', 'Stage III', 'Stage IV'] |
|
|
colors = ['#0074B7', '#60A3D9', '#6CC4DC'] |
|
|
|
|
|
bins = { |
|
|
"Train": (label_binarize(train_y, classes), train_s, train_y), |
|
|
"Int.Valid": (label_binarize(val_y, classes), val_s, val_y), |
|
|
"Int.Test": (label_binarize(test_y, classes), test_s, test_y), |
|
|
"Ext.Test": (label_binarize(test2_y, classes), test2_s, test2_y), |
|
|
} |
|
|
row_labels_base = ["Train", "Int.Valid", "Int.Test", "Ext.Test"] |
|
|
row_colors = ["#0074B7", "#60A3D9", "#6CC4DC", "#22a2c3"] |
|
|
|
|
|
|
|
|
for i, cname in enumerate(names): |
|
|
fig, ax = plt.subplots(figsize=(5, 6), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
aucs = {} |
|
|
fprs = {} |
|
|
tprs = {} |
|
|
sample_counts = {} |
|
|
accs = {} |
|
|
|
|
|
for key, (yb, ys, ylab) in bins.items(): |
|
|
ovr = _calc_ovr_auc(yb, ys) |
|
|
fpr, tpr, auc_i = ovr[i] |
|
|
fprs[key], tprs[key], aucs[key] = fpr, tpr, float(auc_i) |
|
|
|
|
|
sample_counts[key] = str(int((ylab == i).sum())) |
|
|
accs[key] = _acc_ovr(yb[:, i], ys[:, i], thresh=0.5) |
|
|
|
|
|
|
|
|
styles = {"Train": "-", "Int.Valid": "--", "Int.Test": ":", "Ext.Test": "-."} |
|
|
for key in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]: |
|
|
ax.plot(fprs[key], tprs[key], linestyle=styles[key], |
|
|
label=f"{key} (AUC = {aucs[key]:.2f})", |
|
|
color=colors[i], linewidth=2.5) |
|
|
|
|
|
ax.plot([0, 1], [0, 1], 'k--', alpha=0.3) |
|
|
ax.set_xlim([-0.01, 1.0]) |
|
|
ax.set_ylim([0.0, 1.05]) |
|
|
ax.set_xticks(np.linspace(0, 1, 6)) |
|
|
ax.set_yticks(np.linspace(0, 1, 6)) |
|
|
ax.set_xlabel('False Positive Rate', fontsize=13) |
|
|
ax.set_ylabel('True Positive Rate', fontsize=13) |
|
|
ax.set_title(f'TNM stage Classification ROC Curve \nfor {cname}', fontsize=14) |
|
|
ax.legend(loc="lower right", fontsize=11) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
|
|
|
col_labels = ["Sample Count", "AUC", "Accuracy"] |
|
|
table_data = [ |
|
|
[sample_counts["Train"], f"{aucs['Train']:.2f}", f"{accs['Train']:.3f}"], |
|
|
[sample_counts["Int.Valid"], f"{aucs['Int.Valid']:.2f}", f"{accs['Int.Valid']:.3f}"], |
|
|
[sample_counts["Int.Test"], f"{aucs['Int.Test']:.2f}", f"{accs['Int.Test']:.3f}"], |
|
|
[sample_counts["Ext.Test"], f"{aucs['Ext.Test']:.2f}", f"{accs['Ext.Test']:.3f}"], |
|
|
] |
|
|
_add_table(ax, table_data, row_labels_base, col_labels, colors=[colors[i]]*4, |
|
|
bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18) |
|
|
|
|
|
plt.subplots_adjust(bottom=0.38) |
|
|
safe_name = cname.replace(" ", "_").replace("-", "_") |
|
|
plt.savefig(os.path.join(fig_dir, f"Figure5a1_{safe_name}.png"), dpi=600, bbox_inches="tight") |
|
|
plt.savefig(os.path.join(fig_dir, f"Figure5a1_{safe_name}.pdf"), dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
for i, cname in enumerate(names): |
|
|
fig, ax = plt.subplots(figsize=(5, 6.5), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
|
|
|
pr = {} |
|
|
for key, (yb, ys, ylab) in bins.items(): |
|
|
p, r, ap = _calc_ovr_pr(yb, ys)[i] |
|
|
spec, npv = _spec_npv_binary(yb[:, i], ys[:, i], thresh=0.5) |
|
|
pr[key] = dict(p=p, r=r, ap=float(ap), spec=spec, npv=npv) |
|
|
|
|
|
|
|
|
ap_vals = np.array([pr[k]["ap"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float) |
|
|
ap_cv = float(np.std(ap_vals) / np.mean(ap_vals)) if np.mean(ap_vals) > 0 else 0.0 |
|
|
|
|
|
styles = {"Train": "-", "Int.Valid": "--", "Int.Test": ":", "Ext.Test": "-."} |
|
|
colors_pr = ['#7F8FA3', '#FFA0A3', '#77DDF9'] |
|
|
c_use = colors_pr[i] |
|
|
|
|
|
for key in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]: |
|
|
ax.plot(pr[key]["r"], pr[key]["p"], linestyle=styles[key], |
|
|
label=f"{key} (AP={pr[key]['ap']:.2f})", |
|
|
color=c_use, linewidth=2.5) |
|
|
|
|
|
ax.set_xlim([-0.01, 1.0]) |
|
|
ax.set_ylim([0.0, 1.05]) |
|
|
ax.set_xticks(np.linspace(0, 1, 6)) |
|
|
ax.set_yticks(np.linspace(0, 1, 6)) |
|
|
ax.set_xlabel('Recall', fontsize=14) |
|
|
ax.set_ylabel('Precision', fontsize=14) |
|
|
ax.set_title(f'TNM stage Classification Precision-Recall Curve \nfor {cname}', fontsize=14) |
|
|
ax.legend(loc="lower left", fontsize=12) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
col_labels = ["AP CV", "Specificity", "NPV", "Average Precision"] |
|
|
table_data = [ |
|
|
[f"{ap_cv:.2f}", f"{pr['Train']['spec']:.2f}", f"{pr['Train']['npv']:.2f}", f"{pr['Train']['ap']:.2f}"], |
|
|
[f"{ap_cv:.2f}", f"{pr['Int.Valid']['spec']:.2f}", f"{pr['Int.Valid']['npv']:.2f}", f"{pr['Int.Valid']['ap']:.2f}"], |
|
|
[f"{ap_cv:.2f}", f"{pr['Int.Test']['spec']:.2f}", f"{pr['Int.Test']['npv']:.2f}", f"{pr['Int.Test']['ap']:.2f}"], |
|
|
[f"{ap_cv:.2f}", f"{pr['Ext.Test']['spec']:.2f}", f"{pr['Ext.Test']['npv']:.2f}", f"{pr['Ext.Test']['ap']:.2f}"], |
|
|
] |
|
|
_add_table(ax, table_data, row_labels_base, col_labels, colors=[c_use]*4, |
|
|
bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18) |
|
|
|
|
|
plt.subplots_adjust(bottom=0.40) |
|
|
safe_name = cname.replace(" ", "_").replace("-", "_") |
|
|
plt.savefig(os.path.join(fig_dir, f"Figure5a2_{safe_name}.png"), dpi=600, bbox_inches="tight") |
|
|
plt.savefig(os.path.join(fig_dir, f"Figure5a2_{safe_name}.pdf"), dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
for i, cname in enumerate(names): |
|
|
fig, ax = plt.subplots(figsize=(5, 6.3), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
calib_cols = ["#0074B7", "#60A3D9", "#6CC4DC", "#22a2c3"] |
|
|
eces = {} |
|
|
|
|
|
for (key, (yb, ys, _)), c in zip(bins.items(), calib_cols): |
|
|
pt, pp = calibration_curve(yb[:, i], ys[:, i], n_bins=10, strategy="uniform") |
|
|
ax.plot(pp, pt, marker='o', label=key, color=c) |
|
|
eces[key] = _ece(yb[:, i], ys[:, i], n_bins=10) |
|
|
|
|
|
ax.plot([0, 1], [0, 1], 'k--', label='Perfectly Calibrated') |
|
|
ax.set_xlim(-0.01, 1.01) |
|
|
ax.set_ylim(-0.01, 1.01) |
|
|
ax.set_xlabel('Mean Predicted Probability', fontsize=13) |
|
|
ax.set_ylabel('Fraction of Positives', fontsize=13) |
|
|
ax.set_title(f'TNM stage Classification Calibration Curve \nfor {cname}', fontsize=14) |
|
|
ax.legend(loc='upper left', fontsize=11) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
col_labels = ["ECE"] |
|
|
table_data = [ |
|
|
[f"{eces['Train']:.3f}"], |
|
|
[f"{eces['Int.Valid']:.3f}"], |
|
|
[f"{eces['Int.Test']:.3f}"], |
|
|
[f"{eces['Ext.Test']:.3f}"], |
|
|
] |
|
|
_add_table(ax, table_data, row_labels_base, col_labels, colors=calib_cols, |
|
|
bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18) |
|
|
|
|
|
plt.subplots_adjust(bottom=0.38) |
|
|
safe_name = cname.replace(" ", "_").replace("-", "_") |
|
|
plt.savefig(os.path.join(fig_dir, f"Figure5a3_{safe_name}.png"), dpi=600, bbox_inches="tight") |
|
|
plt.savefig(os.path.join(fig_dir, f"Figure5a3_{safe_name}.pdf"), dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
print("✔ TNM multiclass figures generated.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _evaluate_survival(df): |
|
|
df = df.copy() |
|
|
df["risk_score"] = df["group"].map({"Low": 0, "Mediate": 1, "High": 2}) |
|
|
c_index = concordance_index(df["time"], -df["risk_score"], df["event"]) |
|
|
time_point = 30 |
|
|
y_true = (df["time"] > time_point).astype(int) |
|
|
y_prob = 1 - df["risk_score"] / 2.0 |
|
|
brier = brier_score_loss(y_true, y_prob) |
|
|
return float(c_index), float(brier) |
|
|
|
|
|
|
|
|
def _plot_km_with_hr_and_atrisk(df, title, save_path, n_total=None): |
|
|
kmf = KaplanMeierFitter() |
|
|
fig, ax = plt.subplots(figsize=(8, 6), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
colors = {"Low": "#91c7ae", "Mediate": "#f7b977", "High": "#d87c7c"} |
|
|
groups = ["Low", "Mediate", "High"] |
|
|
|
|
|
|
|
|
lines = {} |
|
|
at_risk_table = [] |
|
|
times = np.arange(0, 70, 10) |
|
|
|
|
|
for g in groups: |
|
|
m = (df["group"] == g) |
|
|
if m.sum() == 0: |
|
|
at_risk_table.append([0 for _ in times]) |
|
|
continue |
|
|
kmf.fit(df.loc[m, "time"], event_observed=df.loc[m, "event"], label=g) |
|
|
kmf.plot_survival_function(ci_show=True, linewidth=2, color=colors[g], ax=ax) |
|
|
lines[g] = ax.get_lines()[-1] |
|
|
at_risk_table.append([int(np.sum(df.loc[m, "time"] >= t)) for t in times]) |
|
|
|
|
|
handles = [lines.get("Low"), lines.get("Mediate"), lines.get("High")] |
|
|
labels = ["Low", "Medium", "High"] |
|
|
ax.legend(handles, labels, title="Groups", loc="upper right", framealpha=0.5, fontsize=12, title_fontsize=12) |
|
|
|
|
|
|
|
|
|
|
|
for i, t in enumerate(times): |
|
|
l, m, h = at_risk_table[0][i], at_risk_table[1][i], at_risk_table[2][i] |
|
|
ax.text(t, -0.38, str(l), color="#207f4c", fontsize=13, ha='center') |
|
|
ax.text(t, -0.48, str(m), color="#fca106", fontsize=13, ha='center') |
|
|
ax.text(t, -0.58, str(h), color="#cc163a", fontsize=13, ha='center') |
|
|
|
|
|
ax.text(-1, -0.28, 'Number at risk', color='black', ha='center', fontsize=13) |
|
|
ax.text(-10, -0.38, "Low", color="#207f4c", fontsize=13) |
|
|
ax.text(-10, -0.48, "Medium", color="#fca106", fontsize=13) |
|
|
ax.text(-10, -0.58, "High", color="#cc163a", fontsize=13) |
|
|
|
|
|
|
|
|
dfx = df.copy() |
|
|
dfx["group_code"] = dfx["group"].map({"Low": 0, "Mediate": 1, "High": 2}) |
|
|
cph = CoxPHFitter() |
|
|
cph.fit(dfx[["time", "event", "group_code"]], duration_col="time", event_col="event") |
|
|
coef = float(cph.params_["group_code"]) |
|
|
se = float(cph.standard_errors_["group_code"]) |
|
|
|
|
|
hr_med_vs_low = float(np.exp(coef)) |
|
|
hr_high_vs_low = float(np.exp(2 * coef)) |
|
|
|
|
|
z_med = (coef) / se |
|
|
p_med = float(2 * (1 - norm.cdf(abs(z_med)))) |
|
|
z_high = (2 * coef) / se |
|
|
p_high = float(2 * (1 - norm.cdf(abs(z_high)))) |
|
|
|
|
|
|
|
|
c_index, brier = _evaluate_survival(df) |
|
|
logrank_p = float(multivariate_logrank_test(df["time"], df["group"], df["event"]).p_value) |
|
|
|
|
|
ax.text(25, 0.46, f"P={logrank_p:.3f}", fontsize=12) |
|
|
ax.text(25, 0.36, f"C-index={c_index:.3f}", fontsize=12) |
|
|
ax.text(25, 0.26, f"Brier Score={brier:.3f}", fontsize=12) |
|
|
ax.text(25, 0.16, f"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}", fontsize=12) |
|
|
ax.text(25, 0.06, f"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}", fontsize=12) |
|
|
|
|
|
ax.spines['top'].set_visible(False) |
|
|
ax.spines['right'].set_visible(False) |
|
|
|
|
|
if n_total is None: |
|
|
n_total = len(df) |
|
|
|
|
|
ax.set_title(f"{title}\n(n={n_total})", fontsize=14) |
|
|
ax.set_xlabel("Time since treatment start (months)", fontsize=13) |
|
|
ax.set_ylabel("Survival probability", fontsize=13) |
|
|
ax.set_ylim(0, 1.05) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_path + ".png", dpi=600, bbox_inches="tight") |
|
|
plt.savefig(save_path + ".pdf", dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def plot_survival(result_dir="./results", fig_dir="./figures"): |
|
|
_ensure_dir(fig_dir) |
|
|
|
|
|
|
|
|
for split in ["train", "val", "test"]: |
|
|
dfs_path = os.path.join(result_dir, f"dfs_{split}.csv") |
|
|
os_path = os.path.join(result_dir, f"os_{split}.csv") |
|
|
|
|
|
if _exists(dfs_path): |
|
|
df = pd.read_csv(dfs_path) |
|
|
_plot_km_with_hr_and_atrisk(df, |
|
|
title=f"Disease-Free Survival (DFS) — Kaplan-Meier Curves ({split})", |
|
|
save_path=os.path.join(fig_dir, f"DFS_{split}"), |
|
|
n_total=len(df)) |
|
|
else: |
|
|
print(f"[plot_survival] Skip DFS {split}: missing {dfs_path}") |
|
|
|
|
|
if _exists(os_path): |
|
|
df = pd.read_csv(os_path) |
|
|
_plot_km_with_hr_and_atrisk(df, |
|
|
title=f"Overall Survival (OS) — Kaplan-Meier Curves ({split})", |
|
|
save_path=os.path.join(fig_dir, f"OS_{split}"), |
|
|
n_total=len(df)) |
|
|
else: |
|
|
print(f"[plot_survival] Skip OS {split}: missing {os_path}") |
|
|
|
|
|
print("✔ DFS / OS KM figures generated (where available).") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_all(result_dir="./results", fig_dir="./figures", |
|
|
do_subtype=True, do_tnm=True, do_survival=True): |
|
|
_ensure_dir(fig_dir) |
|
|
|
|
|
if do_subtype: |
|
|
plot_subtype_binary(result_dir=result_dir, fig_dir=fig_dir) |
|
|
|
|
|
if do_tnm: |
|
|
plot_tnm_multiclass(result_dir=result_dir, fig_dir=fig_dir) |
|
|
|
|
|
if do_survival: |
|
|
plot_survival(result_dir=result_dir, fig_dir=fig_dir) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
plot_all("./results", "./figures") |
|
|
|