File size: 4,919 Bytes
bf41494 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
# test_mm_dls.py
# =========================================================
# 🔍 Minimal test for MM-DLS pipeline
# - CUDA
# - forward / loss
# - pandas / lifelines (GLIBCXX check)
# =========================================================
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import pandas as pd
from lifelines import KaplanMeierFitter
from lifelines.utils import concordance_index
# ---------------------------------------------------------
# Project path
# ---------------------------------------------------------
PROJECT_ROOT = os.path.abspath(".")
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
# ---------------------------------------------------------
# Imports from mm_dls
# ---------------------------------------------------------
from mm_dls.HierMM_DLS import HierMM_DLS
from mm_dls.CoxphLoss import CoxPHLoss
from mm_dls.PatientDataset import PatientDataset
# =========================================================
# Basic config (VERY SMALL)
# =========================================================
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
BATCH_SIZE = 2
NUM_SUBTYPES = 2
NUM_TNM = 3
N_SLICES = 30
IMG_SIZE = 224
# =========================================================
# Test Dataset Loader
# =========================================================
def get_test_loader():
dataset = PatientDataset(
data_root="/path/to/DATA_ROOT",
clinical_csv="/path/to/clinical.csv",
radiomics_npy="/path/to/radiomics.npy",
pet_npy="/path/to/pet.npy",
n_slices=N_SLICES,
img_size=IMG_SIZE,
)
# 🔑 只取前 8 个样本
idx = list(range(min(8, len(dataset))))
subset = Subset(dataset, idx)
loader = DataLoader(
subset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=2,
)
return loader
# =========================================================
# One forward + loss
# =========================================================
def test_forward_and_loss():
print("\n[TEST] Forward + Loss")
loader = get_test_loader()
model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()
cox = CoxPHLoss()
model.eval()
for batch in loader:
assert len(batch) == 19, f"Dataset must return 19 items, got {len(batch)}"
(
pid, lesion, space, rad, pet, cli,
y_sub, y_tnm,
dfs_t, dfs_e,
os_t, os_e,
dfs1, dfs3, dfs5,
os1, os3, os5,
treatment
) = batch
lesion, space = lesion.to(DEVICE), space.to(DEVICE)
rad, pet, cli = rad.to(DEVICE), pet.to(DEVICE), cli.to(DEVICE)
y_sub, y_tnm = y_sub.to(DEVICE), y_tnm.to(DEVICE)
dfs_t, dfs_e = dfs_t.to(DEVICE), dfs_e.to(DEVICE)
os_t, os_e = os_t.to(DEVICE), os_e.to(DEVICE)
dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(DEVICE)
os_y = torch.stack([os1, os3, os5 ], dim=1).to(DEVICE)
with torch.no_grad():
sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(
lesion, space, rad, pet, cli
)
loss = (
ce(sub_l, y_sub) +
ce(tnm_l, y_tnm) +
cox(dfs_r, dfs_t, dfs_e) +
cox(os_r, os_t, os_e) +
bce(dfs_log, dfs_y) +
bce(os_log, os_y)
)
print(" ✓ Forward OK | Loss =", float(loss))
break
# =========================================================
# Test pandas + lifelines (GLIBCXX killer)
# =========================================================
def test_pandas_lifelines():
print("\n[TEST] pandas + lifelines")
# fake survival data
time = np.array([10, 12, 8, 20, 15, 25])
event = np.array([1, 1, 0, 1, 0, 0])
risk = np.array([0.9, 0.8, 0.2, 1.2, 0.3, 0.4])
# pandas
df = pd.DataFrame({
"time": time,
"event": event,
"risk": risk
})
print(" pandas OK:", df.shape)
# C-index
cidx = concordance_index(df["time"], -df["risk"], df["event"])
print(" C-index =", round(cidx, 3))
# KM
kmf = KaplanMeierFitter()
kmf.fit(df["time"], event_observed=df["event"])
surv_10 = kmf.predict(10)
print(" KM survival@10 =", float(surv_10))
print(" ✓ lifelines OK")
# =========================================================
# Main
# =========================================================
if __name__ == "__main__":
print("\n==============================")
print(" MM-DLS TEST START ")
print("==============================")
test_forward_and_loss()
test_pandas_lifelines()
print("\n✅ ALL TESTS PASSED")
|