MM-DLS / train_patient_model.py
FangDai's picture
Upload 4 files
bf41494 verified
import sys
sys.path.insert(0, "/export/home/daifang/lunghospital/MM-DLS-master/MM-DLS-master")
# main.py
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from sklearn.metrics import roc_auc_score, accuracy_score
from sklearn.preprocessing import label_binarize
import pandas as pd
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter, CoxPHFitter
from lifelines.statistics import multivariate_logrank_test
from lifelines.utils import concordance_index
from sklearn.metrics import brier_score_loss
from scipy.stats import norm
# =========================================================
# Project path (IMPORTANT for Jupyter / HPC)
# =========================================================
PROJECT_ROOT = os.path.abspath(".")
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
# =========================================================
# imports: mm_dls/
# =========================================================
def _import_modules():
from mm_dls.HierMM_DLS import HierMM_DLS
from mm_dls.FakePatientDataset import FakePatientDataset
from mm_dls.CoxphLoss import CoxPHLoss
return HierMM_DLS, FakePatientDataset, CoxPHLoss
HierMM_DLS, FakePatientDataset, CoxPHLoss = _import_modules()
# =========================
# Training configuration
# =========================
EPOCHS = 300
PATIENCE = 8
BATCH_SIZE = 4
LR = 1e-4
WEIGHT_DECAY = 1e-5
# =========================
# Task definition
# =========================
NUM_SUBTYPES = 2 # e.g., LUAD vs LUSC
NUM_TNM = 3 # Stage I–II / III / IV
# =========================
# Image settings
# =========================
N_SLICES = 30 # max slices per patient
IMG_SIZE = 224
SAVE_DIR = "./results"
FIG_DIR = "./figures"
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(FIG_DIR, exist_ok=True)
# -------------------------
# GPU (force cuda:1)
# -------------------------
assert torch.cuda.is_available(), "CUDA not available"
DEVICE = torch.device("cuda:1")
torch.cuda.set_device(DEVICE)
print("Using device:", DEVICE)
# =========================================================
# Core utils
# =========================================================
def _sigmoid(x):
return 1 / (1 + np.exp(-x))
def _ensure_numpy(x):
if isinstance(x, torch.Tensor):
return x.detach().cpu().numpy()
return x
def _risk_to_groups(risk, q=(1/3, 2/3), labels=("Low", "Mediate", "High")):
"""
Convert continuous risk into 3 groups by tertiles.
"""
r = np.asarray(risk).reshape(-1)
t1, t2 = np.quantile(r, q[0]), np.quantile(r, q[1])
out = np.full(len(r), labels[1], dtype=object)
out[r <= t1] = labels[0]
out[r >= t2] = labels[2]
return out
def _evaluate_survival_metrics(time, event, risk, time_point=30):
"""
C-index + Brier at a fixed time point.
risk: higher => earlier event, so use -risk in concordance_index.
"""
time = np.asarray(time).reshape(-1)
event = np.asarray(event).reshape(-1).astype(int)
risk = np.asarray(risk).reshape(-1)
c_index = concordance_index(time, -risk, event)
# Brier: predict survival at time_point using a monotonic transform of risk (proxy)
# This is a "proxy" survival probability for demo/debug; replace with proper survival model if needed.
y_true = (time > time_point).astype(int) # 1 means survived beyond time_point
# map risk into [0,1] survival prob proxy: higher risk => lower survival prob
y_prob = 1 - (risk - risk.min()) / (risk.max() - risk.min() + 1e-8)
brier = brier_score_loss(y_true, y_prob)
return float(c_index), float(brier)
# =========================================================
# One epoch (train / eval)
# =========================================================
def run_epoch_verbose(model, loader, optimizer, device, train=True):
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss(reduction="none")
cox = CoxPHLoss()
model.train() if train else model.eval()
losses = []
# classification
sub_y_all, sub_s_all = [], []
tnm_y_all, tnm_s_all = [], []
treat_all = []
# survival (cox risk + time/event)
dfs_r_all, dfs_t_all, dfs_e_all = [], [], []
os_r_all, os_t_all, os_e_all = [], [], []
# survival 1y/3y/5y logits (optional save)
dfs_log_all, os_log_all = [], []
for batch in loader:
# NOTE: dataset must return 19 items including treatment
if len(batch) != 19:
raise ValueError(f"Batch length mismatch: expected 19, got {len(batch)}. "
f"Please ensure Dataset __getitem__ returns treatment as the 19th item.")
(
pid, lesion, space, rad, pet, cli,
y_sub, y_tnm,
dfs_t, dfs_e,
os_t, os_e,
dfs1, dfs3, dfs5,
os1, os3, os5,
treatment
) = batch
lesion, space = lesion.to(device), space.to(device)
rad, pet, cli = rad.to(device), pet.to(device), cli.to(device)
y_sub, y_tnm = y_sub.to(device), y_tnm.to(device)
dfs_t, dfs_e = dfs_t.to(device), dfs_e.to(device)
os_t, os_e = os_t.to(device), os_e.to(device)
treatment = treatment.to(device)
dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(device)
os_y = torch.stack([os1, os3, os5 ], dim=1).to(device)
with torch.set_grad_enabled(train):
sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(
lesion, space, rad, pet, cli
)
loss = (
ce(sub_l, y_sub) +
ce(tnm_l, y_tnm) +
cox(dfs_r, dfs_t, dfs_e) +
cox(os_r, os_t, os_e) +
bce(dfs_log, dfs_y).mean() +
bce(os_log, os_y ).mean()
)
if train:
optimizer.zero_grad()
loss.backward()
optimizer.step()
losses.append(loss.item())
# ----- Collect predictions -----
sub_prob = torch.softmax(sub_l, dim=1)[:, 1] # subtype prob
tnm_prob = torch.softmax(tnm_l, dim=1) # [B,3]
sub_s_all.append(_ensure_numpy(sub_prob))
sub_y_all.append(_ensure_numpy(y_sub))
tnm_s_all.append(_ensure_numpy(tnm_prob))
tnm_y_all.append(_ensure_numpy(y_tnm))
treat_all.append(_ensure_numpy(treatment))
# survival
dfs_r_all.append(_ensure_numpy(dfs_r))
dfs_t_all.append(_ensure_numpy(dfs_t))
dfs_e_all.append(_ensure_numpy(dfs_e))
os_r_all.append(_ensure_numpy(os_r))
os_t_all.append(_ensure_numpy(os_t))
os_e_all.append(_ensure_numpy(os_e))
dfs_log_all.append(_ensure_numpy(dfs_log))
os_log_all.append(_ensure_numpy(os_log))
return (
float(np.mean(losses)),
np.concatenate(sub_y_all),
np.concatenate(sub_s_all),
np.concatenate(tnm_y_all),
np.concatenate(tnm_s_all),
np.concatenate(treat_all),
np.concatenate(dfs_r_all),
np.concatenate(dfs_t_all),
np.concatenate(dfs_e_all),
np.concatenate(os_r_all),
np.concatenate(os_t_all),
np.concatenate(os_e_all),
np.concatenate(dfs_log_all, axis=0), # [N,3]
np.concatenate(os_log_all, axis=0), # [N,3]
)
# =========================================================
# Evaluation by cohort (classification + survival)
# =========================================================
def evaluate_by_treatment(sub_y, sub_s, tnm_y, tnm_s, treat,
dfs_r, dfs_t, dfs_e, os_r, os_t, os_e):
results = {}
cohorts = {
"All": np.ones_like(treat, dtype=bool),
"Immune": treat == 0,
"Chemo": treat == 1,
}
for name, mask in cohorts.items():
if mask.sum() < 10:
continue
res = {}
# Subtype (binary)
res["Subtype_AUC"] = roc_auc_score(sub_y[mask], sub_s[mask])
res["Subtype_ACC"] = accuracy_score(sub_y[mask], (sub_s[mask] > 0.5).astype(int))
# TNM (multiclass macro AUC + ACC)
tnm_bin = label_binarize(tnm_y[mask], classes=[0, 1, 2])
res["TNM_AUC_macro"] = roc_auc_score(
tnm_bin, tnm_s[mask], average="macro", multi_class="ovr"
)
res["TNM_ACC"] = accuracy_score(
tnm_y[mask], np.argmax(tnm_s[mask], axis=1)
)
# Survival
dfs_c, dfs_b = _evaluate_survival_metrics(dfs_t[mask], dfs_e[mask], dfs_r[mask], time_point=30)
os_c, os_b = _evaluate_survival_metrics(os_t[mask], os_e[mask], os_r[mask], time_point=30)
res["DFS_C_index"] = dfs_c
res["DFS_Brier_30m"] = dfs_b
res["OS_C_index"] = os_c
res["OS_Brier_30m"] = os_b
results[name] = res
return results
# =========================================================
# Figure 7: KM + HR (per cohort, per endpoint)
# =========================================================
def plot_km_curve_with_hr(df, title, save_prefix):
"""
df must contain columns: time, event, group (Low/Mediate/High)
"""
kmf = KaplanMeierFitter()
fig, ax = plt.subplots(figsize=(8, 6), facecolor="white")
ax.set_facecolor("white")
colors = {"Low": "#91c7ae", "Mediate": "#f7b977", "High": "#d87c7c"}
groups = ["Low", "Mediate", "High"]
# plot KM
lines = {}
at_risk_table = []
times = np.arange(0, 70, 10)
for g in groups:
m = df["group"] == g
if m.sum() == 0:
continue
kmf.fit(df.loc[m, "time"], event_observed=df.loc[m, "event"], label=g)
kmf.plot_survival_function(
ax=ax, ci_show=True, linewidth=2, color=colors[g], marker="+"
)
lines[g] = ax.get_lines()[-1]
at_risk_table.append([np.sum(df.loc[m, "time"] >= t) for t in times])
# legend
handles = [lines[g] for g in groups if g in lines]
labels = ["Low", "Medium", "High"][:len(handles)]
ax.legend(handles, labels, title="Groups", loc="upper right",
frameon=True, framealpha=0.5, fontsize=12, title_fontsize=12)
# at risk numbers (optional, matches your style)
if len(at_risk_table) == 3:
low, mid, high = at_risk_table
for i, t in enumerate(times):
ax.text(t, -0.38, str(low[i]), color="#207f4c", fontsize=14, ha="center")
ax.text(t, -0.48, str(mid[i]), color="#fca106", fontsize=14, ha="center")
ax.text(t, -0.58, str(high[i]), color="#cc163a", fontsize=14, ha="center")
ax.text(-1, -0.28, "Number at risk", color="black", ha="center", fontsize=14)
ax.text(-10, -0.38, "Low", color="#207f4c", fontsize=14)
ax.text(-10, -0.48, "Medium", color="#fca106", fontsize=14)
ax.text(-10, -0.58, "High", color="#cc163a", fontsize=14)
# Cox HR + p-values
df2 = df.copy()
df2["group_code"] = df2["group"].map({"Low": 0, "Mediate": 1, "High": 2})
cph = CoxPHFitter()
cph.fit(df2[["time", "event", "group_code"]], duration_col="time", event_col="event")
coef = float(cph.params_["group_code"])
se = float(cph.standard_errors_["group_code"])
hr_med_vs_low = np.exp(coef * 1)
hr_high_vs_low = np.exp(coef * 2)
z_med = (coef * 1) / se
p_med = 2 * (1 - norm.cdf(abs(z_med)))
z_high = (coef * 2) / se
p_high = 2 * (1 - norm.cdf(abs(z_high)))
# logrank
res_lr = multivariate_logrank_test(df2["time"], df2["group"], df2["event"])
# C-index + brier (proxy)
c_index, brier = _evaluate_survival_metrics(df2["time"].values, df2["event"].values,
df2["group_code"].values, time_point=30)
ax.text(25, 0.46, f"P(log-rank)={res_lr.p_value:.3f}", fontsize=12)
ax.text(25, 0.36, f"C-index={c_index:.3f}", fontsize=12)
ax.text(25, 0.26, f"Brier(30m)={brier:.3f}", fontsize=12)
ax.text(25, 0.16, f"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}", fontsize=12)
ax.text(25, 0.06, f"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}", fontsize=12)
# cosmetics
ax.spines["top"].set_visible(False)
ax.spines["right"].set_visible(False)
ax.set_title(title, fontsize=14)
ax.set_xlabel("Time since treatment start (months)", fontsize=14)
ax.set_ylabel("Survival probability", fontsize=14)
ax.set_ylim(0, 1.05)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(save_prefix + ".png", dpi=600, bbox_inches="tight")
plt.savefig(save_prefix + ".pdf", dpi=600, bbox_inches="tight")
plt.close()
return save_prefix
def generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=("val", "test")):
"""
Load saved dfs/os arrays and generate KM+HR for Immune/Chemo separately.
"""
os.makedirs(fig_dir, exist_ok=True)
for split in which_split:
# load arrays
trt = np.load(os.path.join(result_dir, f"treatment_{split}.npy"))
dfs_r = np.load(os.path.join(result_dir, f"dfs_{split}_risk.npy"))
dfs_t = np.load(os.path.join(result_dir, f"dfs_{split}_time.npy"))
dfs_e = np.load(os.path.join(result_dir, f"dfs_{split}_event.npy"))
os_r = np.load(os.path.join(result_dir, f"os_{split}_risk.npy"))
os_t = np.load(os.path.join(result_dir, f"os_{split}_time.npy"))
os_e = np.load(os.path.join(result_dir, f"os_{split}_event.npy"))
for cohort_name, mask in {
"Immune": trt == 0,
"Chemo": trt == 1
}.items():
if mask.sum() < 20:
print(f"[Figure7] Skip {split}-{cohort_name}: too few samples ({mask.sum()})")
continue
# DFS groups
dfs_group = _risk_to_groups(dfs_r[mask])
df_dfs = pd.DataFrame({
"time": dfs_t[mask],
"event": dfs_e[mask].astype(int),
"group": dfs_group
})
# OS groups
os_group = _risk_to_groups(os_r[mask])
df_os = pd.DataFrame({
"time": os_t[mask],
"event": os_e[mask].astype(int),
"group": os_group
})
# save CSV (optional, for reproducibility)
df_dfs.to_csv(os.path.join(result_dir, f"dfs_{split}_{cohort_name}.csv"), index=False)
df_os.to_csv(os.path.join(result_dir, f"os_{split}_{cohort_name}.csv"), index=False)
# plot
plot_km_curve_with_hr(
df_dfs,
title=f"Disease-Free Survival (DFS) — Kaplan-Meier Curves\n{cohort_name} {split} set (n={mask.sum()})",
save_prefix=os.path.join(fig_dir, f"Figure7_DFS_{cohort_name}_{split}")
)
plot_km_curve_with_hr(
df_os,
title=f"Overall Survival (OS) — Kaplan-Meier Curves\n{cohort_name} {split} set (n={mask.sum()})",
save_prefix=os.path.join(fig_dir, f"Figure7_OS_{cohort_name}_{split}")
)
print("✔ Figure 7 generated (DFS/OS KM + HR) for Immune/Chemo.")
# =========================================================
# Main
# =========================================================
def main():
# -------------------------
# Dataset (must return treatment as 19th item)
# -------------------------
from mm_dls.PatientDataset import PatientDataset
dataset = PatientDataset(
data_root="/path/to/DATA_ROOT",
clinical_csv="/path/to/clinical.csv",
radiomics_npy="/path/to/radiomics.npy",
pet_npy="/path/to/pet.npy",
n_slices=N_SLICES,
img_size=IMG_SIZE,
)
n_train = int(0.6 * len(dataset))
n_val = int(0.2 * len(dataset))
n_test = len(dataset) - n_train - n_val
train_set, val_set, test_set = random_split(dataset, [n_train, n_val, n_test])
loaders = {
"train": DataLoader(train_set, BATCH_SIZE, shuffle=True, num_workers=4),
"val": DataLoader(val_set, BATCH_SIZE, shuffle=False, num_workers=4),
"test": DataLoader(test_set, BATCH_SIZE, shuffle=False, num_workers=4),
}
# -------------------------
# Model
# -------------------------
model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
best_val_loss = 1e9
wait = 0
# -------------------------
# Training
# -------------------------
for epoch in range(1, EPOCHS + 1):
tr = run_epoch_verbose(model, loaders["train"], optimizer, DEVICE, train=True)
va = run_epoch_verbose(model, loaders["val"], optimizer, DEVICE, train=False)
tr_loss = tr[0]
va_loss = va[0]
# unpack val for metrics
_, sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e, _, _ = va
metrics = evaluate_by_treatment(sy, ss, ty, ts, trt, dfs_r, dfs_t, dfs_e, os_r, os_t, os_e)
print(f"\n[Epoch {epoch:03d}] Train Loss={tr_loss:.3f} | Val Loss={va_loss:.3f}")
for k, v in metrics.items():
print(
f" {k:7s} | "
f"Subtype AUC={v['Subtype_AUC']:.3f} | "
f"TNM AUC={v['TNM_AUC_macro']:.3f} | "
f"DFS C-index={v['DFS_C_index']:.3f} | "
f"OS C-index={v['OS_C_index']:.3f}"
)
# early stopping
if va_loss < best_val_loss:
best_val_loss = va_loss
wait = 0
torch.save(model.state_dict(), os.path.join(SAVE_DIR, "best_model.pt"))
print(" ✓ Best model updated")
else:
wait += 1
if wait >= PATIENCE:
print("\n⏹ Early stopping triggered")
break
# -------------------------
# Inference (best model)
# -------------------------
print("\nRunning inference with best model...")
model.load_state_dict(torch.load(os.path.join(SAVE_DIR, "best_model.pt"), map_location=DEVICE))
for split in ["train", "val", "test"]:
out = run_epoch_verbose(model, loaders[split], optimizer, DEVICE, train=False)
(
loss,
sy, ss,
ty, ts,
trt,
dfs_r, dfs_t, dfs_e,
os_r, os_t, os_e,
dfs_log, os_log
) = out
# classification
np.save(os.path.join(SAVE_DIR, f"subtype_{split}_labels.npy"), sy)
np.save(os.path.join(SAVE_DIR, f"subtype_{split}_scores.npy"), ss)
np.save(os.path.join(SAVE_DIR, f"tnm_{split}_labels.npy"), ty)
np.save(os.path.join(SAVE_DIR, f"tnm_{split}_scores.npy"), ts)
np.save(os.path.join(SAVE_DIR, f"treatment_{split}.npy"), trt)
# survival (cox risk + time/event)
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_risk.npy"), dfs_r)
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_time.npy"), dfs_t)
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_event.npy"), dfs_e)
np.save(os.path.join(SAVE_DIR, f"os_{split}_risk.npy"), os_r)
np.save(os.path.join(SAVE_DIR, f"os_{split}_time.npy"), os_t)
np.save(os.path.join(SAVE_DIR, f"os_{split}_event.npy"), os_e)
# 1y/3y/5y logits (optional, for AUC at specific horizons)
np.save(os.path.join(SAVE_DIR, f"dfs_{split}_logits_1y3y5y.npy"), dfs_log)
np.save(os.path.join(SAVE_DIR, f"os_{split}_logits_1y3y5y.npy"), os_log)
print(f"{split:5s} | loss={loss:.3f} | Immune={np.sum(trt==0)} Chemo={np.sum(trt==1)}")
print("\n✓ Inference completed. Results saved.")
# -------------------------
# Figure: Immune/Chemo KM + HR
# -------------------------
print("\nGenerating Figure (KM + HR) ...")
generate_figure_from_saved(result_dir=SAVE_DIR, fig_dir=FIG_DIR, which_split=("val", "test"))
print("✓ Figure done. Files saved under ./figures")
if __name__ == "__main__":
main()