MM-DLS / test.py
FangDai's picture
Upload 4 files
bf41494 verified
# test_mm_dls.py
# =========================================================
# πŸ” Minimal test for MM-DLS pipeline
# - CUDA
# - forward / loss
# - pandas / lifelines (GLIBCXX check)
# =========================================================
import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset
import pandas as pd
from lifelines import KaplanMeierFitter
from lifelines.utils import concordance_index
# ---------------------------------------------------------
# Project path
# ---------------------------------------------------------
PROJECT_ROOT = os.path.abspath(".")
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
# ---------------------------------------------------------
# Imports from mm_dls
# ---------------------------------------------------------
from mm_dls.HierMM_DLS import HierMM_DLS
from mm_dls.CoxphLoss import CoxPHLoss
from mm_dls.PatientDataset import PatientDataset
# =========================================================
# Basic config (VERY SMALL)
# =========================================================
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)
BATCH_SIZE = 2
NUM_SUBTYPES = 2
NUM_TNM = 3
N_SLICES = 30
IMG_SIZE = 224
# =========================================================
# Test Dataset Loader
# =========================================================
def get_test_loader():
dataset = PatientDataset(
data_root="/path/to/DATA_ROOT",
clinical_csv="/path/to/clinical.csv",
radiomics_npy="/path/to/radiomics.npy",
pet_npy="/path/to/pet.npy",
n_slices=N_SLICES,
img_size=IMG_SIZE,
)
# πŸ”‘ εͺ取前 8 δΈͺ样本
idx = list(range(min(8, len(dataset))))
subset = Subset(dataset, idx)
loader = DataLoader(
subset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=2,
)
return loader
# =========================================================
# One forward + loss
# =========================================================
def test_forward_and_loss():
print("\n[TEST] Forward + Loss")
loader = get_test_loader()
model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)
ce = nn.CrossEntropyLoss()
bce = nn.BCEWithLogitsLoss()
cox = CoxPHLoss()
model.eval()
for batch in loader:
assert len(batch) == 19, f"Dataset must return 19 items, got {len(batch)}"
(
pid, lesion, space, rad, pet, cli,
y_sub, y_tnm,
dfs_t, dfs_e,
os_t, os_e,
dfs1, dfs3, dfs5,
os1, os3, os5,
treatment
) = batch
lesion, space = lesion.to(DEVICE), space.to(DEVICE)
rad, pet, cli = rad.to(DEVICE), pet.to(DEVICE), cli.to(DEVICE)
y_sub, y_tnm = y_sub.to(DEVICE), y_tnm.to(DEVICE)
dfs_t, dfs_e = dfs_t.to(DEVICE), dfs_e.to(DEVICE)
os_t, os_e = os_t.to(DEVICE), os_e.to(DEVICE)
dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(DEVICE)
os_y = torch.stack([os1, os3, os5 ], dim=1).to(DEVICE)
with torch.no_grad():
sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(
lesion, space, rad, pet, cli
)
loss = (
ce(sub_l, y_sub) +
ce(tnm_l, y_tnm) +
cox(dfs_r, dfs_t, dfs_e) +
cox(os_r, os_t, os_e) +
bce(dfs_log, dfs_y) +
bce(os_log, os_y)
)
print(" βœ“ Forward OK | Loss =", float(loss))
break
# =========================================================
# Test pandas + lifelines (GLIBCXX killer)
# =========================================================
def test_pandas_lifelines():
print("\n[TEST] pandas + lifelines")
# fake survival data
time = np.array([10, 12, 8, 20, 15, 25])
event = np.array([1, 1, 0, 1, 0, 0])
risk = np.array([0.9, 0.8, 0.2, 1.2, 0.3, 0.4])
# pandas
df = pd.DataFrame({
"time": time,
"event": event,
"risk": risk
})
print(" pandas OK:", df.shape)
# C-index
cidx = concordance_index(df["time"], -df["risk"], df["event"])
print(" C-index =", round(cidx, 3))
# KM
kmf = KaplanMeierFitter()
kmf.fit(df["time"], event_observed=df["event"])
surv_10 = kmf.predict(10)
print(" KM survival@10 =", float(surv_10))
print(" βœ“ lifelines OK")
# =========================================================
# Main
# =========================================================
if __name__ == "__main__":
print("\n==============================")
print(" MM-DLS TEST START ")
print("==============================")
test_forward_and_loss()
test_pandas_lifelines()
print("\nβœ… ALL TESTS PASSED")