File size: 6,436 Bytes
a19a7aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 |
import torch
from torch.utils.data import Dataset
import numpy as np
import random
class FakePatientDataset(Dataset):
"""
Controllable synthetic multimodal + survival dataset
You can explicitly control:
- Final AUC (classification)
- Final C-index (DFS / OS)
via interpretable hyperparameters.
Output: 19 items (aligned with run_epoch_verbose)
"""
def __init__(
self,
n_patients=3000,
n_slices=30,
img_size=224,
num_subtypes=2,
num_tnm=3,
seed=2131,
# =========================
# ---- AUC controllers ----
# =========================
tabular_signal_dims=16, # β dims β β AUC
tabular_signal_strength=0.40, # β strength β β AUC
label_flip_rate=0.10, # β noise β β AUC
# =========================
# ---- C-index controllers
# =========================
risk_noise=1.0, # β noise β β C-index
dfs_time_noise=6.0,
os_time_noise=7.0,
event_sharpness=1.3, # β β HRζ΄ζζΎ
):
super().__init__()
random.seed(seed)
np.random.seed(seed)
self.n = n_patients
self.n_slices = n_slices
self.img_size = img_size
self.num_subtypes = num_subtypes
self.num_tnm = num_tnm
self.tabular_signal_dims = tabular_signal_dims
self.tabular_signal_strength = tabular_signal_strength
self.label_flip_rate = label_flip_rate
self.risk_noise = risk_noise
self.dfs_time_noise = dfs_time_noise
self.os_time_noise = os_time_noise
self.event_sharpness = event_sharpness
# =========================
# Treatment cohort
# =========================
self.treatment = np.random.choice(
[0, 1],
size=self.n,
p=[2374 / (2374 + 1790), 1790 / (2374 + 1790)]
).astype(np.int64)
# =========================
# Ground-truth labels
# =========================
self.subtype = np.random.randint(0, num_subtypes, size=self.n).astype(np.int64)
self.tnm = np.random.randint(0, num_tnm, size=self.n).astype(np.int64)
# =========================
# Latent biological risk
# =========================
base_risk = (
0.6 * self.subtype +
0.5 * self.tnm +
0.4 * self.treatment +
np.random.normal(0, self.risk_noise, size=self.n)
)
# =========================
# Survival times
# =========================
self.dfs_time = np.clip(
60 - 7.0 * base_risk + np.random.normal(0, self.dfs_time_noise, size=self.n),
3, 96
)
self.os_time = np.clip(
75 - 8.5 * base_risk + np.random.normal(0, self.os_time_noise, size=self.n),
6, 120
)
# =========================
# Event indicators (soft)
# =========================
p_dfs = 1 / (1 + np.exp(-(base_risk - 0.2) * self.event_sharpness))
p_os = 1 / (1 + np.exp(-(base_risk - 0.4) * self.event_sharpness))
self.dfs_event = (np.random.rand(self.n) < p_dfs).astype(np.float32)
self.os_event = (np.random.rand(self.n) < p_os).astype(np.float32)
# =========================
# Time-point labels
# =========================
self.dfs_1y = (self.dfs_time <= 12).astype(np.float32)
self.dfs_3y = (self.dfs_time <= 36).astype(np.float32)
self.dfs_5y = (self.dfs_time <= 60).astype(np.float32)
self.os_1y = (self.os_time <= 12).astype(np.float32)
self.os_3y = (self.os_time <= 36).astype(np.float32)
self.os_5y = (self.os_time <= 60).astype(np.float32)
def __len__(self):
return self.n
def __getitem__(self, idx):
s = int(self.subtype[idx])
t = int(self.tnm[idx])
tr = int(self.treatment[idx])
# =========================
# Label noise (controls AUC ceiling)
# =========================
if np.random.rand() < self.label_flip_rate:
s = 1 - s
# =========================
# IMAGE: very weak signal
# =========================
base_img = np.random.normal(0.5, 0.30, (self.img_size, self.img_size)).astype(np.float32)
base_img += 0.03 * s + 0.02 * t + 0.02 * tr
base_img = np.clip(base_img, 0, 1)
lesion = torch.from_numpy(
np.repeat(base_img[None, None, ...], self.n_slices, axis=0)
)
space = lesion.clone()
# =========================
# TABULAR: main discriminative signal
# =========================
radiomics = np.random.normal(0, 1.0, 128).astype(np.float32)
radiomics[:self.tabular_signal_dims] += (
self.tabular_signal_strength * s +
0.7 * self.tabular_signal_strength * t +
np.random.normal(0, 0.8, self.tabular_signal_dims)
)
pet = np.random.normal(0, 1.0, 5).astype(np.float32)
pet[:2] += 0.5 * self.tabular_signal_strength * s + np.random.normal(0, 0.7, 2)
clinical = np.random.normal(0, 1.0, 6).astype(np.float32)
clinical[:3] += 0.5 * self.tabular_signal_strength * t + np.random.normal(0, 0.7, 3)
return (
f"P{idx:04d}",
lesion.float(),
space.float(),
torch.from_numpy(radiomics),
torch.from_numpy(pet),
torch.from_numpy(clinical),
torch.tensor(s, dtype=torch.long),
torch.tensor(t, dtype=torch.long),
torch.tensor(self.dfs_time[idx], dtype=torch.float32),
torch.tensor(self.dfs_event[idx], dtype=torch.float32),
torch.tensor(self.os_time[idx], dtype=torch.float32),
torch.tensor(self.os_event[idx], dtype=torch.float32),
torch.tensor(self.dfs_1y[idx], dtype=torch.float32),
torch.tensor(self.dfs_3y[idx], dtype=torch.float32),
torch.tensor(self.dfs_5y[idx], dtype=torch.float32),
torch.tensor(self.os_1y[idx], dtype=torch.float32),
torch.tensor(self.os_3y[idx], dtype=torch.float32),
torch.tensor(self.os_5y[idx], dtype=torch.float32),
torch.tensor(tr, dtype=torch.long),
)
|