FangDai commited on
Commit
a19a7aa
·
verified ·
1 Parent(s): af44cf9

Upload 11 files

Browse files
mm-dls/ClinicalFusionModel.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from sklearn.metrics import roc_auc_score, accuracy_score, f1_score
5
+ import numpy as np
6
+
7
+ class PatientLevelFusionModel(nn.Module):
8
+ def __init__(self, input_dim=128, pet_dim=5, clinical_dim=6):
9
+ super().__init__()
10
+ self.fc_merge = nn.Sequential(
11
+ nn.Linear(input_dim * 2 + 128, 256), # lesion_fused + space_fused + radiomics_feat
12
+ nn.ReLU(),
13
+ nn.Dropout(0.3),
14
+ nn.Linear(256, 128),
15
+ nn.ReLU()
16
+ )
17
+ total_feat = 128 + pet_dim + clinical_dim
18
+ self.fc_dfs = nn.Linear(total_feat, 1)
19
+ self.fc_os = nn.Linear(total_feat, 1)
20
+ self.fc_cls = nn.Linear(total_feat, 1)
21
+
22
+ def forward(self, lesion_feat, space_feat, radiomics_feat, pet_feat, clinical_feat):
23
+ x = torch.cat([lesion_feat, space_feat, radiomics_feat], dim=1)
24
+ fused = self.fc_merge(x) # shape [B, 128]
25
+ full_feat = torch.cat([fused, pet_feat, clinical_feat], dim=1)
26
+ dfs = self.fc_dfs(full_feat).squeeze(1)
27
+ os = self.fc_os(full_feat).squeeze(1)
28
+ cls = self.fc_cls(full_feat) # keep [B, 1] for BCEWithLogits
29
+ return dfs, os, cls
30
+
31
+ @staticmethod
32
+ def classification_metrics(logits, labels):
33
+ probs = torch.sigmoid(logits).detach().cpu().numpy()
34
+ labels = labels.detach().cpu().numpy()
35
+ try:
36
+ auc = roc_auc_score(labels, probs)
37
+ except:
38
+ auc = 0.0
39
+ preds = (probs >= 0.5).astype(int)
40
+ acc = accuracy_score(labels, preds)
41
+ f1 = f1_score(labels, preds)
42
+ return auc, acc, f1
43
+
44
+ @staticmethod
45
+ def c_index(preds, durations, events):
46
+ preds = preds.detach().cpu().numpy()
47
+ durations = durations.detach().cpu().numpy()
48
+ events = events.detach().cpu().numpy()
49
+
50
+ n = len(preds)
51
+ num = 0
52
+ den = 0
53
+ for i in range(n):
54
+ for j in range(i + 1, n):
55
+ if durations[i] == durations[j]:
56
+ continue
57
+ if events[i] == 1 and durations[i] < durations[j]:
58
+ den += 1
59
+ if preds[i] < preds[j]:
60
+ num += 1
61
+ elif preds[i] == preds[j]:
62
+ num += 0.5
63
+ elif events[j] == 1 and durations[j] < durations[i]:
64
+ den += 1
65
+ if preds[j] < preds[i]:
66
+ num += 1
67
+ elif preds[j] == preds[i]:
68
+ num += 0.5
69
+ return num / den if den > 0 else 0.0
mm-dls/CoxphLoss.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class CoxPHLoss(nn.Module):
5
+ """
6
+ 实现 Cox Proportional Hazards Loss (负对数偏似然)
7
+ """
8
+ def __init__(self):
9
+ super(CoxPHLoss, self).__init__()
10
+
11
+ def forward(self, risk_pred, durations, events):
12
+ """
13
+ risk_pred: [batch_size] 模型输出的风险评分(未经过sigmoid)
14
+ durations: [batch_size] 存活时间
15
+ events: [batch_size] 事件发生标志 (1=死亡/复发, 0=删失)
16
+ """
17
+ # 以时间降序排序(从最长生存期开始)
18
+ order = torch.argsort(durations, descending=True)
19
+ risk_pred = risk_pred[order]
20
+ events = events[order]
21
+
22
+ # 累加风险值 log-sum-exp 以稳定训练
23
+ log_cumsum = torch.logcumsumexp(risk_pred, dim=0)
24
+ diff = risk_pred - log_cumsum
25
+ loss = -torch.sum(diff * events) / torch.sum(events + 1e-8) # 防止除以 0
26
+
27
+ return loss
mm-dls/FakePatientDataset.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import numpy as np
4
+ import random
5
+
6
+
7
+ class FakePatientDataset(Dataset):
8
+ """
9
+ Controllable synthetic multimodal + survival dataset
10
+
11
+ You can explicitly control:
12
+ - Final AUC (classification)
13
+ - Final C-index (DFS / OS)
14
+ via interpretable hyperparameters.
15
+
16
+ Output: 19 items (aligned with run_epoch_verbose)
17
+ """
18
+
19
+ def __init__(
20
+ self,
21
+ n_patients=3000,
22
+ n_slices=30,
23
+ img_size=224,
24
+ num_subtypes=2,
25
+ num_tnm=3,
26
+ seed=2131,
27
+
28
+ # =========================
29
+ # ---- AUC controllers ----
30
+ # =========================
31
+ tabular_signal_dims=16, # ↑ dims → ↑ AUC
32
+ tabular_signal_strength=0.40, # ↑ strength → ↑ AUC
33
+ label_flip_rate=0.10, # ↑ noise → ↓ AUC
34
+
35
+ # =========================
36
+ # ---- C-index controllers
37
+ # =========================
38
+ risk_noise=1.0, # ↑ noise → ↓ C-index
39
+ dfs_time_noise=6.0,
40
+ os_time_noise=7.0,
41
+ event_sharpness=1.3, # ↑ → HR更明显
42
+ ):
43
+ super().__init__()
44
+ random.seed(seed)
45
+ np.random.seed(seed)
46
+
47
+ self.n = n_patients
48
+ self.n_slices = n_slices
49
+ self.img_size = img_size
50
+ self.num_subtypes = num_subtypes
51
+ self.num_tnm = num_tnm
52
+
53
+ self.tabular_signal_dims = tabular_signal_dims
54
+ self.tabular_signal_strength = tabular_signal_strength
55
+ self.label_flip_rate = label_flip_rate
56
+
57
+ self.risk_noise = risk_noise
58
+ self.dfs_time_noise = dfs_time_noise
59
+ self.os_time_noise = os_time_noise
60
+ self.event_sharpness = event_sharpness
61
+
62
+ # =========================
63
+ # Treatment cohort
64
+ # =========================
65
+ self.treatment = np.random.choice(
66
+ [0, 1],
67
+ size=self.n,
68
+ p=[2374 / (2374 + 1790), 1790 / (2374 + 1790)]
69
+ ).astype(np.int64)
70
+
71
+ # =========================
72
+ # Ground-truth labels
73
+ # =========================
74
+ self.subtype = np.random.randint(0, num_subtypes, size=self.n).astype(np.int64)
75
+ self.tnm = np.random.randint(0, num_tnm, size=self.n).astype(np.int64)
76
+
77
+ # =========================
78
+ # Latent biological risk
79
+ # =========================
80
+ base_risk = (
81
+ 0.6 * self.subtype +
82
+ 0.5 * self.tnm +
83
+ 0.4 * self.treatment +
84
+ np.random.normal(0, self.risk_noise, size=self.n)
85
+ )
86
+
87
+ # =========================
88
+ # Survival times
89
+ # =========================
90
+ self.dfs_time = np.clip(
91
+ 60 - 7.0 * base_risk + np.random.normal(0, self.dfs_time_noise, size=self.n),
92
+ 3, 96
93
+ )
94
+ self.os_time = np.clip(
95
+ 75 - 8.5 * base_risk + np.random.normal(0, self.os_time_noise, size=self.n),
96
+ 6, 120
97
+ )
98
+
99
+ # =========================
100
+ # Event indicators (soft)
101
+ # =========================
102
+ p_dfs = 1 / (1 + np.exp(-(base_risk - 0.2) * self.event_sharpness))
103
+ p_os = 1 / (1 + np.exp(-(base_risk - 0.4) * self.event_sharpness))
104
+
105
+ self.dfs_event = (np.random.rand(self.n) < p_dfs).astype(np.float32)
106
+ self.os_event = (np.random.rand(self.n) < p_os).astype(np.float32)
107
+
108
+ # =========================
109
+ # Time-point labels
110
+ # =========================
111
+ self.dfs_1y = (self.dfs_time <= 12).astype(np.float32)
112
+ self.dfs_3y = (self.dfs_time <= 36).astype(np.float32)
113
+ self.dfs_5y = (self.dfs_time <= 60).astype(np.float32)
114
+
115
+ self.os_1y = (self.os_time <= 12).astype(np.float32)
116
+ self.os_3y = (self.os_time <= 36).astype(np.float32)
117
+ self.os_5y = (self.os_time <= 60).astype(np.float32)
118
+
119
+ def __len__(self):
120
+ return self.n
121
+
122
+ def __getitem__(self, idx):
123
+ s = int(self.subtype[idx])
124
+ t = int(self.tnm[idx])
125
+ tr = int(self.treatment[idx])
126
+
127
+ # =========================
128
+ # Label noise (controls AUC ceiling)
129
+ # =========================
130
+ if np.random.rand() < self.label_flip_rate:
131
+ s = 1 - s
132
+
133
+ # =========================
134
+ # IMAGE: very weak signal
135
+ # =========================
136
+ base_img = np.random.normal(0.5, 0.30, (self.img_size, self.img_size)).astype(np.float32)
137
+ base_img += 0.03 * s + 0.02 * t + 0.02 * tr
138
+ base_img = np.clip(base_img, 0, 1)
139
+
140
+ lesion = torch.from_numpy(
141
+ np.repeat(base_img[None, None, ...], self.n_slices, axis=0)
142
+ )
143
+ space = lesion.clone()
144
+
145
+ # =========================
146
+ # TABULAR: main discriminative signal
147
+ # =========================
148
+ radiomics = np.random.normal(0, 1.0, 128).astype(np.float32)
149
+ radiomics[:self.tabular_signal_dims] += (
150
+ self.tabular_signal_strength * s +
151
+ 0.7 * self.tabular_signal_strength * t +
152
+ np.random.normal(0, 0.8, self.tabular_signal_dims)
153
+ )
154
+
155
+ pet = np.random.normal(0, 1.0, 5).astype(np.float32)
156
+ pet[:2] += 0.5 * self.tabular_signal_strength * s + np.random.normal(0, 0.7, 2)
157
+
158
+ clinical = np.random.normal(0, 1.0, 6).astype(np.float32)
159
+ clinical[:3] += 0.5 * self.tabular_signal_strength * t + np.random.normal(0, 0.7, 3)
160
+
161
+ return (
162
+ f"P{idx:04d}",
163
+
164
+ lesion.float(),
165
+ space.float(),
166
+
167
+ torch.from_numpy(radiomics),
168
+ torch.from_numpy(pet),
169
+ torch.from_numpy(clinical),
170
+
171
+ torch.tensor(s, dtype=torch.long),
172
+ torch.tensor(t, dtype=torch.long),
173
+
174
+ torch.tensor(self.dfs_time[idx], dtype=torch.float32),
175
+ torch.tensor(self.dfs_event[idx], dtype=torch.float32),
176
+
177
+ torch.tensor(self.os_time[idx], dtype=torch.float32),
178
+ torch.tensor(self.os_event[idx], dtype=torch.float32),
179
+
180
+ torch.tensor(self.dfs_1y[idx], dtype=torch.float32),
181
+ torch.tensor(self.dfs_3y[idx], dtype=torch.float32),
182
+ torch.tensor(self.dfs_5y[idx], dtype=torch.float32),
183
+
184
+ torch.tensor(self.os_1y[idx], dtype=torch.float32),
185
+ torch.tensor(self.os_3y[idx], dtype=torch.float32),
186
+ torch.tensor(self.os_5y[idx], dtype=torch.float32),
187
+
188
+ torch.tensor(tr, dtype=torch.long),
189
+ )
mm-dls/HierMM_DLS.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from mm_dls.ModelLesionEncoder import LesionEncoder
6
+ from mm_dls.ModelSpaceEncoder import SpaceEncoder
7
+ from mm_dls.LesionAttentionFusion import LesionAttentionFusion
8
+
9
+
10
+ class HierMM_DLS(nn.Module):
11
+ """
12
+ Hierarchical multi-task model:
13
+ Stage-1: subtype classification + TNM classification
14
+ Stage-2: survival Cox risks (DFS/OS) conditioned on subtype/TNM soft embeddings
15
+ Stage-3: fixed-horizon binary classification (DFS/OS at 1y/3y/5y) logits
16
+
17
+ Inputs:
18
+ lesion_vol: [B,S,1,H,W]
19
+ space_vol : [B,S,1,H,W]
20
+ radiomics : [B,128]
21
+ pet : [B,5]
22
+ clinical : [B,C]
23
+
24
+ Outputs:
25
+ subtype_logits: [B, K_sub]
26
+ tnm_logits : [B, K_tnm]
27
+ dfs_risk : [B]
28
+ os_risk : [B]
29
+ dfs_logits : [B,3] (1y,3y,5y)
30
+ os_logits : [B,3]
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ num_subtypes: int,
36
+ num_tnm: int,
37
+ img_feat_dim: int = 128,
38
+ radiomics_dim: int = 128,
39
+ pet_dim: int = 5,
40
+ clinical_dim: int = 6,
41
+ task_emb_dim: int = 32,
42
+ dropout: float = 0.3,
43
+ ):
44
+ super().__init__()
45
+
46
+ self.lesion_encoder = LesionEncoder(input_channels=1, feature_dim=img_feat_dim)
47
+ self.space_encoder = SpaceEncoder(input_channels=1, feature_dim=img_feat_dim)
48
+
49
+ self.lesion_fuser = LesionAttentionFusion(img_feat_dim, img_feat_dim)
50
+ self.space_fuser = LesionAttentionFusion(img_feat_dim, img_feat_dim)
51
+
52
+ fused_base_dim = img_feat_dim * 2 + radiomics_dim + pet_dim + clinical_dim
53
+
54
+ self.shared_up = nn.Sequential(
55
+ nn.Linear(fused_base_dim, 256),
56
+ nn.ReLU(),
57
+ nn.Dropout(dropout),
58
+ nn.Linear(256, 128),
59
+ nn.ReLU(),
60
+ )
61
+
62
+ self.subtype_head = nn.Linear(128, num_subtypes)
63
+ self.tnm_head = nn.Linear(128, num_tnm)
64
+
65
+ self.subtype_emb = nn.Embedding(num_subtypes, task_emb_dim)
66
+ self.tnm_emb = nn.Embedding(num_tnm, task_emb_dim)
67
+
68
+ surv_in = 128 + task_emb_dim * 2
69
+ self.surv_mlp = nn.Sequential(
70
+ nn.Linear(surv_in, 128),
71
+ nn.ReLU(),
72
+ nn.Dropout(dropout),
73
+ )
74
+
75
+ # Cox risks
76
+ self.dfs_head = nn.Linear(128, 1)
77
+ self.os_head = nn.Linear(128, 1)
78
+
79
+ # Fixed-horizon classification logits (1y/3y/5y)
80
+ self.dfs_cls = nn.Linear(128, 3)
81
+ self.os_cls = nn.Linear(128, 3)
82
+
83
+ def _encode_volume(self, encoder, vol):
84
+ # vol: [B,S,1,H,W]
85
+ B, S, C, H, W = vol.shape
86
+ x = vol.view(B * S, C, H, W)
87
+ feat = encoder(x) # [B*S, D]
88
+ feat = feat.view(B, S, -1) # [B,S,D]
89
+ return feat
90
+
91
+ def forward(self, lesion_vol, space_vol, radiomics, pet, clinical):
92
+ lesion_seq = self._encode_volume(self.lesion_encoder, lesion_vol)
93
+ space_seq = self._encode_volume(self.space_encoder, space_vol)
94
+
95
+ lesion_f = self.lesion_fuser(lesion_seq) # [B,D]
96
+ space_f = self.space_fuser(space_seq) # [B,D]
97
+
98
+ base = torch.cat([lesion_f, space_f, radiomics, pet, clinical], dim=1)
99
+ up = self.shared_up(base) # [B,128]
100
+
101
+ subtype_logits = self.subtype_head(up) # [B,Ks]
102
+ tnm_logits = self.tnm_head(up) # [B,Kt]
103
+
104
+ subtype_prob = F.softmax(subtype_logits, dim=1)
105
+ tnm_prob = F.softmax(tnm_logits, dim=1)
106
+
107
+ subtype_e = subtype_prob @ self.subtype_emb.weight # [B,E]
108
+ tnm_e = tnm_prob @ self.tnm_emb.weight # [B,E]
109
+
110
+ surv_x = torch.cat([up, subtype_e, tnm_e], dim=1)
111
+ surv_h = self.surv_mlp(surv_x) # [B,128]
112
+
113
+ dfs_risk = self.dfs_head(surv_h).squeeze(1)
114
+ os_risk = self.os_head(surv_h).squeeze(1)
115
+
116
+ dfs_logits = self.dfs_cls(surv_h) # [B,3]
117
+ os_logits = self.os_cls(surv_h) # [B,3]
118
+
119
+ return subtype_logits, tnm_logits, dfs_risk, os_risk, dfs_logits, os_logits
mm-dls/ImageDataLoader.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from PatientDataset import PatientMultiModalDataset
3
+
4
+ def make_loader(
5
+ split_dir: str,
6
+ batch_size: int = 4,
7
+ n_slices: int = 10,
8
+ img_size: int = 64,
9
+ num_workers: int = 4,
10
+ shuffle: bool = True,
11
+ pin_memory: bool = True,
12
+ ):
13
+ ds = PatientMultiModalDataset(
14
+ split_dir=split_dir,
15
+ n_slices=n_slices,
16
+ img_size=(img_size, img_size),
17
+ clinical_dim=6,
18
+ radiomics_dim=128,
19
+ pet_dim=5,
20
+ seed=0,
21
+ require_space=True,
22
+ )
23
+ return DataLoader(
24
+ ds,
25
+ batch_size=batch_size,
26
+ shuffle=shuffle,
27
+ num_workers=num_workers,
28
+ pin_memory=pin_memory,
29
+ drop_last=False,
30
+ )
mm-dls/LesionAttentionFusion.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class LesionAttentionFusion(nn.Module):
6
+ def __init__(self, input_dim, output_dim, heads=4, dropout=0.1):
7
+ super().__init__()
8
+ self.heads = heads
9
+ self.scale = (input_dim // heads) ** 0.5
10
+ self.q_proj = nn.Linear(input_dim, input_dim)
11
+ self.k_proj = nn.Linear(input_dim, input_dim)
12
+ self.v_proj = nn.Linear(input_dim, input_dim)
13
+ self.out_proj = nn.Linear(input_dim, output_dim)
14
+ self.dropout = nn.Dropout(dropout)
15
+
16
+ def forward(self, lesion_feat, lung_feat=None):
17
+ """
18
+ lesion_feat: [B, N, D] 或 [N, D] 单个病人时
19
+ lung_feat: [B, N, D] 或 [N, D]
20
+ """
21
+ if lung_feat is None:
22
+ lung_feat = lesion_feat
23
+
24
+ # 允许单个病人输入:自动添加 batch 维度
25
+ added_batch = False
26
+ if lesion_feat.dim() == 2:
27
+ lesion_feat = lesion_feat.unsqueeze(0) # -> [1, N, D]
28
+ lung_feat = lung_feat.unsqueeze(0)
29
+ added_batch = True
30
+
31
+ B, N, D = lesion_feat.shape
32
+ H = self.heads
33
+
34
+ Q = self.q_proj(lesion_feat).view(B, N, H, -1).transpose(1, 2) # [B, H, N, d]
35
+ K = self.k_proj(lung_feat).view(B, N, H, -1).transpose(1, 2) # [B, H, N, d]
36
+ V = self.v_proj(lung_feat).view(B, N, H, -1).transpose(1, 2) # [B, H, N, d]
37
+
38
+ attn_weights = (Q @ K.transpose(-2, -1)) / self.scale
39
+ attn_weights = self.dropout(F.softmax(attn_weights, dim=-1)) # [B, H, N, N]
40
+
41
+ attn_output = attn_weights @ V # [B, H, N, d]
42
+ attn_output = attn_output.transpose(1, 2).reshape(B, N, D)
43
+ output = self.out_proj(attn_output) + lesion_feat # residual connection
44
+
45
+ # 做平均池化(每个病人输出一个 [D] 向量)
46
+ output = output.mean(dim=1) # [B, D]
47
+
48
+ if added_batch:
49
+ return output[0] # 去掉 batch
50
+ return output
mm-dls/ModelLesionEncoder.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class LesionEncoder(nn.Module):
4
+ def __init__(self, input_channels=1, feature_dim=128):
5
+ super().__init__()
6
+ self.encoder = nn.Sequential(
7
+ nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
8
+ nn.ReLU(inplace=True),
9
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
10
+ nn.ReLU(inplace=True),
11
+ nn.AdaptiveAvgPool2d((1, 1)), # 输出 [B, 64, 1, 1]
12
+ nn.Flatten(), # [B, 64]
13
+ nn.Linear(64, feature_dim), # → [B, 128]
14
+ nn.ReLU(inplace=True)
15
+ )
16
+
17
+ def forward(self, x): # x: [B, 1, H, W]
18
+ return self.encoder(x) # [B, 128]
mm-dls/ModelSpaceEncoder.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ class SpaceEncoder(nn.Module):
4
+ def __init__(self, input_channels=1, feature_dim=128):
5
+ super().__init__()
6
+ self.encoder = nn.Sequential(
7
+ nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
8
+ nn.ReLU(inplace=True),
9
+ nn.Conv2d(32, 64, kernel_size=3, padding=1),
10
+ nn.ReLU(inplace=True),
11
+ nn.AdaptiveAvgPool2d((1, 1)), # 输出 [B, 64, 1, 1]
12
+ nn.Flatten(), # [B, 64]
13
+ nn.Linear(64, feature_dim), # → [B, 128]
14
+ nn.ReLU(inplace=True)
15
+ )
16
+
17
+ def forward(self, x): # x: [B, 1, H, W]
18
+ return self.encoder(x) # [B, 128]
mm-dls/PatientDataset.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # mm_dls/PatientDataset.py
2
+ import os
3
+ import numpy as np
4
+ import pandas as pd
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+
10
+
11
+ class PatientDataset(Dataset):
12
+ def __init__(
13
+ self,
14
+ data_root,
15
+ clinical_csv,
16
+ radiomics_npy,
17
+ pet_npy,
18
+ n_slices=30,
19
+ img_size=224
20
+ ):
21
+ super().__init__()
22
+
23
+ self.data_root = data_root
24
+ self.df = pd.read_csv(clinical_csv)
25
+ self.radiomics = np.load(radiomics_npy)
26
+ self.pet = np.load(pet_npy)
27
+
28
+ self.n_slices = n_slices
29
+
30
+ self.transform = transforms.Compose([
31
+ transforms.Resize((img_size, img_size)),
32
+ transforms.ToTensor(),
33
+ ])
34
+
35
+ def __len__(self):
36
+ return len(self.df)
37
+
38
+ def _load_slices(self, folder):
39
+ files = sorted(os.listdir(folder))[: self.n_slices]
40
+ imgs = []
41
+ for f in files:
42
+ img = Image.open(os.path.join(folder, f)).convert("L")
43
+ imgs.append(self.transform(img))
44
+ imgs = torch.stack(imgs, dim=0) # [S,1,H,W]
45
+ return imgs
46
+
47
+ def __getitem__(self, idx):
48
+ row = self.df.iloc[idx]
49
+ pid = row["pid"]
50
+
51
+ # -------- images --------
52
+ lesion_dir = os.path.join(self.data_root, "images", pid, "lesion")
53
+ space_dir = os.path.join(self.data_root, "images", pid, "space")
54
+
55
+ lesion = self._load_slices(lesion_dir)
56
+ space = self._load_slices(space_dir)
57
+
58
+ # -------- tabular --------
59
+ radiomics = torch.tensor(self.radiomics[idx], dtype=torch.float32)
60
+ pet = torch.tensor(self.pet[idx], dtype=torch.float32)
61
+ clinical = torch.zeros(6)
62
+
63
+ # -------- labels --------
64
+ y_sub = torch.tensor(row["subtype"], dtype=torch.long)
65
+ y_tnm = torch.tensor(row["tnm_stage"], dtype=torch.long)
66
+
67
+ dfs_time = torch.tensor(row["dfs_time"], dtype=torch.float32)
68
+ dfs_event = torch.tensor(row["dfs_event"], dtype=torch.float32)
69
+
70
+ os_time = torch.tensor(row["os_time"], dtype=torch.float32)
71
+ os_event = torch.tensor(row["os_event"], dtype=torch.float32)
72
+
73
+ # 1y / 3y / 5y
74
+ dfs_1y = torch.tensor(row["dfs_time"] <= 12, dtype=torch.float32)
75
+ dfs_3y = torch.tensor(row["dfs_time"] <= 36, dtype=torch.float32)
76
+ dfs_5y = torch.tensor(row["dfs_time"] <= 60, dtype=torch.float32)
77
+
78
+ os_1y = torch.tensor(row["os_time"] <= 12, dtype=torch.float32)
79
+ os_3y = torch.tensor(row["os_time"] <= 36, dtype=torch.float32)
80
+ os_5y = torch.tensor(row["os_time"] <= 60, dtype=torch.float32)
81
+
82
+ treatment = torch.tensor(row["treatment"], dtype=torch.long)
83
+
84
+ return (
85
+ pid,
86
+ lesion,
87
+ space,
88
+ radiomics,
89
+ pet,
90
+ clinical,
91
+ y_sub,
92
+ y_tnm,
93
+ dfs_time,
94
+ dfs_event,
95
+ os_time,
96
+ os_event,
97
+ dfs_1y,
98
+ dfs_3y,
99
+ dfs_5y,
100
+ os_1y,
101
+ os_3y,
102
+ os_5y,
103
+ treatment,
104
+ )
mm-dls/__init__.py ADDED
File without changes
mm-dls/plot_results.py ADDED
@@ -0,0 +1,733 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # code/plot_results.py
2
+ # ============================================================
3
+ # End-to-end paper-style plotting (curves + tables)
4
+ # - Subtype (binary): ROC + PR + Calibration (with tables)
5
+ # - TNM (multiclass OVR): ROC + PR + Calibration (with tables, per class)
6
+ # - DFS/OS survival: KM + Cox HR + log-rank + C-index/Brier (with at-risk text)
7
+ #
8
+ # IMPORTANT:
9
+ # - Safe to import (NO plotting on import)
10
+ # - Call plot_all(result_dir, fig_dir) after main.py saves outputs
11
+ # ============================================================
12
+
13
+ import os
14
+ import numpy as np
15
+ import pandas as pd
16
+ import matplotlib.pyplot as plt
17
+
18
+ from sklearn.preprocessing import label_binarize
19
+ from sklearn.metrics import (
20
+ roc_curve, auc,
21
+ precision_recall_curve, average_precision_score,
22
+ confusion_matrix,
23
+ brier_score_loss
24
+ )
25
+ from sklearn.calibration import calibration_curve
26
+
27
+ from lifelines import KaplanMeierFitter, CoxPHFitter
28
+ from lifelines.statistics import multivariate_logrank_test
29
+ from lifelines.utils import concordance_index
30
+ from scipy.stats import norm
31
+
32
+
33
+ # ============================================================
34
+ # Basic I/O helpers
35
+ # ============================================================
36
+ def _ensure_dir(path: str):
37
+ os.makedirs(path, exist_ok=True)
38
+
39
+
40
+ def _exists(path: str) -> bool:
41
+ return os.path.exists(path) and os.path.isfile(path)
42
+
43
+
44
+ def _load_npy(path: str):
45
+ if not _exists(path):
46
+ return None
47
+ return np.load(path, allow_pickle=True)
48
+
49
+
50
+ def _maybe_sim_ext(labels, scores, noise=0.03, seed=42):
51
+ """
52
+ Simulate an external test split when not provided.
53
+ Keeps labels same; adds small noise to scores then clips to [0,1].
54
+ """
55
+ rng = np.random.RandomState(seed)
56
+ if scores is None:
57
+ return None, None
58
+ s = scores.copy()
59
+ s = np.clip(s + rng.normal(0, noise, s.shape), 0.0, 1.0)
60
+ return labels.copy(), s
61
+
62
+
63
+ # ============================================================
64
+ # Metrics helpers
65
+ # ============================================================
66
+ def _calc_binary_roc(y_true, y_score):
67
+ fpr, tpr, _ = roc_curve(y_true, y_score)
68
+ roc_auc = auc(fpr, tpr)
69
+ brier = brier_score_loss(y_true, y_score)
70
+ return fpr, tpr, roc_auc, brier
71
+
72
+
73
+ def _calc_binary_pr(y_true, y_score):
74
+ p, r, _ = precision_recall_curve(y_true, y_score)
75
+ ap = average_precision_score(y_true, y_score)
76
+ return p, r, ap
77
+
78
+
79
+ def _spec_npv_binary(y_true, y_score, thresh=0.5):
80
+ y_pred = (y_score >= thresh).astype(int)
81
+ tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
82
+ specificity = tn / (tn + fp) if (tn + fp) else 0.0
83
+ npv = tn / (tn + fn) if (tn + fn) else 0.0
84
+ return specificity, npv
85
+
86
+
87
+ def _ece(y_true, y_score, n_bins=10):
88
+ bins = np.linspace(0.0, 1.0, n_bins + 1)
89
+ binids = np.digitize(y_score, bins) - 1
90
+ ece = 0.0
91
+ for i in range(n_bins):
92
+ m = binids == i
93
+ if m.sum() > 0:
94
+ prob_true = np.mean(y_true[m])
95
+ prob_pred = np.mean(y_score[m])
96
+ ece += (m.sum() / len(y_score)) * abs(prob_pred - prob_true)
97
+ return float(ece)
98
+
99
+
100
+ def _calc_ovr_auc(y_bin, y_score):
101
+ """One-vs-rest ROC for multiclass. Returns dict: {class_i: (fpr,tpr,auc)}"""
102
+ out = {}
103
+ for i in range(y_bin.shape[1]):
104
+ fpr, tpr, _ = roc_curve(y_bin[:, i], y_score[:, i])
105
+ out[i] = (fpr, tpr, auc(fpr, tpr))
106
+ return out
107
+
108
+
109
+ def _calc_ovr_pr(y_bin, y_score):
110
+ """One-vs-rest PR for multiclass. Returns dict: {class_i: (p,r,ap)}"""
111
+ out = {}
112
+ for i in range(y_bin.shape[1]):
113
+ p, r, _ = precision_recall_curve(y_bin[:, i], y_score[:, i])
114
+ ap = average_precision_score(y_bin[:, i], y_score[:, i])
115
+ out[i] = (p, r, ap)
116
+ return out
117
+
118
+
119
+ def _acc_ovr(y_true_bin, y_score, thresh=0.5):
120
+ y_pred = (y_score >= thresh).astype(int)
121
+ return float((y_pred == y_true_bin).mean())
122
+
123
+
124
+ # ============================================================
125
+ # Table helpers (paper-style)
126
+ # ============================================================
127
+ def _auto_col_widths(col_labels, bbox_w):
128
+ lens = np.array([max(4, len(c)) for c in col_labels], dtype=float)
129
+ ratio = lens / lens.sum()
130
+ return bbox_w * ratio
131
+
132
+
133
+ def _add_table(ax, table_data, row_labels, col_labels, colors=None,
134
+ bbox=(0.05, -0.50, 0.95, 0.30),
135
+ fontsize=13, rowlabel_width=0.18):
136
+ """
137
+ colors: list[str] length = len(row_labels) (for per-row coloring)
138
+ """
139
+ tbl = plt.table(
140
+ cellText=table_data,
141
+ rowLabels=row_labels,
142
+ colLabels=col_labels,
143
+ cellLoc='center',
144
+ rowLoc='left',
145
+ colLoc='center',
146
+ bbox=list(bbox),
147
+ )
148
+ tbl.auto_set_font_size(False)
149
+ tbl.set_fontsize(fontsize)
150
+
151
+ cells = tbl.get_celld()
152
+ # set column widths (excluding row label col=-1)
153
+ col_widths = _auto_col_widths(col_labels, bbox[2])
154
+ for col in range(len(col_labels)):
155
+ for row in range(len(row_labels) + 1): # header included
156
+ cells[(row, col)].set_width(col_widths[col])
157
+
158
+ # row label width
159
+ for row in range(1, len(row_labels) + 1):
160
+ if (row, -1) in cells:
161
+ cells[(row, -1)].set_width(rowlabel_width)
162
+
163
+ # styling: no grid lines
164
+ for (r, c), cell in cells.items():
165
+ cell.set_linewidth(0)
166
+
167
+ # optional per-row color
168
+ if colors is not None:
169
+ for r in range(1, len(row_labels) + 1):
170
+ # color values (not the header)
171
+ for c in range(len(col_labels)):
172
+ if (r, c) in cells:
173
+ cells[(r, c)].get_text().set_color(colors[r - 1])
174
+ # row label
175
+ if (r, -1) in cells:
176
+ cells[(r, -1)].get_text().set_color(colors[r - 1])
177
+
178
+ return tbl
179
+
180
+
181
+ # ============================================================
182
+ # Subtype (binary) plots: ROC / PR / Calibration
183
+ # ============================================================
184
+ def plot_subtype_binary(result_dir="./results", fig_dir="./figures",
185
+ title_suffix="(LUAD vs LUSC)"):
186
+ _ensure_dir(fig_dir)
187
+
188
+ # Required: train/val/test
189
+ paths = {
190
+ "Train": (os.path.join(result_dir, "subtype_train_labels.npy"),
191
+ os.path.join(result_dir, "subtype_train_scores.npy")),
192
+ "Int.Valid": (os.path.join(result_dir, "subtype_val_labels.npy"),
193
+ os.path.join(result_dir, "subtype_val_scores.npy")),
194
+ "Int.Test": (os.path.join(result_dir, "subtype_test_labels.npy"),
195
+ os.path.join(result_dir, "subtype_test_scores.npy")),
196
+ }
197
+
198
+ data = {}
199
+ missing_core = False
200
+ for k, (lp, sp) in paths.items():
201
+ y = _load_npy(lp)
202
+ s = _load_npy(sp)
203
+ if y is None or s is None:
204
+ print(f"[plot_subtype_binary] Skip: missing {lp} or {sp}")
205
+ missing_core = True
206
+ break
207
+ data[k] = (y.astype(int), s.astype(float))
208
+
209
+ if missing_core:
210
+ return
211
+
212
+ # External (simulated) if not present
213
+ ext_lp = os.path.join(result_dir, "subtype_test2_labels.npy")
214
+ ext_sp = os.path.join(result_dir, "subtype_test2_scores.npy")
215
+ ext_y = _load_npy(ext_lp)
216
+ ext_s = _load_npy(ext_sp)
217
+ if ext_y is None or ext_s is None:
218
+ ext_y, ext_s = _maybe_sim_ext(data["Int.Test"][0], data["Int.Test"][1], noise=0.04, seed=7)
219
+ data["Ext.Test"] = (ext_y.astype(int), ext_s.astype(float))
220
+
221
+ # Colors (match your style)
222
+ colors = {
223
+ "Train": "#0074B7",
224
+ "Int.Valid": "#60A3D9",
225
+ "Int.Test": "#6CC4DC",
226
+ "Ext.Test": "#61649f",
227
+ }
228
+ row_colors = [colors["Train"], colors["Int.Valid"], colors["Int.Test"], colors["Ext.Test"]]
229
+
230
+ # ---------- ROC (Figure 4a-like) ----------
231
+ roc_items = {}
232
+ for k, (y, s) in data.items():
233
+ fpr, tpr, auc_k, brier_k = _calc_binary_roc(y, s)
234
+ roc_items[k] = dict(fpr=fpr, tpr=tpr, auc=auc_k, brier=brier_k, y=y, s=s)
235
+
236
+ auc_list = np.array([roc_items[k]["auc"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float)
237
+ auc_cv = float(np.std(auc_list) / np.mean(auc_list)) if np.mean(auc_list) > 0 else 0.0
238
+
239
+ fig, ax = plt.subplots(figsize=(5, 7), facecolor="white")
240
+ ax.set_facecolor("white")
241
+
242
+ for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
243
+ ax.plot(roc_items[k]["fpr"], roc_items[k]["tpr"],
244
+ label=f"{k} (AUC = {roc_items[k]['auc']:.2f})",
245
+ color=colors[k], linewidth=3)
246
+
247
+ ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
248
+ ax.set_xlim([-0.01, 1.0])
249
+ ax.set_ylim([0.0, 1.05])
250
+ ax.set_xticks(np.linspace(0, 1, 6))
251
+ ax.set_yticks(np.linspace(0, 1, 6))
252
+ ax.set_xlabel("False Positive Rate", fontsize=14)
253
+ ax.set_ylabel("True Positive Rate", fontsize=14)
254
+ ax.set_title(f"Pathological Subtype Classification ROC Curves\n{title_suffix}", fontsize=14)
255
+ ax.legend(loc="lower right", fontsize=12)
256
+ ax.grid(alpha=0.3)
257
+
258
+ # Table: Number / AUC CV / Brier Score
259
+ def _posneg(y):
260
+ neg = int((y == 0).sum())
261
+ pos = int((y == 1).sum())
262
+ return f"{neg} vs {pos}"
263
+
264
+ row_labels = ["Train", "Int.Valid", "Int.Test", "Ext.Test"]
265
+ col_labels = ["Number", "AUC CV", "Brier Score"]
266
+ table_data = [
267
+ [_posneg(roc_items["Train"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Train']['brier']:.3f}"],
268
+ [_posneg(roc_items["Int.Valid"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Int.Valid']['brier']:.3f}"],
269
+ [_posneg(roc_items["Int.Test"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Int.Test']['brier']:.3f}"],
270
+ [_posneg(roc_items["Ext.Test"]["y"]), f"{auc_cv:.2f}", f"{roc_items['Ext.Test']['brier']:.3f}"],
271
+ ]
272
+ _add_table(ax, table_data, row_labels, col_labels, colors=row_colors,
273
+ bbox=(0.05, -0.52, 0.98, 0.30), fontsize=12, rowlabel_width=0.20)
274
+
275
+ plt.subplots_adjust(bottom=0.42)
276
+ plt.savefig(os.path.join(fig_dir, "Figure4a_subtype_ROC.png"), dpi=600, bbox_inches="tight")
277
+ plt.savefig(os.path.join(fig_dir, "Figure4a_subtype_ROC.pdf"), dpi=600, bbox_inches="tight")
278
+ plt.close()
279
+
280
+ # ---------- PR (Figure 4b-like) ----------
281
+ pr_items = {}
282
+ for k, (y, s) in data.items():
283
+ p, r, ap = _calc_binary_pr(y, s)
284
+ spec, npv = _spec_npv_binary(y, s, thresh=0.5)
285
+ pr_items[k] = dict(p=p, r=r, ap=ap, spec=spec, npv=npv, y=y, s=s)
286
+
287
+ ap_vals = np.array([pr_items[k]["ap"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float)
288
+ ap_cv = float(np.std(ap_vals) / np.mean(ap_vals)) if np.mean(ap_vals) > 0 else 0.0
289
+
290
+ fig, ax = plt.subplots(figsize=(7, 5.3), facecolor="white")
291
+ ax.set_facecolor("white")
292
+
293
+ for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
294
+ ax.plot(pr_items[k]["r"], pr_items[k]["p"],
295
+ label=f"{k} (AP={pr_items[k]['ap']:.2f})",
296
+ color={
297
+ "Train": "#7F8FA3",
298
+ "Int.Valid": "#FFA0A3",
299
+ "Int.Test": "#77DDF9",
300
+ "Ext.Test": "#61649f",
301
+ }[k],
302
+ linewidth=3)
303
+ ax.fill_between(pr_items[k]["r"], pr_items[k]["p"], step='post', alpha=0.1,
304
+ color={
305
+ "Train": "#7F8FA3",
306
+ "Int.Valid": "#FFA0A3",
307
+ "Int.Test": "#77DDF9",
308
+ "Ext.Test": "#61649f",
309
+ }[k])
310
+
311
+ ax.set_xlim(-0.01, 1.01)
312
+ ax.set_ylim(-0.01, 1.01)
313
+ ax.set_xlabel("Recall", fontsize=14)
314
+ ax.set_ylabel("Precision", fontsize=14)
315
+ ax.set_title(f"Pathological Subtype Classification Precision-Recall Curves\n{title_suffix}", fontsize=14)
316
+ ax.legend(loc="lower left", fontsize=12)
317
+ ax.grid(alpha=0.3)
318
+
319
+ row_labels = [
320
+ f"Train (n={len(pr_items['Train']['y'])})",
321
+ f"Int.Valid (n={len(pr_items['Int.Valid']['y'])})",
322
+ f"Int.Test (n={len(pr_items['Int.Test']['y'])})",
323
+ f"Ext.Test (n={len(pr_items['Ext.Test']['y'])})",
324
+ ]
325
+ col_labels = ["AP CV", "Specificity", "NPV", "Average Precision"]
326
+ table_data = [
327
+ [f"{ap_cv:.2f}", f"{pr_items['Train']['spec']:.2f}", f"{pr_items['Train']['npv']:.2f}", f"{pr_items['Train']['ap']:.2f}"],
328
+ [f"{ap_cv:.2f}", f"{pr_items['Int.Valid']['spec']:.2f}", f"{pr_items['Int.Valid']['npv']:.2f}", f"{pr_items['Int.Valid']['ap']:.2f}"],
329
+ [f"{ap_cv:.2f}", f"{pr_items['Int.Test']['spec']:.2f}", f"{pr_items['Int.Test']['npv']:.2f}", f"{pr_items['Int.Test']['ap']:.2f}"],
330
+ [f"{ap_cv:.2f}", f"{pr_items['Ext.Test']['spec']:.2f}", f"{pr_items['Ext.Test']['npv']:.2f}", f"{pr_items['Ext.Test']['ap']:.2f}"],
331
+ ]
332
+ pr_row_colors = ["#7F8FA3", "#FFA0A3", "#77DDF9", "#61649f"]
333
+ _add_table(ax, table_data, row_labels, col_labels, colors=pr_row_colors,
334
+ bbox=(0.10, -0.55, 0.90, 0.30), fontsize=12, rowlabel_width=0.28)
335
+
336
+ plt.subplots_adjust(bottom=0.45)
337
+ plt.savefig(os.path.join(fig_dir, "Figure4b_subtype_PR.png"), dpi=600, bbox_inches="tight")
338
+ plt.savefig(os.path.join(fig_dir, "Figure4b_subtype_PR.pdf"), dpi=600, bbox_inches="tight")
339
+ plt.close()
340
+
341
+ # ---------- Calibration (Figure 4c-like) ----------
342
+ fig, ax = plt.subplots(figsize=(5, 5.4), facecolor="white")
343
+ ax.set_facecolor("white")
344
+
345
+ calib_colors = {
346
+ "Train": "#7F8FA3",
347
+ "Int.Valid": "#FFA0A3",
348
+ "Int.Test": "#77DDF9",
349
+ "Ext.Test": "#61649f",
350
+ }
351
+ eces = {}
352
+ for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
353
+ y, s = data[k]
354
+ prob_true, prob_pred = calibration_curve(y, s, n_bins=10)
355
+ ax.plot(prob_pred, prob_true, marker='o', label=k, color=calib_colors[k])
356
+ eces[k] = _ece(y, s, n_bins=10)
357
+
358
+ ax.plot([0, 1], [0, 1], 'k--', label='Perfect')
359
+ ax.set_xlim(-0.01, 1.01)
360
+ ax.set_ylim(-0.01, 1.01)
361
+ ax.set_xlabel("Mean Predicted Probability", fontsize=14)
362
+ ax.set_ylabel("Fraction of Positives", fontsize=14)
363
+ ax.set_title(f"Pathological Subtype Classification Calibration Curves\n{title_suffix}", fontsize=14)
364
+ ax.legend(loc="lower right", fontsize=12)
365
+ ax.grid(alpha=0.3)
366
+
367
+ row_labels = [
368
+ f"Train (n={len(data['Train'][0])})",
369
+ f"Int.Valid (n={len(data['Int.Valid'][0])})",
370
+ f"Int.Test (n={len(data['Int.Test'][0])})",
371
+ f"Ext.Test (n={len(data['Ext.Test'][0])})",
372
+ ]
373
+ col_labels = ["ECE"]
374
+ table_data = [
375
+ [f"{eces['Train']:.3f}"],
376
+ [f"{eces['Int.Valid']:.3f}"],
377
+ [f"{eces['Int.Test']:.3f}"],
378
+ [f"{eces['Ext.Test']:.3f}"],
379
+ ]
380
+ _add_table(ax, table_data, row_labels, col_labels, colors=pr_row_colors,
381
+ bbox=(0.30, -0.55, 0.65, 0.30), fontsize=12, rowlabel_width=0.40)
382
+
383
+ plt.subplots_adjust(bottom=0.42)
384
+ plt.savefig(os.path.join(fig_dir, "Figure4c_subtype_Calibration.png"), dpi=600, bbox_inches="tight")
385
+ plt.savefig(os.path.join(fig_dir, "Figure4c_subtype_Calibration.pdf"), dpi=600, bbox_inches="tight")
386
+ plt.close()
387
+
388
+ print("✔ Subtype (binary) figures generated.")
389
+
390
+
391
+ # ============================================================
392
+ # TNM (multiclass OVR) plots: ROC / PR / Calibration + tables
393
+ # ============================================================
394
+ def plot_tnm_multiclass(result_dir="./results", fig_dir="./figures"):
395
+ _ensure_dir(fig_dir)
396
+
397
+ req = [
398
+ "tnm_train_labels.npy", "tnm_train_scores.npy",
399
+ "tnm_val_labels.npy", "tnm_val_scores.npy",
400
+ "tnm_test_labels.npy", "tnm_test_scores.npy",
401
+ ]
402
+ for f in req:
403
+ if not _exists(os.path.join(result_dir, f)):
404
+ print(f"[plot_tnm_multiclass] Skip: missing {os.path.join(result_dir, f)}")
405
+ return
406
+
407
+ train_y = np.load(os.path.join(result_dir, "tnm_train_labels.npy")).astype(int)
408
+ train_s = np.load(os.path.join(result_dir, "tnm_train_scores.npy")).astype(float)
409
+
410
+ val_y = np.load(os.path.join(result_dir, "tnm_val_labels.npy")).astype(int)
411
+ val_s = np.load(os.path.join(result_dir, "tnm_val_scores.npy")).astype(float)
412
+
413
+ test_y = np.load(os.path.join(result_dir, "tnm_test_labels.npy")).astype(int)
414
+ test_s = np.load(os.path.join(result_dir, "tnm_test_scores.npy")).astype(float)
415
+
416
+ # external (simulated unless provided)
417
+ test2_lp = os.path.join(result_dir, "tnm_test2_labels.npy")
418
+ test2_sp = os.path.join(result_dir, "tnm_test2_scores.npy")
419
+ test2_y = _load_npy(test2_lp)
420
+ test2_s = _load_npy(test2_sp)
421
+ if test2_y is None or test2_s is None:
422
+ test2_y, test2_s = _maybe_sim_ext(test_y, test_s, noise=0.05, seed=9)
423
+ test2_y = test2_y.astype(int)
424
+ test2_s = test2_s.astype(float)
425
+
426
+ classes = [0, 1, 2]
427
+ names = ['Stage I-II', 'Stage III', 'Stage IV']
428
+ colors = ['#0074B7', '#60A3D9', '#6CC4DC']
429
+
430
+ bins = {
431
+ "Train": (label_binarize(train_y, classes), train_s, train_y),
432
+ "Int.Valid": (label_binarize(val_y, classes), val_s, val_y),
433
+ "Int.Test": (label_binarize(test_y, classes), test_s, test_y),
434
+ "Ext.Test": (label_binarize(test2_y, classes), test2_s, test2_y),
435
+ }
436
+ row_labels_base = ["Train", "Int.Valid", "Int.Test", "Ext.Test"]
437
+ row_colors = ["#0074B7", "#60A3D9", "#6CC4DC", "#22a2c3"]
438
+
439
+ # ---------- Figure 5a1: ROC per class + table ----------
440
+ for i, cname in enumerate(names):
441
+ fig, ax = plt.subplots(figsize=(5, 6), facecolor="white")
442
+ ax.set_facecolor("white")
443
+
444
+ aucs = {}
445
+ fprs = {}
446
+ tprs = {}
447
+ sample_counts = {}
448
+ accs = {}
449
+
450
+ for key, (yb, ys, ylab) in bins.items():
451
+ ovr = _calc_ovr_auc(yb, ys)
452
+ fpr, tpr, auc_i = ovr[i]
453
+ fprs[key], tprs[key], aucs[key] = fpr, tpr, float(auc_i)
454
+
455
+ sample_counts[key] = str(int((ylab == i).sum()))
456
+ accs[key] = _acc_ovr(yb[:, i], ys[:, i], thresh=0.5)
457
+
458
+ # plot 4 curves with different linestyles like your original
459
+ styles = {"Train": "-", "Int.Valid": "--", "Int.Test": ":", "Ext.Test": "-."}
460
+ for key in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
461
+ ax.plot(fprs[key], tprs[key], linestyle=styles[key],
462
+ label=f"{key} (AUC = {aucs[key]:.2f})",
463
+ color=colors[i], linewidth=2.5)
464
+
465
+ ax.plot([0, 1], [0, 1], 'k--', alpha=0.3)
466
+ ax.set_xlim([-0.01, 1.0])
467
+ ax.set_ylim([0.0, 1.05])
468
+ ax.set_xticks(np.linspace(0, 1, 6))
469
+ ax.set_yticks(np.linspace(0, 1, 6))
470
+ ax.set_xlabel('False Positive Rate', fontsize=13)
471
+ ax.set_ylabel('True Positive Rate', fontsize=13)
472
+ ax.set_title(f'TNM stage Classification ROC Curve \nfor {cname}', fontsize=14)
473
+ ax.legend(loc="lower right", fontsize=11)
474
+ ax.grid(alpha=0.3)
475
+
476
+ # table (Sample Count / AUC / Accuracy) — same spirit as your original
477
+ col_labels = ["Sample Count", "AUC", "Accuracy"]
478
+ table_data = [
479
+ [sample_counts["Train"], f"{aucs['Train']:.2f}", f"{accs['Train']:.3f}"],
480
+ [sample_counts["Int.Valid"], f"{aucs['Int.Valid']:.2f}", f"{accs['Int.Valid']:.3f}"],
481
+ [sample_counts["Int.Test"], f"{aucs['Int.Test']:.2f}", f"{accs['Int.Test']:.3f}"],
482
+ [sample_counts["Ext.Test"], f"{aucs['Ext.Test']:.2f}", f"{accs['Ext.Test']:.3f}"],
483
+ ]
484
+ _add_table(ax, table_data, row_labels_base, col_labels, colors=[colors[i]]*4,
485
+ bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18)
486
+
487
+ plt.subplots_adjust(bottom=0.38)
488
+ safe_name = cname.replace(" ", "_").replace("-", "_")
489
+ plt.savefig(os.path.join(fig_dir, f"Figure5a1_{safe_name}.png"), dpi=600, bbox_inches="tight")
490
+ plt.savefig(os.path.join(fig_dir, f"Figure5a1_{safe_name}.pdf"), dpi=600, bbox_inches="tight")
491
+ plt.close()
492
+
493
+ # ---------- Figure 5a2: PR per class + table ----------
494
+ for i, cname in enumerate(names):
495
+ fig, ax = plt.subplots(figsize=(5, 6.5), facecolor="white")
496
+ ax.set_facecolor("white")
497
+
498
+ # PR curves for each split
499
+ pr = {}
500
+ for key, (yb, ys, ylab) in bins.items():
501
+ p, r, ap = _calc_ovr_pr(yb, ys)[i]
502
+ spec, npv = _spec_npv_binary(yb[:, i], ys[:, i], thresh=0.5)
503
+ pr[key] = dict(p=p, r=r, ap=float(ap), spec=spec, npv=npv)
504
+
505
+ # AP CV across splits (per class)
506
+ ap_vals = np.array([pr[k]["ap"] for k in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]], dtype=float)
507
+ ap_cv = float(np.std(ap_vals) / np.mean(ap_vals)) if np.mean(ap_vals) > 0 else 0.0
508
+
509
+ styles = {"Train": "-", "Int.Valid": "--", "Int.Test": ":", "Ext.Test": "-."}
510
+ colors_pr = ['#7F8FA3', '#FFA0A3', '#77DDF9'] # your TNM PR palette (3 classes)
511
+ c_use = colors_pr[i]
512
+
513
+ for key in ["Train", "Int.Valid", "Int.Test", "Ext.Test"]:
514
+ ax.plot(pr[key]["r"], pr[key]["p"], linestyle=styles[key],
515
+ label=f"{key} (AP={pr[key]['ap']:.2f})",
516
+ color=c_use, linewidth=2.5)
517
+
518
+ ax.set_xlim([-0.01, 1.0])
519
+ ax.set_ylim([0.0, 1.05])
520
+ ax.set_xticks(np.linspace(0, 1, 6))
521
+ ax.set_yticks(np.linspace(0, 1, 6))
522
+ ax.set_xlabel('Recall', fontsize=14)
523
+ ax.set_ylabel('Precision', fontsize=14)
524
+ ax.set_title(f'TNM stage Classification Precision-Recall Curve \nfor {cname}', fontsize=14)
525
+ ax.legend(loc="lower left", fontsize=12)
526
+ ax.grid(alpha=0.3)
527
+
528
+ col_labels = ["AP CV", "Specificity", "NPV", "Average Precision"]
529
+ table_data = [
530
+ [f"{ap_cv:.2f}", f"{pr['Train']['spec']:.2f}", f"{pr['Train']['npv']:.2f}", f"{pr['Train']['ap']:.2f}"],
531
+ [f"{ap_cv:.2f}", f"{pr['Int.Valid']['spec']:.2f}", f"{pr['Int.Valid']['npv']:.2f}", f"{pr['Int.Valid']['ap']:.2f}"],
532
+ [f"{ap_cv:.2f}", f"{pr['Int.Test']['spec']:.2f}", f"{pr['Int.Test']['npv']:.2f}", f"{pr['Int.Test']['ap']:.2f}"],
533
+ [f"{ap_cv:.2f}", f"{pr['Ext.Test']['spec']:.2f}", f"{pr['Ext.Test']['npv']:.2f}", f"{pr['Ext.Test']['ap']:.2f}"],
534
+ ]
535
+ _add_table(ax, table_data, row_labels_base, col_labels, colors=[c_use]*4,
536
+ bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18)
537
+
538
+ plt.subplots_adjust(bottom=0.40)
539
+ safe_name = cname.replace(" ", "_").replace("-", "_")
540
+ plt.savefig(os.path.join(fig_dir, f"Figure5a2_{safe_name}.png"), dpi=600, bbox_inches="tight")
541
+ plt.savefig(os.path.join(fig_dir, f"Figure5a2_{safe_name}.pdf"), dpi=600, bbox_inches="tight")
542
+ plt.close()
543
+
544
+ # ---------- Figure 5a3: Calibration per class + table (ECE) ----------
545
+ for i, cname in enumerate(names):
546
+ fig, ax = plt.subplots(figsize=(5, 6.3), facecolor="white")
547
+ ax.set_facecolor("white")
548
+
549
+ calib_cols = ["#0074B7", "#60A3D9", "#6CC4DC", "#22a2c3"] # split colors
550
+ eces = {}
551
+
552
+ for (key, (yb, ys, _)), c in zip(bins.items(), calib_cols):
553
+ pt, pp = calibration_curve(yb[:, i], ys[:, i], n_bins=10, strategy="uniform")
554
+ ax.plot(pp, pt, marker='o', label=key, color=c)
555
+ eces[key] = _ece(yb[:, i], ys[:, i], n_bins=10)
556
+
557
+ ax.plot([0, 1], [0, 1], 'k--', label='Perfectly Calibrated')
558
+ ax.set_xlim(-0.01, 1.01)
559
+ ax.set_ylim(-0.01, 1.01)
560
+ ax.set_xlabel('Mean Predicted Probability', fontsize=13)
561
+ ax.set_ylabel('Fraction of Positives', fontsize=13)
562
+ ax.set_title(f'TNM stage Classification Calibration Curve \nfor {cname}', fontsize=14)
563
+ ax.legend(loc='upper left', fontsize=11)
564
+ ax.grid(alpha=0.3)
565
+
566
+ col_labels = ["ECE"]
567
+ table_data = [
568
+ [f"{eces['Train']:.3f}"],
569
+ [f"{eces['Int.Valid']:.3f}"],
570
+ [f"{eces['Int.Test']:.3f}"],
571
+ [f"{eces['Ext.Test']:.3f}"],
572
+ ]
573
+ _add_table(ax, table_data, row_labels_base, col_labels, colors=calib_cols,
574
+ bbox=(0.10, -0.52, 0.90, 0.30), fontsize=12, rowlabel_width=0.18)
575
+
576
+ plt.subplots_adjust(bottom=0.38)
577
+ safe_name = cname.replace(" ", "_").replace("-", "_")
578
+ plt.savefig(os.path.join(fig_dir, f"Figure5a3_{safe_name}.png"), dpi=600, bbox_inches="tight")
579
+ plt.savefig(os.path.join(fig_dir, f"Figure5a3_{safe_name}.pdf"), dpi=600, bbox_inches="tight")
580
+ plt.close()
581
+
582
+ print("✔ TNM multiclass figures generated.")
583
+
584
+
585
+ # ============================================================
586
+ # Survival plots (DFS/OS): KM + Cox HR + log-rank + at-risk text
587
+ # ============================================================
588
+ def _evaluate_survival(df):
589
+ df = df.copy()
590
+ df["risk_score"] = df["group"].map({"Low": 0, "Mediate": 1, "High": 2})
591
+ c_index = concordance_index(df["time"], -df["risk_score"], df["event"])
592
+ time_point = 30
593
+ y_true = (df["time"] > time_point).astype(int)
594
+ y_prob = 1 - df["risk_score"] / 2.0
595
+ brier = brier_score_loss(y_true, y_prob)
596
+ return float(c_index), float(brier)
597
+
598
+
599
+ def _plot_km_with_hr_and_atrisk(df, title, save_path, n_total=None):
600
+ kmf = KaplanMeierFitter()
601
+ fig, ax = plt.subplots(figsize=(8, 6), facecolor="white")
602
+ ax.set_facecolor("white")
603
+
604
+ colors = {"Low": "#91c7ae", "Mediate": "#f7b977", "High": "#d87c7c"}
605
+ groups = ["Low", "Mediate", "High"]
606
+
607
+ # curves + capture handles
608
+ lines = {}
609
+ at_risk_table = []
610
+ times = np.arange(0, 70, 10)
611
+
612
+ for g in groups:
613
+ m = (df["group"] == g)
614
+ if m.sum() == 0:
615
+ at_risk_table.append([0 for _ in times])
616
+ continue
617
+ kmf.fit(df.loc[m, "time"], event_observed=df.loc[m, "event"], label=g)
618
+ kmf.plot_survival_function(ci_show=True, linewidth=2, color=colors[g], ax=ax)
619
+ lines[g] = ax.get_lines()[-1]
620
+ at_risk_table.append([int(np.sum(df.loc[m, "time"] >= t)) for t in times])
621
+
622
+ handles = [lines.get("Low"), lines.get("Mediate"), lines.get("High")]
623
+ labels = ["Low", "Medium", "High"]
624
+ ax.legend(handles, labels, title="Groups", loc="upper right", framealpha=0.5, fontsize=12, title_fontsize=12)
625
+
626
+ # at-risk text (match your style)
627
+ # place below x-axis
628
+ for i, t in enumerate(times):
629
+ l, m, h = at_risk_table[0][i], at_risk_table[1][i], at_risk_table[2][i]
630
+ ax.text(t, -0.38, str(l), color="#207f4c", fontsize=13, ha='center')
631
+ ax.text(t, -0.48, str(m), color="#fca106", fontsize=13, ha='center')
632
+ ax.text(t, -0.58, str(h), color="#cc163a", fontsize=13, ha='center')
633
+
634
+ ax.text(-1, -0.28, 'Number at risk', color='black', ha='center', fontsize=13)
635
+ ax.text(-10, -0.38, "Low", color="#207f4c", fontsize=13)
636
+ ax.text(-10, -0.48, "Medium", color="#fca106", fontsize=13)
637
+ ax.text(-10, -0.58, "High", color="#cc163a", fontsize=13)
638
+
639
+ # Cox HR + Wald p
640
+ dfx = df.copy()
641
+ dfx["group_code"] = dfx["group"].map({"Low": 0, "Mediate": 1, "High": 2})
642
+ cph = CoxPHFitter()
643
+ cph.fit(dfx[["time", "event", "group_code"]], duration_col="time", event_col="event")
644
+ coef = float(cph.params_["group_code"])
645
+ se = float(cph.standard_errors_["group_code"])
646
+
647
+ hr_med_vs_low = float(np.exp(coef))
648
+ hr_high_vs_low = float(np.exp(2 * coef))
649
+
650
+ z_med = (coef) / se
651
+ p_med = float(2 * (1 - norm.cdf(abs(z_med))))
652
+ z_high = (2 * coef) / se
653
+ p_high = float(2 * (1 - norm.cdf(abs(z_high))))
654
+
655
+ # global stats
656
+ c_index, brier = _evaluate_survival(df)
657
+ logrank_p = float(multivariate_logrank_test(df["time"], df["group"], df["event"]).p_value)
658
+
659
+ ax.text(25, 0.46, f"P={logrank_p:.3f}", fontsize=12)
660
+ ax.text(25, 0.36, f"C-index={c_index:.3f}", fontsize=12)
661
+ ax.text(25, 0.26, f"Brier Score={brier:.3f}", fontsize=12)
662
+ ax.text(25, 0.16, f"HR Intermediate vs Low = {hr_med_vs_low:.2f}, P={p_med:.3f}", fontsize=12)
663
+ ax.text(25, 0.06, f"HR High vs Low = {hr_high_vs_low:.2f}, P={p_high:.3f}", fontsize=12)
664
+
665
+ ax.spines['top'].set_visible(False)
666
+ ax.spines['right'].set_visible(False)
667
+
668
+ if n_total is None:
669
+ n_total = len(df)
670
+
671
+ ax.set_title(f"{title}\n(n={n_total})", fontsize=14)
672
+ ax.set_xlabel("Time since treatment start (months)", fontsize=13)
673
+ ax.set_ylabel("Survival probability", fontsize=13)
674
+ ax.set_ylim(0, 1.05)
675
+ ax.grid(alpha=0.3)
676
+
677
+ plt.tight_layout()
678
+ plt.savefig(save_path + ".png", dpi=600, bbox_inches="tight")
679
+ plt.savefig(save_path + ".pdf", dpi=600, bbox_inches="tight")
680
+ plt.close()
681
+
682
+
683
+ def plot_survival(result_dir="./results", fig_dir="./figures"):
684
+ _ensure_dir(fig_dir)
685
+
686
+ # DFS/OS for train/val/test; ext optional
687
+ for split in ["train", "val", "test"]:
688
+ dfs_path = os.path.join(result_dir, f"dfs_{split}.csv")
689
+ os_path = os.path.join(result_dir, f"os_{split}.csv")
690
+
691
+ if _exists(dfs_path):
692
+ df = pd.read_csv(dfs_path)
693
+ _plot_km_with_hr_and_atrisk(df,
694
+ title=f"Disease-Free Survival (DFS) — Kaplan-Meier Curves ({split})",
695
+ save_path=os.path.join(fig_dir, f"DFS_{split}"),
696
+ n_total=len(df))
697
+ else:
698
+ print(f"[plot_survival] Skip DFS {split}: missing {dfs_path}")
699
+
700
+ if _exists(os_path):
701
+ df = pd.read_csv(os_path)
702
+ _plot_km_with_hr_and_atrisk(df,
703
+ title=f"Overall Survival (OS) — Kaplan-Meier Curves ({split})",
704
+ save_path=os.path.join(fig_dir, f"OS_{split}"),
705
+ n_total=len(df))
706
+ else:
707
+ print(f"[plot_survival] Skip OS {split}: missing {os_path}")
708
+
709
+ print("✔ DFS / OS KM figures generated (where available).")
710
+
711
+
712
+ # ============================================================
713
+ # Public entry: plot_all
714
+ # ============================================================
715
+ def plot_all(result_dir="./results", fig_dir="./figures",
716
+ do_subtype=True, do_tnm=True, do_survival=True):
717
+ _ensure_dir(fig_dir)
718
+
719
+ if do_subtype:
720
+ plot_subtype_binary(result_dir=result_dir, fig_dir=fig_dir)
721
+
722
+ if do_tnm:
723
+ plot_tnm_multiclass(result_dir=result_dir, fig_dir=fig_dir)
724
+
725
+ if do_survival:
726
+ plot_survival(result_dir=result_dir, fig_dir=fig_dir)
727
+
728
+
729
+ # ============================================================
730
+ # CLI usage (optional)
731
+ # ============================================================
732
+ if __name__ == "__main__":
733
+ plot_all("./results", "./figures")