File size: 3,116 Bytes
a19a7aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
# mm_dls/PatientDataset.py
import os
import numpy as np
import pandas as pd
import torch
from torch.utils.data import Dataset
from PIL import Image
from torchvision import transforms


class PatientDataset(Dataset):
    def __init__(
        self,
        data_root,
        clinical_csv,
        radiomics_npy,
        pet_npy,
        n_slices=30,
        img_size=224
    ):
        super().__init__()

        self.data_root = data_root
        self.df = pd.read_csv(clinical_csv)
        self.radiomics = np.load(radiomics_npy) 
        self.pet = np.load(pet_npy)              

        self.n_slices = n_slices

        self.transform = transforms.Compose([
            transforms.Resize((img_size, img_size)),
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.df)

    def _load_slices(self, folder):
        files = sorted(os.listdir(folder))[: self.n_slices]
        imgs = []
        for f in files:
            img = Image.open(os.path.join(folder, f)).convert("L")
            imgs.append(self.transform(img))
        imgs = torch.stack(imgs, dim=0)  # [S,1,H,W]
        return imgs

    def __getitem__(self, idx):
        row = self.df.iloc[idx]
        pid = row["pid"]

        # -------- images --------
        lesion_dir = os.path.join(self.data_root, "images", pid, "lesion")
        space_dir  = os.path.join(self.data_root, "images", pid, "space")

        lesion = self._load_slices(lesion_dir)
        space  = self._load_slices(space_dir)

        # -------- tabular --------
        radiomics = torch.tensor(self.radiomics[idx], dtype=torch.float32)
        pet = torch.tensor(self.pet[idx], dtype=torch.float32)
        clinical = torch.zeros(6)  

        # -------- labels --------
        y_sub = torch.tensor(row["subtype"], dtype=torch.long)
        y_tnm = torch.tensor(row["tnm_stage"], dtype=torch.long)

        dfs_time  = torch.tensor(row["dfs_time"], dtype=torch.float32)
        dfs_event = torch.tensor(row["dfs_event"], dtype=torch.float32)

        os_time  = torch.tensor(row["os_time"], dtype=torch.float32)
        os_event = torch.tensor(row["os_event"], dtype=torch.float32)

        # 1y / 3y / 5y
        dfs_1y = torch.tensor(row["dfs_time"] <= 12, dtype=torch.float32)
        dfs_3y = torch.tensor(row["dfs_time"] <= 36, dtype=torch.float32)
        dfs_5y = torch.tensor(row["dfs_time"] <= 60, dtype=torch.float32)

        os_1y = torch.tensor(row["os_time"] <= 12, dtype=torch.float32)
        os_3y = torch.tensor(row["os_time"] <= 36, dtype=torch.float32)
        os_5y = torch.tensor(row["os_time"] <= 60, dtype=torch.float32)

        treatment = torch.tensor(row["treatment"], dtype=torch.long)

        return (
            pid,
            lesion,
            space,
            radiomics,
            pet,
            clinical,
            y_sub,
            y_tnm,
            dfs_time,
            dfs_event,
            os_time,
            os_event,
            dfs_1y,
            dfs_3y,
            dfs_5y,
            os_1y,
            os_3y,
            os_5y,
            treatment,   
        )