MM-DLS / mm-dls /FakePatientDataset.py
FangDai's picture
Upload 11 files
a19a7aa verified
import torch
from torch.utils.data import Dataset
import numpy as np
import random
class FakePatientDataset(Dataset):
"""
Controllable synthetic multimodal + survival dataset
You can explicitly control:
- Final AUC (classification)
- Final C-index (DFS / OS)
via interpretable hyperparameters.
Output: 19 items (aligned with run_epoch_verbose)
"""
def __init__(
self,
n_patients=3000,
n_slices=30,
img_size=224,
num_subtypes=2,
num_tnm=3,
seed=2131,
# =========================
# ---- AUC controllers ----
# =========================
tabular_signal_dims=16, # ↑ dims β†’ ↑ AUC
tabular_signal_strength=0.40, # ↑ strength β†’ ↑ AUC
label_flip_rate=0.10, # ↑ noise β†’ ↓ AUC
# =========================
# ---- C-index controllers
# =========================
risk_noise=1.0, # ↑ noise β†’ ↓ C-index
dfs_time_noise=6.0,
os_time_noise=7.0,
event_sharpness=1.3, # ↑ β†’ HRζ›΄ζ˜Žζ˜Ύ
):
super().__init__()
random.seed(seed)
np.random.seed(seed)
self.n = n_patients
self.n_slices = n_slices
self.img_size = img_size
self.num_subtypes = num_subtypes
self.num_tnm = num_tnm
self.tabular_signal_dims = tabular_signal_dims
self.tabular_signal_strength = tabular_signal_strength
self.label_flip_rate = label_flip_rate
self.risk_noise = risk_noise
self.dfs_time_noise = dfs_time_noise
self.os_time_noise = os_time_noise
self.event_sharpness = event_sharpness
# =========================
# Treatment cohort
# =========================
self.treatment = np.random.choice(
[0, 1],
size=self.n,
p=[2374 / (2374 + 1790), 1790 / (2374 + 1790)]
).astype(np.int64)
# =========================
# Ground-truth labels
# =========================
self.subtype = np.random.randint(0, num_subtypes, size=self.n).astype(np.int64)
self.tnm = np.random.randint(0, num_tnm, size=self.n).astype(np.int64)
# =========================
# Latent biological risk
# =========================
base_risk = (
0.6 * self.subtype +
0.5 * self.tnm +
0.4 * self.treatment +
np.random.normal(0, self.risk_noise, size=self.n)
)
# =========================
# Survival times
# =========================
self.dfs_time = np.clip(
60 - 7.0 * base_risk + np.random.normal(0, self.dfs_time_noise, size=self.n),
3, 96
)
self.os_time = np.clip(
75 - 8.5 * base_risk + np.random.normal(0, self.os_time_noise, size=self.n),
6, 120
)
# =========================
# Event indicators (soft)
# =========================
p_dfs = 1 / (1 + np.exp(-(base_risk - 0.2) * self.event_sharpness))
p_os = 1 / (1 + np.exp(-(base_risk - 0.4) * self.event_sharpness))
self.dfs_event = (np.random.rand(self.n) < p_dfs).astype(np.float32)
self.os_event = (np.random.rand(self.n) < p_os).astype(np.float32)
# =========================
# Time-point labels
# =========================
self.dfs_1y = (self.dfs_time <= 12).astype(np.float32)
self.dfs_3y = (self.dfs_time <= 36).astype(np.float32)
self.dfs_5y = (self.dfs_time <= 60).astype(np.float32)
self.os_1y = (self.os_time <= 12).astype(np.float32)
self.os_3y = (self.os_time <= 36).astype(np.float32)
self.os_5y = (self.os_time <= 60).astype(np.float32)
def __len__(self):
return self.n
def __getitem__(self, idx):
s = int(self.subtype[idx])
t = int(self.tnm[idx])
tr = int(self.treatment[idx])
# =========================
# Label noise (controls AUC ceiling)
# =========================
if np.random.rand() < self.label_flip_rate:
s = 1 - s
# =========================
# IMAGE: very weak signal
# =========================
base_img = np.random.normal(0.5, 0.30, (self.img_size, self.img_size)).astype(np.float32)
base_img += 0.03 * s + 0.02 * t + 0.02 * tr
base_img = np.clip(base_img, 0, 1)
lesion = torch.from_numpy(
np.repeat(base_img[None, None, ...], self.n_slices, axis=0)
)
space = lesion.clone()
# =========================
# TABULAR: main discriminative signal
# =========================
radiomics = np.random.normal(0, 1.0, 128).astype(np.float32)
radiomics[:self.tabular_signal_dims] += (
self.tabular_signal_strength * s +
0.7 * self.tabular_signal_strength * t +
np.random.normal(0, 0.8, self.tabular_signal_dims)
)
pet = np.random.normal(0, 1.0, 5).astype(np.float32)
pet[:2] += 0.5 * self.tabular_signal_strength * s + np.random.normal(0, 0.7, 2)
clinical = np.random.normal(0, 1.0, 6).astype(np.float32)
clinical[:3] += 0.5 * self.tabular_signal_strength * t + np.random.normal(0, 0.7, 3)
return (
f"P{idx:04d}",
lesion.float(),
space.float(),
torch.from_numpy(radiomics),
torch.from_numpy(pet),
torch.from_numpy(clinical),
torch.tensor(s, dtype=torch.long),
torch.tensor(t, dtype=torch.long),
torch.tensor(self.dfs_time[idx], dtype=torch.float32),
torch.tensor(self.dfs_event[idx], dtype=torch.float32),
torch.tensor(self.os_time[idx], dtype=torch.float32),
torch.tensor(self.os_event[idx], dtype=torch.float32),
torch.tensor(self.dfs_1y[idx], dtype=torch.float32),
torch.tensor(self.dfs_3y[idx], dtype=torch.float32),
torch.tensor(self.dfs_5y[idx], dtype=torch.float32),
torch.tensor(self.os_1y[idx], dtype=torch.float32),
torch.tensor(self.os_3y[idx], dtype=torch.float32),
torch.tensor(self.os_5y[idx], dtype=torch.float32),
torch.tensor(tr, dtype=torch.long),
)