|
|
import sys |
|
|
sys.path.insert(0, "/export/home/daifang/lunghospital/MM-DLS-master/MM-DLS-master") |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.utils.data import DataLoader, random_split |
|
|
|
|
|
from sklearn.metrics import roc_auc_score, accuracy_score |
|
|
from sklearn.preprocessing import label_binarize |
|
|
|
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
from lifelines import KaplanMeierFitter, CoxPHFitter |
|
|
from lifelines.statistics import multivariate_logrank_test |
|
|
from lifelines.utils import concordance_index |
|
|
from sklearn.metrics import brier_score_loss |
|
|
from scipy.stats import norm |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
PROJECT_ROOT = os.path.abspath(".") |
|
|
if PROJECT_ROOT not in sys.path: |
|
|
sys.path.insert(0, PROJECT_ROOT) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _import_modules(): |
|
|
|
|
|
from mm_dls.HierMM_DLS import HierMM_DLS |
|
|
from mm_dls.FakePatientDataset import FakePatientDataset |
|
|
from mm_dls.CoxphLoss import CoxPHLoss |
|
|
return HierMM_DLS, FakePatientDataset, CoxPHLoss |
|
|
|
|
|
|
|
|
HierMM_DLS, FakePatientDataset, CoxPHLoss = _import_modules() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
EPOCHS = 300 |
|
|
PATIENCE = 8 |
|
|
BATCH_SIZE = 4 |
|
|
LR = 1e-4 |
|
|
WEIGHT_DECAY = 1e-5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
NUM_SUBTYPES = 2 |
|
|
NUM_TNM = 3 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
N_SLICES = 30 |
|
|
IMG_SIZE = 224 |
|
|
|
|
|
|
|
|
SAVE_DIR = "./results" |
|
|
FIG_DIR = "./figures" |
|
|
os.makedirs(SAVE_DIR, exist_ok=True) |
|
|
os.makedirs(FIG_DIR, exist_ok=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert torch.cuda.is_available(), "CUDA not available" |
|
|
DEVICE = torch.device("cuda:1") |
|
|
torch.cuda.set_device(DEVICE) |
|
|
print("Using device:", DEVICE) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _sigmoid(x): |
|
|
return 1 / (1 + np.exp(-x)) |
|
|
|
|
|
def _ensure_numpy(x): |
|
|
if isinstance(x, torch.Tensor): |
|
|
return x.detach().cpu().numpy() |
|
|
return x |
|
|
|
|
|
def _risk_to_groups(risk, q=(1/3, 2/3), labels=("Low", "Mediate", "High")): |
|
|
""" |
|
|
Convert continuous risk into 3 groups by tertiles. |
|
|
""" |
|
|
r = np.asarray(risk).reshape(-1) |
|
|
t1, t2 = np.quantile(r, q[0]), np.quantile(r, q[1]) |
|
|
out = np.full(len(r), labels[1], dtype=object) |
|
|
out[r <= t1] = labels[0] |
|
|
out[r >= t2] = labels[2] |
|
|
return out |
|
|
|
|
|
def _evaluate_survival_metrics(time, event, risk, time_point=30): |
|
|
""" |
|
|
C-index + Brier at a fixed time point. |
|
|
risk: higher => earlier event, so use -risk in concordance_index. |
|
|
""" |
|
|
time = np.asarray(time).reshape(-1) |
|
|
event = np.asarray(event).reshape(-1).astype(int) |
|
|
risk = np.asarray(risk).reshape(-1) |
|
|
|
|
|
c_index = concordance_index(time, -risk, event) |
|
|
|
|
|
|
|
|
|
|
|
y_true = (time > time_point).astype(int) |
|
|
|
|
|
y_prob = 1 - (risk - risk.min()) / (risk.max() - risk.min() + 1e-8) |
|
|
brier = brier_score_loss(y_true, y_prob) |
|
|
|
|
|
return float(c_index), float(brier) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_epoch_verbose(model, loader, optimizer, device, train=True): |
|
|
ce = nn.CrossEntropyLoss() |
|
|
bce = nn.BCEWithLogitsLoss(reduction="none") |
|
|
cox = CoxPHLoss() |
|
|
|
|
|
model.train() if train else model.eval() |
|
|
|
|
|
losses = [] |
|
|
|
|
|
|
|
|
sub_y_all, sub_s_all = [], [] |
|
|
tnm_y_all, tnm_s_all = [], [] |
|
|
treat_all = [] |
|
|
|
|
|
|
|
|
dfs_r_all, dfs_t_all, dfs_e_all = [], [], [] |
|
|
os_r_all, os_t_all, os_e_all = [], [], [] |
|
|
|
|
|
|
|
|
dfs_log_all, os_log_all = [], [] |
|
|
|
|
|
for batch in loader: |
|
|
|
|
|
if len(batch) != 19: |
|
|
raise ValueError(f"Batch length mismatch: expected 19, got {len(batch)}. " |
|
|
f"Please ensure Dataset __getitem__ returns treatment as the 19th item.") |
|
|
|
|
|
( |
|
|
pid, lesion, space, rad, pet, cli, |
|
|
y_sub, y_tnm, |
|
|
dfs_t, dfs_e, |
|
|
os_t, os_e, |
|
|
dfs1, dfs3, dfs5, |
|
|
os1, os3, os5, |
|
|
treatment |
|
|
) = batch |
|
|
|
|
|
lesion, space = lesion.to(device), space.to(device) |
|
|
rad, pet, cli = rad.to(device), pet.to(device), cli.to(device) |
|
|
y_sub, y_tnm = y_sub.to(device), y_tnm.to(device) |
|
|
dfs_t, dfs_e = dfs_t.to(device), dfs_e.to(device) |
|
|
os_t, os_e = os_t.to(device), os_e.to(device) |
|
|
treatment = treatment.to(device) |
|
|
|
|
|
dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(device) |
|
|
os_y = torch.stack([os1, os3, os5 ], dim=1).to(device) |
|
|
|
|
|
with torch.set_grad_enabled(train): |
|
|
sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model( |
|
|
lesion, space, rad, pet, cli |
|
|
) |
|
|
|
|
|
loss = ( |
|
|
ce(sub_l, y_sub) + |
|
|
ce(tnm_l, y_tnm) + |
|
|
cox(dfs_r, dfs_t, dfs_e) + |
|
|
cox(os_r, os_t, os_e) + |
|
|
bce(dfs_log, dfs_y).mean() + |
|
|
bce(os_log, os_y ).mean() |
|
|
) |
|
|
|
|
|
if train: |
|
|
optimizer.zero_grad() |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
losses.append(loss.item()) |
|
|
|
|
|
|
|
|
sub_prob = torch.softmax(sub_l, dim=1)[:, 1] |
|
|
tnm_prob = torch.softmax(tnm_l, dim=1) |
|
|
|
|
|
sub_s_all.append(_ensure_numpy(sub_prob)) |
|
|
sub_y_all.append(_ensure_numpy(y_sub)) |
|
|
|
|
|
tnm_s_all.append(_ensure_numpy(tnm_prob)) |
|
|
tnm_y_all.append(_ensure_numpy(y_tnm)) |
|
|
|
|
|
treat_all.append(_ensure_numpy(treatment)) |
|
|
|
|
|
|
|
|
dfs_r_all.append(_ensure_numpy(dfs_r)) |
|
|
dfs_t_all.append(_ensure_numpy(dfs_t)) |
|
|
dfs_e_all.append(_ensure_numpy(dfs_e)) |
|
|
|
|
|
os_r_all.append(_ensure_numpy(os_r)) |
|
|
os_t_all.append(_ensure_numpy(os_t)) |
|
|
os_e_all.append(_ensure_numpy(os_e)) |
|
|
|
|
|
dfs_log_all.append(_ensure_numpy(dfs_log)) |
|
|
os_log_all.append(_ensure_numpy(os_log)) |
|
|
|
|
|
return ( |
|
|
float(np.mean(losses)), |
|
|
|
|
|
np.concatenate(sub_y_all), |
|
|
np.concatenate(sub_s_all), |
|
|
|
|
|
np.concatenate(tnm_y_all), |
|
|
np.concatenate(tnm_s_all), |
|
|
|
|
|
np.concatenate(treat_all), |
|
|
|
|
|
np.concatenate(dfs_r_all), |
|
|
np.concatenate(dfs_t_all), |
|
|
np.concatenate(dfs_e_all), |
|
|
|
|
|
np.concatenate(os_r_all), |
|
|
np.concatenate(os_t_all), |
|
|
np.concatenate(os_e_all), |
|
|
|
|
|
np.concatenate(dfs_log_all, axis=0), |
|
|
np.concatenate(os_log_all, axis=0), |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def evaluate_by_treatment(sub_y, sub_s, tnm_y, tnm_s, treat, |
|
|
dfs_r, dfs_t, dfs_e, os_r, os_t, os_e): |
|
|
results = {} |
|
|
|
|
|
cohorts = { |
|
|
"All": np.ones_like(treat, dtype=bool), |
|
|
"Immune": treat == 0, |
|
|
"Chemo": treat == 1, |
|
|
} |
|
|
|
|
|
for name, mask in cohorts.items(): |
|
|
if mask.sum() < 10: |
|
|
continue |
|
|
|
|
|
res = {} |
|
|
|
|
|
|
|
|
res["Subtype_AUC"] = roc_auc_score(sub_y[mask], sub_s[mask]) |
|
|
res["Subtype_ACC"] = accuracy_score(sub_y[mask], (sub_s[mask] > 0.5).astype(int)) |
|
|
|
|
|
|
|
|
tnm_bin = label_binarize(tnm_y[mask], classes=[0, 1, 2]) |
|
|
res["TNM_AUC_macro"] = roc_auc_score( |
|
|
tnm_bin, tnm_s[mask], average="macro", multi_class="ovr" |
|
|
) |
|
|
res["TNM_ACC"] = accuracy_score( |
|
|
tnm_y[mask], np.argmax(tnm_s[mask], axis=1) |
|
|
) |
|
|
|
|
|
|
|
|
dfs_c, dfs_b = _evaluate_survival_metrics(dfs_t[mask], dfs_e[mask], dfs_r[mask], time_point=30) |
|
|
os_c, os_b = _evaluate_survival_metrics(os_t[mask], os_e[mask], os_r[mask], time_point=30) |
|
|
|
|
|
res["DFS_C_index"] = dfs_c |
|
|
res["DFS_Brier_30m"] = dfs_b |
|
|
res["OS_C_index"] = os_c |
|
|
res["OS_Brier_30m"] = os_b |
|
|
|
|
|
results[name] = res |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def plot_km_curve_with_hr(df, title, save_prefix): |
|
|
""" |
|
|
df must contain columns: time, event, group (Low/Mediate/High) |
|
|
""" |
|
|
kmf = KaplanMeierFitter() |
|
|
fig, ax = plt.subplots(figsize=(8, 6), facecolor="white") |
|
|
ax.set_facecolor("white") |
|
|
|
|
|
colors = {"Low": "#91c7ae", "Mediate": "#f7b977", "High": "#d87c7c"} |
|
|
groups = ["Low", "Mediate", "High"] |
|
|
|
|
|
|
|
|
lines = {} |
|
|
at_risk_table = [] |
|
|
times = np.arange(0, 70, 10) |
|
|
|
|
|
for g in groups: |
|
|
m = df["group"] == g |
|
|
if m.sum() == 0: |
|
|
continue |
|
|
|
|
|
kmf.fit(df.loc[m, "time"], event_observed=df.loc[m, "event"], label=g) |
|
|
kmf.plot_survival_function( |
|
|
ax=ax, ci_show=True, linewidth=2, color=colors[g], marker="+" |
|
|
) |
|
|
lines[g] = ax.get_lines()[-1] |
|
|
|
|
|
at_risk_table.append([np.sum(df.loc[m, "time"] >= t) for t in times]) |
|
|
|
|
|
|
|
|
handles = [lines[g] for g in groups if g in lines] |
|
|
labels = ["Low", "Medium", "High"][:len(handles)] |
|
|
ax.legend(handles, labels, title="Groups", loc="upper right", |
|
|
frameon=True, framealpha=0.5, fontsize=12, title_fontsize=12) |
|
|
|
|
|
|
|
|
if len(at_risk_table) == 3: |
|
|
low, mid, high = at_risk_table |
|
|
for i, t in enumerate(times): |
|
|
ax.text(t, -0.38, str(low[i]), color="#207f4c", fontsize=14, ha="center") |
|
|
ax.text(t, -0.48, str(mid[i]), color="#fca106", fontsize=14, ha="center") |
|
|
ax.text(t, -0.58, str(high[i]), color="#cc163a", fontsize=14, ha="center") |
|
|
|
|
|
ax.text(-1, -0.28, "Number at risk", color="black", ha="center", fontsize=14) |
|
|
ax.text(-10, -0.38, "Low", color="#207f4c", fontsize=14) |
|
|
ax.text(-10, -0.48, "Medium", color="#fca106", fontsize=14) |
|
|
ax.text(-10, -0.58, "High", color="#cc163a", fontsize=14) |
|
|
|
|
|
|
|
|
df2 = df.copy() |
|
|
df2["group_code"] = df2["group"].map({"Low": 0, "Mediate": 1, "High": 2}) |
|
|
cph = CoxPHFitter() |
|
|
cph.fit(df2[["time", "event", "group_code"]], duration_col="time", event_col="event") |
|
|
|
|
|
coef = float(cph.params_["group_code"]) |
|
|
se = float(cph.standard_errors_["group_code"]) |
|
|
|
|
|
hr_med_vs_low = np.exp(coef * 1) |
|
|
hr_high_vs_low = np.exp(coef * 2) |
|
|
|
|
|
z_med = (coef * 1) / se |
|
|
p_med = 2 * (1 - norm.cdf(abs(z_med))) |
|
|
|
|
|
z_high = (coef * 2) / se |
|
|
p_high = 2 * (1 - norm.cdf(abs(z_high))) |
|
|
|
|
|
|
|
|
res_lr = multivariate_logrank_test(df2["time"], df2["group"], df2["event"]) |
|
|
|
|
|
|
|
|
c_index, brier = _evaluate_survival_metrics(df2["time"].values, df2["event"].values, |
|
|
df2["group_code"].values, time_point=30) |
|
|
|
|
|
ax.text(25, 0.46, f"P(log-rank)={res_lr.p_value:.3f}", fontsize=12) |
|
|
ax.text(25, 0.36, f"C-index={c_index:.3f}", fontsize=12) |
|
|
ax.text(25, 0.26, f"Brier(30m)={brier:.3f}", fontsize=12) |
|
|
ax.text(25, 0.16, f"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}", fontsize=12) |
|
|
ax.text(25, 0.06, f"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}", fontsize=12) |
|
|
|
|
|
|
|
|
ax.spines["top"].set_visible(False) |
|
|
ax.spines["right"].set_visible(False) |
|
|
ax.set_title(title, fontsize=14) |
|
|
ax.set_xlabel("Time since treatment start (months)", fontsize=14) |
|
|
ax.set_ylabel("Survival probability", fontsize=14) |
|
|
ax.set_ylim(0, 1.05) |
|
|
ax.grid(alpha=0.3) |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(save_prefix + ".png", dpi=600, bbox_inches="tight") |
|
|
plt.savefig(save_prefix + ".pdf", dpi=600, bbox_inches="tight") |
|
|
plt.close() |
|
|
return save_prefix |
|
|
|
|
|
|
|
|
def generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=("val", "test")): |
|
|
""" |
|
|
Load saved dfs/os arrays and generate KM+HR for Immune/Chemo separately. |
|
|
""" |
|
|
os.makedirs(fig_dir, exist_ok=True) |
|
|
|
|
|
for split in which_split: |
|
|
|
|
|
trt = np.load(os.path.join(result_dir, f"treatment_{split}.npy")) |
|
|
|
|
|
dfs_r = np.load(os.path.join(result_dir, f"dfs_{split}_risk.npy")) |
|
|
dfs_t = np.load(os.path.join(result_dir, f"dfs_{split}_time.npy")) |
|
|
dfs_e = np.load(os.path.join(result_dir, f"dfs_{split}_event.npy")) |
|
|
|
|
|
os_r = np.load(os.path.join(result_dir, f"os_{split}_risk.npy")) |
|
|
os_t = np.load(os.path.join(result_dir, f"os_{split}_time.npy")) |
|
|
os_e = np.load(os.path.join(result_dir, f"os_{split}_event.npy")) |
|
|
|
|
|
for cohort_name, mask in { |
|
|
"Immune": trt == 0, |
|
|
"Chemo": trt == 1 |
|
|
}.items(): |
|
|
if mask.sum() < 20: |
|
|
print(f"[Figure7] Skip {split}-{cohort_name}: too few samples ({mask.sum()})") |
|
|
continue |
|
|
|
|
|
|
|
|
dfs_group = _risk_to_groups(dfs_r[mask]) |
|
|
df_dfs = pd.DataFrame({ |
|
|
"time": dfs_t[mask], |
|
|
"event": dfs_e[mask].astype(int), |
|
|
"group": dfs_group |
|
|
}) |
|
|
|
|
|
|
|
|
os_group = _risk_to_groups(os_r[mask]) |
|
|
df_os = pd.DataFrame({ |
|
|
"time": os_t[mask], |
|
|
"event": os_e[mask].astype(int), |
|
|
"group": os_group |
|
|
}) |
|
|
|
|
|
|
|
|
df_dfs.to_csv(os.path.join(result_dir, f"dfs_{split}_{cohort_name}.csv"), index=False) |
|
|
df_os.to_csv(os.path.join(result_dir, f"os_{split}_{cohort_name}.csv"), index=False) |
|
|
|
|
|
|
|
|
plot_km_curve_with_hr( |
|
|
df_dfs, |
|
|
title=f"Disease-Free Survival (DFS) — Kaplan-Meier Curves\n{cohort_name} {split} set (n={mask.sum()})", |
|
|
save_prefix=os.path.join(fig_dir, f"Figure7_DFS_{cohort_name}_{split}") |
|
|
) |
|
|
plot_km_curve_with_hr( |
|
|
df_os, |
|
|
title=f"Overall Survival (OS) — Kaplan-Meier Curves\n{cohort_name} {split} set (n={mask.sum()})", |
|
|
save_prefix=os.path.join(fig_dir, f"Figure7_OS_{cohort_name}_{split}") |
|
|
) |
|
|
|
|
|
print("✔ Figure 7 generated (DFS/OS KM + HR) for Immune/Chemo.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
|
|
|
|
|
|
|
|
|
from mm_dls.PatientDataset import PatientDataset |
|
|
|
|
|
dataset = PatientDataset( |
|
|
data_root="/path/to/DATA_ROOT", |
|
|
clinical_csv="/path/to/clinical.csv", |
|
|
radiomics_npy="/path/to/radiomics.npy", |
|
|
pet_npy="/path/to/pet.npy", |
|
|
n_slices=N_SLICES, |
|
|
img_size=IMG_SIZE, |
|
|
) |
|
|
|
|
|
|
|
|
n_train = int(0.6 * len(dataset)) |
|
|
n_val = int(0.2 * len(dataset)) |
|
|
n_test = len(dataset) - n_train - n_val |
|
|
|
|
|
train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test]) |
|
|
|
|
|
loaders = { |
|
|
"train": DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=4), |
|
|
"val": DataLoader(val_set, BATCH_SIZE, shuffle=False, num_workers=4), |
|
|
"test": DataLoader(test_set, BATCH_SIZE, shuffle=False, num_workers=4), |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE) |
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY) |
|
|
|
|
|
best_val_loss = 1e9 |
|
|
wait = 0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for epoch in range(1, EPOCHS + 1): |
|
|
tr = run_epoch_verbose(model, loaders["train"], optimizer, DEVICE, train=True) |
|
|
va = run_epoch_verbose(model, loaders["val"], optimizer, DEVICE, train=False) |
|
|
|
|
|
tr_loss = tr[0] |
|
|
va_loss = va[0] |
|
|
|
|
|
|
|
|
_, sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e, _, _ = va |
|
|
metrics = evaluate_by_treatment(sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e) |
|
|
|
|
|
print(f"\n[Epoch {epoch:03d}] Train Loss={tr_loss:.3f} | Val Loss={va_loss:.3f}") |
|
|
for k, v in metrics.items(): |
|
|
print( |
|
|
f" {k:7s} | " |
|
|
f"Subtype AUC={v['Subtype_AUC']:.3f} | " |
|
|
f"TNM AUC={v['TNM_AUC_macro']:.3f} | " |
|
|
f"DFS C-index={v['DFS_C_index']:.3f} | " |
|
|
f"OS C-index={v['OS_C_index']:.3f}" |
|
|
) |
|
|
|
|
|
|
|
|
if va_loss < best_val_loss: |
|
|
best_val_loss = va_loss |
|
|
wait = 0 |
|
|
torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model.pt")) |
|
|
print(" ✓ Best model updated") |
|
|
else: |
|
|
wait += 1 |
|
|
if wait >= PATIENCE: |
|
|
print("\n⏹ Early stopping triggered") |
|
|
break |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nRunning inference with best model...") |
|
|
model.load_state_dict(torch.load(os.path.join(SAVE_DIR, "best_model.pt"), map_location=DEVICE)) |
|
|
|
|
|
for split in ["train", "val", "test"]: |
|
|
out = run_epoch_verbose(model, loaders[split], optimizer, DEVICE, train=False) |
|
|
( |
|
|
loss, |
|
|
sy, ss, |
|
|
ty, ts, |
|
|
trt, |
|
|
dfs_r, dfs_t, dfs_e, |
|
|
os_r, os_t, os_e, |
|
|
dfs_log, os_log |
|
|
) = out |
|
|
|
|
|
|
|
|
np.save(os.path.join(SAVE_DIR, f"subtype_{split}_labels.npy"), sy) |
|
|
np.save(os.path.join(SAVE_DIR, f"subtype_{split}_scores.npy"), ss) |
|
|
np.save(os.path.join(SAVE_DIR, f"tnm_{split}_labels.npy"), ty) |
|
|
np.save(os.path.join(SAVE_DIR, f"tnm_{split}_scores.npy"), ts) |
|
|
np.save(os.path.join(SAVE_DIR, f"treatment_{split}.npy"), trt) |
|
|
|
|
|
|
|
|
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_risk.npy"), dfs_r) |
|
|
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_time.npy"), dfs_t) |
|
|
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_event.npy"), dfs_e) |
|
|
|
|
|
np.save(os.path.join(SAVE_DIR, f"os_{split}_risk.npy"), os_r) |
|
|
np.save(os.path.join(SAVE_DIR, f"os_{split}_time.npy"), os_t) |
|
|
np.save(os.path.join(SAVE_DIR, f"os_{split}_event.npy"), os_e) |
|
|
|
|
|
|
|
|
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_logits_1y3y5y.npy"), dfs_log) |
|
|
np.save(os.path.join(SAVE_DIR, f"os_{split}_logits_1y3y5y.npy"), os_log) |
|
|
|
|
|
print(f"{split:5s} | loss={loss:.3f} | Immune={np.sum(trt==0)} Chemo={np.sum(trt==1)}") |
|
|
|
|
|
print("\n✓ Inference completed. Results saved.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("\nGenerating Figure (KM + HR) ...") |
|
|
generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=("val", "test")) |
|
|
print("✓ Figure done. Files saved under ./figures") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|
|
|
|