|
|
import torch |
|
|
from torch.utils.data import Dataset |
|
|
import numpy as np |
|
|
import random |
|
|
|
|
|
|
|
|
class FakePatientDataset(Dataset): |
|
|
""" |
|
|
Controllable synthetic multimodal + survival dataset |
|
|
|
|
|
You can explicitly control: |
|
|
- Final AUC (classification) |
|
|
- Final C-index (DFS / OS) |
|
|
via interpretable hyperparameters. |
|
|
|
|
|
Output: 19 items (aligned with run_epoch_verbose) |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
n_patients=3000, |
|
|
n_slices=30, |
|
|
img_size=224, |
|
|
num_subtypes=2, |
|
|
num_tnm=3, |
|
|
seed=2131, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tabular_signal_dims=16, |
|
|
tabular_signal_strength=0.40, |
|
|
label_flip_rate=0.10, |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
risk_noise=1.0, |
|
|
dfs_time_noise=6.0, |
|
|
os_time_noise=7.0, |
|
|
event_sharpness=1.3, |
|
|
): |
|
|
super().__init__() |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
|
|
|
self.n = n_patients |
|
|
self.n_slices = n_slices |
|
|
self.img_size = img_size |
|
|
self.num_subtypes = num_subtypes |
|
|
self.num_tnm = num_tnm |
|
|
|
|
|
self.tabular_signal_dims = tabular_signal_dims |
|
|
self.tabular_signal_strength = tabular_signal_strength |
|
|
self.label_flip_rate = label_flip_rate |
|
|
|
|
|
self.risk_noise = risk_noise |
|
|
self.dfs_time_noise = dfs_time_noise |
|
|
self.os_time_noise = os_time_noise |
|
|
self.event_sharpness = event_sharpness |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.treatment = np.random.choice( |
|
|
[0, 1], |
|
|
size=self.n, |
|
|
p=[2374 / (2374 + 1790), 1790 / (2374 + 1790)] |
|
|
).astype(np.int64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.subtype = np.random.randint(0, num_subtypes, size=self.n).astype(np.int64) |
|
|
self.tnm = np.random.randint(0, num_tnm, size=self.n).astype(np.int64) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_risk = ( |
|
|
0.6 * self.subtype + |
|
|
0.5 * self.tnm + |
|
|
0.4 * self.treatment + |
|
|
np.random.normal(0, self.risk_noise, size=self.n) |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dfs_time = np.clip( |
|
|
60 - 7.0 * base_risk + np.random.normal(0, self.dfs_time_noise, size=self.n), |
|
|
3, 96 |
|
|
) |
|
|
self.os_time = np.clip( |
|
|
75 - 8.5 * base_risk + np.random.normal(0, self.os_time_noise, size=self.n), |
|
|
6, 120 |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
p_dfs = 1 / (1 + np.exp(-(base_risk - 0.2) * self.event_sharpness)) |
|
|
p_os = 1 / (1 + np.exp(-(base_risk - 0.4) * self.event_sharpness)) |
|
|
|
|
|
self.dfs_event = (np.random.rand(self.n) < p_dfs).astype(np.float32) |
|
|
self.os_event = (np.random.rand(self.n) < p_os).astype(np.float32) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.dfs_1y = (self.dfs_time <= 12).astype(np.float32) |
|
|
self.dfs_3y = (self.dfs_time <= 36).astype(np.float32) |
|
|
self.dfs_5y = (self.dfs_time <= 60).astype(np.float32) |
|
|
|
|
|
self.os_1y = (self.os_time <= 12).astype(np.float32) |
|
|
self.os_3y = (self.os_time <= 36).astype(np.float32) |
|
|
self.os_5y = (self.os_time <= 60).astype(np.float32) |
|
|
|
|
|
def __len__(self): |
|
|
return self.n |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
s = int(self.subtype[idx]) |
|
|
t = int(self.tnm[idx]) |
|
|
tr = int(self.treatment[idx]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if np.random.rand() < self.label_flip_rate: |
|
|
s = 1 - s |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
base_img = np.random.normal(0.5, 0.30, (self.img_size, self.img_size)).astype(np.float32) |
|
|
base_img += 0.03 * s + 0.02 * t + 0.02 * tr |
|
|
base_img = np.clip(base_img, 0, 1) |
|
|
|
|
|
lesion = torch.from_numpy( |
|
|
np.repeat(base_img[None, None, ...], self.n_slices, axis=0) |
|
|
) |
|
|
space = lesion.clone() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
radiomics = np.random.normal(0, 1.0, 128).astype(np.float32) |
|
|
radiomics[:self.tabular_signal_dims] += ( |
|
|
self.tabular_signal_strength * s + |
|
|
0.7 * self.tabular_signal_strength * t + |
|
|
np.random.normal(0, 0.8, self.tabular_signal_dims) |
|
|
) |
|
|
|
|
|
pet = np.random.normal(0, 1.0, 5).astype(np.float32) |
|
|
pet[:2] += 0.5 * self.tabular_signal_strength * s + np.random.normal(0, 0.7, 2) |
|
|
|
|
|
clinical = np.random.normal(0, 1.0, 6).astype(np.float32) |
|
|
clinical[:3] += 0.5 * self.tabular_signal_strength * t + np.random.normal(0, 0.7, 3) |
|
|
|
|
|
return ( |
|
|
f"P{idx:04d}", |
|
|
|
|
|
lesion.float(), |
|
|
space.float(), |
|
|
|
|
|
torch.from_numpy(radiomics), |
|
|
torch.from_numpy(pet), |
|
|
torch.from_numpy(clinical), |
|
|
|
|
|
torch.tensor(s, dtype=torch.long), |
|
|
torch.tensor(t, dtype=torch.long), |
|
|
|
|
|
torch.tensor(self.dfs_time[idx], dtype=torch.float32), |
|
|
torch.tensor(self.dfs_event[idx], dtype=torch.float32), |
|
|
|
|
|
torch.tensor(self.os_time[idx], dtype=torch.float32), |
|
|
torch.tensor(self.os_event[idx], dtype=torch.float32), |
|
|
|
|
|
torch.tensor(self.dfs_1y[idx], dtype=torch.float32), |
|
|
torch.tensor(self.dfs_3y[idx], dtype=torch.float32), |
|
|
torch.tensor(self.dfs_5y[idx], dtype=torch.float32), |
|
|
|
|
|
torch.tensor(self.os_1y[idx], dtype=torch.float32), |
|
|
torch.tensor(self.os_3y[idx], dtype=torch.float32), |
|
|
torch.tensor(self.os_5y[idx], dtype=torch.float32), |
|
|
|
|
|
torch.tensor(tr, dtype=torch.long), |
|
|
) |
|
|
|