File size: 4,919 Bytes
bf41494
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
# test_mm_dls.py
# =========================================================
# 🔍 Minimal test for MM-DLS pipeline
# - CUDA
# - forward / loss
# - pandas / lifelines (GLIBCXX check)
# =========================================================

import os
import sys
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Subset

import pandas as pd
from lifelines import KaplanMeierFitter
from lifelines.utils import concordance_index

# ---------------------------------------------------------
# Project path
# ---------------------------------------------------------
PROJECT_ROOT = os.path.abspath(".")
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# ---------------------------------------------------------
# Imports from mm_dls
# ---------------------------------------------------------
from mm_dls.HierMM_DLS import HierMM_DLS
from mm_dls.CoxphLoss import CoxPHLoss
from mm_dls.PatientDataset import PatientDataset


# =========================================================
# Basic config (VERY SMALL)
# =========================================================
DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)

BATCH_SIZE = 2
NUM_SUBTYPES = 2
NUM_TNM = 3
N_SLICES = 30
IMG_SIZE = 224


# =========================================================
# Test Dataset Loader
# =========================================================
def get_test_loader():
    dataset = PatientDataset(
        data_root="/path/to/DATA_ROOT",
        clinical_csv="/path/to/clinical.csv",
        radiomics_npy="/path/to/radiomics.npy",
        pet_npy="/path/to/pet.npy",
        n_slices=N_SLICES,
        img_size=IMG_SIZE,
    )

    # 🔑 只取前 8 个样本
    idx = list(range(min(8, len(dataset))))
    subset = Subset(dataset, idx)

    loader = DataLoader(
        subset,
        batch_size=BATCH_SIZE,
        shuffle=False,
        num_workers=2,
    )
    return loader


# =========================================================
# One forward + loss
# =========================================================
def test_forward_and_loss():
    print("\n[TEST] Forward + Loss")

    loader = get_test_loader()
    model = HierMM_DLS(NUM_SUBTYPES, NUM_TNM).to(DEVICE)

    ce  = nn.CrossEntropyLoss()
    bce = nn.BCEWithLogitsLoss()
    cox = CoxPHLoss()

    model.eval()

    for batch in loader:
        assert len(batch) == 19, f"Dataset must return 19 items, got {len(batch)}"

        (
            pid, lesion, space, rad, pet, cli,
            y_sub, y_tnm,
            dfs_t, dfs_e,
            os_t, os_e,
            dfs1, dfs3, dfs5,
            os1, os3, os5,
            treatment
        ) = batch

        lesion, space = lesion.to(DEVICE), space.to(DEVICE)
        rad, pet, cli = rad.to(DEVICE), pet.to(DEVICE), cli.to(DEVICE)
        y_sub, y_tnm  = y_sub.to(DEVICE), y_tnm.to(DEVICE)
        dfs_t, dfs_e = dfs_t.to(DEVICE), dfs_e.to(DEVICE)
        os_t,  os_e  = os_t.to(DEVICE),  os_e.to(DEVICE)

        dfs_y = torch.stack([dfs1, dfs3, dfs5], dim=1).to(DEVICE)
        os_y  = torch.stack([os1,  os3,  os5 ], dim=1).to(DEVICE)

        with torch.no_grad():
            sub_l, tnm_l, dfs_r, os_r, dfs_log, os_log = model(
                lesion, space, rad, pet, cli
            )

            loss = (
                ce(sub_l, y_sub) +
                ce(tnm_l, y_tnm) +
                cox(dfs_r, dfs_t, dfs_e) +
                cox(os_r,  os_t,  os_e) +
                bce(dfs_log, dfs_y) +
                bce(os_log,  os_y)
            )

        print("  ✓ Forward OK | Loss =", float(loss))
        break


# =========================================================
# Test pandas + lifelines (GLIBCXX killer)
# =========================================================
def test_pandas_lifelines():
    print("\n[TEST] pandas + lifelines")

    # fake survival data
    time  = np.array([10, 12, 8, 20, 15, 25])
    event = np.array([1, 1, 0, 1, 0, 0])
    risk  = np.array([0.9, 0.8, 0.2, 1.2, 0.3, 0.4])

    # pandas
    df = pd.DataFrame({
        "time": time,
        "event": event,
        "risk": risk
    })

    print("  pandas OK:", df.shape)

    # C-index
    cidx = concordance_index(df["time"], -df["risk"], df["event"])
    print("  C-index =", round(cidx, 3))

    # KM
    kmf = KaplanMeierFitter()
    kmf.fit(df["time"], event_observed=df["event"])
    surv_10 = kmf.predict(10)

    print("  KM survival@10 =", float(surv_10))
    print("  ✓ lifelines OK")


# =========================================================
# Main
# =========================================================
if __name__ == "__main__":
    print("\n==============================")
    print(" MM-DLS TEST START ")
    print("==============================")

    test_forward_and_loss()
    test_pandas_lifelines()

    print("\n✅ ALL TESTS PASSED")