File size: 9,533 Bytes
9b0ac82
 
 
 
 
 
 
 
 
 
 
 
 
 
1338ce4
9b0ac82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1338ce4
9b0ac82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
#%%
import pandas as pd
import numpy as np
import toml
import datetime
import torch
import random
import ast
import os
import glob
import nibabel as nib
import json

from tqdm import tqdm
#from .MRI_load import load_mris
from collections import defaultdict
from time import time
from copy import deepcopy
from icecream import ic
#%%


value_mapping = {
    'his_SEX':          {'female': 0, 'male': 1},
    'his_HISPANIC':     {'no': 0, 'yes': 1},
    'his_NACCNIHR':     {'whi': 0, 'blk': 1, 'asi': 2, 'ind': 3, 'haw': 4, 'mul': 5},
}

class CSVDataset:

    def __init__(self, dat_file, cnf_file, img_mode=0, arch=None, transforms=None, stripped=None):
        ''' ... '''
        # load data csv
        if isinstance(dat_file, str):
            print(dat_file)
            df = pd.read_csv(dat_file)
        else:
            df = dat_file
            
        # load configuration file
        self.cnf = toml.load(cnf_file)
        
        if 'ID' in df.columns:
            self.ids = list(df['ID'])

        df.reset_index(drop=True, inplace=True)

        # check feature availability in data file
        print('Out of {} features in configuration file, '.format(len(self.cnf['feature'])), end='')
        tmp = [fea for fea in self.cnf['feature'] if fea not in df.columns]
        print('{} are unavailable in data file.'.format(tmp))

        # check label availability in data file
        print('Out of {} labels in configuration file, '.format(len(self.cnf['label'])), end='')
        tmp = [lbl for lbl in self.cnf['label'] if lbl not in df.columns]
        print('{} are unavailable in data file.'.format(len(tmp)))

        self.cnf['feature'] = {k:v for k,v in self.cnf['feature'].items() if k in df.columns}
        self.cnf['label'] = {k:v for k,v in self.cnf['label'].items() if k in df.columns}

        # get feature and label names
        features = list(self.cnf['feature'].keys())
        labels = list(self.cnf['label'].keys())

        # omit features that are not present in dat_file
        features = [fea for fea in features if fea in df.columns]
        shapes = []

        # mri
        error_features = []
        img_fea_to_pop = []
        total = 0
        total_cohorts = {}
        for fea in self.cnf['feature'].keys():
            # print('fea: ', fea)
            if self.cnf['feature'][fea]['type'] == 'imaging':
                print('imaging..')
                if img_mode == -1:
                    # to train non imaging model
                    img_fea_to_pop.append(fea)
                elif img_mode == 0:
                    print("fea: ", fea)
                    filenames = df[fea].dropna().to_list()
                    with open('notexists.txt', 'a') as f:
                        for fn in filenames:
                            if fn is None:
                                continue
                            if not os.path.exists(fn):
                                f.write(fn+'\n')
                            
                            mri_name = fn.split('/')[-1]
                            total += 1

                # load MRI embeddings 
                elif img_mode == 1:
                    print("fea: ", fea)
                    filenames = df[fea].to_list()
                    # print(len(filenames))
                    
                    if len(df[~df[fea].isna()]) == 0:
                        continue
                    # print(fea)
                        
                    npy = []
                    n = 0
                    for fn in tqdm(filenames):
                        try:
                            # print('fn: ', fn)
                            data = np.load(fn, mmap_mode='r')
                            if np.isnan(data).any():
                                npy.append(None)
                                continue
                            
                            if data.shape[-1] == 9 or data.shape[-1] == 10:
                                npy.append(None)
                                continue
                            
                            shapes.append(data.shape)
                            
                            if len(data.shape) < 5:
                                data = np.expand_dims(data, axis=0)
                                    
                            npy.append(data)

                            self.cnf['feature'][fea]['shape'] = data.shape
                            self.cnf['feature'][fea]['img_shape'] = data.shape
                            
                            
                            # print(data.shape)
                            n += 1
                        except:
                            npy.append(None)
                    # print(self.cnf['feature'][fea]['shape'])
                    print(f"{n} MRI embeddings found with shape {self.cnf['feature'][fea]['shape']}")
                    if n == 0 or len(self.cnf['feature'][fea]['shape']) == 1:
                        error_features.append(fea)
                    total += n
                    print(len(df), len(npy))
                    df[fea] = npy
                    # return

                elif img_mode == 2: 
                    # load MRIs and use swinunetr model to get the embeddings
                    print('img_mode is 2')
                    #embedding_dict = load_mris.get_emb('filename', df, arch=arch, transforms=transforms, stripped=stripped)
                    mri_embeddings = []
                    for index, row in df.iterrows():
                        filename = row['filename']
                        print(filename)
                        if filename in embedding_dict:

                            emb = embedding_dict[filename].flatten()
                            mri_embeddings.append(emb)
                            self.cnf['feature'][fea]['shape'] = emb.shape
                            self.cnf['feature'][fea]['img_shape'] = emb.shape
                        else:
                            mri_embeddings.append(None)
                    print(avail)

                    df[fea] = mri_embeddings
                    if 'img_shape' in self.cnf['feature'][fea]:
                        print(self.cnf['feature'][fea]['img_shape'])

        print(f"Total mri embeddings found: {total}")

        for fea in img_fea_to_pop:
            self.cnf['feature'].pop(fea)

        df = df.drop(img_fea_to_pop, axis=1)
        features = [fea for fea in features if fea in df.columns]
        labels = [lab for lab in labels if lab in df.columns]

        # drop columns that are not present in configuration
        df = df[features + labels]

        # drop rows where ALL features are missing
        df_fea = df[features]
        df_fea = df_fea.dropna(how='all')
        print('Out of {} samples, {} are dropped due to complete feature missing.'.format(len(df), len(df) - len(df_fea)))
        df = df[df.index.isin(df_fea.index)]
        df.reset_index(drop=True, inplace=True)

        # drop rows where ALL labels are missing
        df_lbl = df[labels]
        df_lbl = df_lbl.dropna(how='all')
        print('Out of {} samples, {} are dropped due to complete label missing.'.format(len(df), len(df) - len(df_lbl)))
        df = df[df.index.isin(df_lbl.index)]
        df.reset_index(drop=True, inplace=True)
    
        print(set(shapes))
        print("Error features ")
        print(error_features)

        # some of the values need to be mapped to the desirable domain
        for name in features + labels:
            if name in value_mapping:
                col = df[name].to_list()
                try:
                    col = [value_mapping[name][s] if not pd.isnull(s) else None for s in col]
                except KeyError as err:
                    print(err, name)
                    exit()
                df[name] = col
                
        # print(features)
        
        df = df.dropna(axis=1, how='all')
        features = [fea for fea in features if fea in df.columns]
        labels = [lab for lab in labels if lab in df.columns]
        

        # change np.nan to None
        df.replace({np.nan: None}, inplace=True)

        
        # done for df
        self.df = df

        # construct dictionaries for features and labels
        self.features, self.labels = [], []
        keys = df.columns.values.tolist()
        for i in range(len(df)):
            vals = df.iloc[i].to_list()
            self.features.append(dict(zip(keys[:len(features)], vals[:len(features)])))
            self.labels.append(dict(zip(keys[len(features):], vals[len(features):])))
        
        # test: remove if None
        for i in range(len(self.features)):
            for k, v in list(self.features[i].items()):
                if v is None:
                    self.features[i].pop(k)


        # getting label fractions
        self.label_fractions = {}
        for label in labels:
            # self.label_fractions[label] = dict(self.df[label].value_counts() / len(self.df[~self.df[label].isna()]))
            self.label_fractions[label] = dict(self.df[label].value_counts() / len(self.df))

    def __len__(self):
        ''' ... '''
        return len(self.df)

    def __getitem__(self, idx):
        ''' ... '''
        return self.features[idx], self.labels[idx]

    
    @property
    def feature_modalities(self):
        ''' ... '''
        return self.cnf['feature']

    @property
    def label_modalities(self):
        ''' ... '''
        return self.cnf['label']