Esten Leonardsen commited on
Commit
55880f9
·
1 Parent(s): 539bc34

Finished first version of scripts necessary to finetune models

Browse files
new_packages.txt CHANGED
@@ -13,4 +13,5 @@ tqdm==4.66.4
13
  plotly==5.24.1
14
  pytest==8.3.3
15
  scikit-learn==1.5.1
16
- xlrd==2.0.1
 
 
13
  plotly==5.24.1
14
  pytest==8.3.3
15
  scikit-learn==1.5.1
16
+ xlrd==2.0.1
17
+ pydantic==2.10
pyment/configurations/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .dataset_configuration import DatasetConfiguration
2
+ from .finetuning_configuration import FinetuningConfiguration
pyment/configurations/data_split_configuration.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+
6
+ class DataSplitConfiguration(BaseModel):
7
+ training_fraction: float
8
+ stratification: List[str] = None
pyment/configurations/dataset_configuration.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ import re
6
+ import numpy as np
7
+ import pandas as pd
8
+ from collections import Counter
9
+ from typing import Dict, List, Tuple, Union
10
+
11
+ import nibabel as nib
12
+ from pydantic import model_validator, BaseModel
13
+
14
+ from .data_split_configuration import DataSplitConfiguration
15
+
16
+
17
+ logging.basicConfig(
18
+ format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
19
+ level=logging.INFO
20
+ )
21
+ logger = logging.getLogger(__name__)
22
+
23
+ def _extract_run(filename: str) -> Union[str, None]:
24
+ match = re.fullmatch(r'.*_run-(?P<run>[^_.]*)(?:_.*)?\.mgz', filename)
25
+
26
+ if match:
27
+ return match.group('run')
28
+
29
+ logger.warning('Unable to extract run from filename %s', filename)
30
+
31
+ return None
32
+
33
+ def _parse_bids_folder(root: str):
34
+ entries = []
35
+
36
+ for subject_folder in os.listdir(root):
37
+ subject_match = re.fullmatch(r'sub-(?P<subject>.*)', subject_folder)
38
+
39
+ if not subject_match:
40
+ logger.warning(
41
+ 'Subject folder %s in %s does not have the expected sub-XXX '
42
+ 'format. Skipping', subject_folder, root
43
+ )
44
+ continue
45
+
46
+ subject = subject_match.group('subject')
47
+
48
+ for session_folder in os.listdir(os.path.join(root, subject_folder)):
49
+ session_match = re.fullmatch(
50
+ r'ses-(?P<session>.*)', session_folder
51
+ )
52
+
53
+ if not session_match:
54
+ logger.warning(
55
+ 'Session folder %s in subject %s in folder %s does not '
56
+ 'match the expected ses-XXX format. Skipping',
57
+ session_folder, subject_folder, root
58
+ )
59
+ continue
60
+
61
+ session = session_match.group('session')
62
+
63
+ anat_folder = os.path.join(
64
+ root, subject_folder, session_folder, 'anat'
65
+ )
66
+
67
+ t1s = [
68
+ filename for filename in os.listdir(anat_folder)
69
+ if 'T1' in filename
70
+ ]
71
+
72
+ for filename in t1s:
73
+ run = _extract_run(filename)
74
+ entries.append({
75
+ 'subject': subject,
76
+ 'session': session,
77
+ 'run': run,
78
+ 'path': os.path.join(anat_folder, filename)
79
+ })
80
+
81
+ return pd.DataFrame(entries, columns=['subject', 'session', 'run', 'path'])
82
+
83
+ def _parse_bids_folders(folders: List[str]):
84
+ df = pd.concat([_parse_bids_folder(folder) for folder in folders])
85
+ df = df.reset_index()
86
+ logger.info('Parsed %d images', len(df))
87
+
88
+ return df
89
+
90
+ def _parse_fastsurfer_name(name: str) -> Tuple[str, str, str]:
91
+ match = re.fullmatch(r'sub-(.*)_ses-(.*)_run-(.*)(?:T1w?)?', name)
92
+
93
+ if not match:
94
+ raise ValueError(
95
+ 'Unable to extract subject, session, run from folder %s', name
96
+ )
97
+
98
+ return match.groups()
99
+
100
+ def _parse_fastsurfer_folder(folder: str):
101
+ entries = []
102
+
103
+ for subfolder in os.listdir(folder):
104
+ subject, session, run = _parse_fastsurfer_name(subfolder)
105
+
106
+ mri_folder = os.path.join(folder, subfolder, 'mri')
107
+ brainmask = os.path.join(mri_folder, 'brainmask.mgz')
108
+
109
+ if not os.path.isfile(brainmask):
110
+ logger.info('Brainmask does not exist in folder %s', subfolder)
111
+
112
+ orig = os.path.join(mri_folder, 'orig.mgz')
113
+ mask = os.path.join(mri_folder, 'mask.mgz')
114
+
115
+ if not os.path.isfile(orig):
116
+ logger.error('Orig does not exist in folder %s', subfolder)
117
+ continue
118
+ elif not os.path.isfile(mask):
119
+ logger.error('Mask does not exist in folder %s', subfolder)
120
+ continue
121
+
122
+ orig_data = nib.load(orig)
123
+ mask_data = nib.load(mask)
124
+ brainmask_data = nib.Nifti1Image(
125
+ orig_data.get_fdata() * mask_data.get_fdata(),
126
+ header=orig_data.header,
127
+ affine=orig_data.affine
128
+ )
129
+
130
+ nib.save(brainmask_data, brainmask)
131
+
132
+ entries.append({
133
+ 'subject': subject,
134
+ 'session': session,
135
+ 'run': run,
136
+ 'path': brainmask
137
+ })
138
+
139
+ return pd.DataFrame(entries, columns=['subject', 'session', 'run', 'path'])
140
+
141
+ def _parse_fastsurfer_folders(folders: List[str]):
142
+ df = pd.concat([_parse_fastsurfer_folder(folder) for folder in folders])
143
+ df = df.reset_index()
144
+ logger.info('Parsed %d images', len(df))
145
+
146
+ return df
147
+
148
+ def _summarize_values(values: np.ndarray, name: str):
149
+ if not np.issubdtype(values.dtype, np.number):
150
+ logger.info('%s: %s', name, Counter(values))
151
+ elif np.array_equal(
152
+ np.unique(values[~np.isnan(values)]),
153
+ np.asarray([0, 1])
154
+ ):
155
+ nans = len(np.where(np.isnan(values))[0])
156
+ logger.info(
157
+ '%s: %s (%d NAs)', name, Counter(values[~np.isnan(values)]), nans
158
+ )
159
+ else:
160
+ nans = len(np.where(np.isnan(values))[0])
161
+ mean = np.round(np.nanmean(values), 2)
162
+ std = np.round(np.nanstd(values), 2)
163
+ logger.info('%s: %.2f+/-%.2f (%d NAs)', name, mean, std, nans)
164
+
165
+ def _summarize(df: pd.DataFrame, variables: List[str], name: str):
166
+ logger.info('%s n=%d', name, len(df))
167
+
168
+ for variable in variables:
169
+ _summarize_values(df[variable].values, name=variable)
170
+
171
+
172
+ def _split_training_validation_fold(
173
+ df: pd.DataFrame,
174
+ labels: str,
175
+ training_fraction: float,
176
+ target: str = None,
177
+ stratification: List[str] = None
178
+ ) -> Tuple[pd.DataFrame, pd.DataFrame]:
179
+
180
+ columns = set(['subject', 'session', 'run'])
181
+
182
+ if target:
183
+ columns.add(target)
184
+
185
+ if stratification:
186
+ columns |= set(stratification)
187
+
188
+ labels = pd.read_csv(
189
+ labels,
190
+ usecols=list(columns),
191
+ dtype={'subject': object, 'session': object, 'run': object},
192
+ )
193
+
194
+ logger.info('Parsed %d labels', len(labels))
195
+
196
+ if not len(labels) == len(labels.drop_duplicates(['subject', 'session'])):
197
+ raise ValueError(
198
+ f'There are duplicates (subject, session)-pairs in the labels file'
199
+ )
200
+
201
+ df = pd.merge(
202
+ df, labels,
203
+ how='inner',
204
+ left_on=['subject', 'session'],
205
+ right_on=['subject', 'session']
206
+ )
207
+
208
+ logger.info('Merged %d data points', len(df))
209
+
210
+ if stratification is not None:
211
+ df = df.sort_values(stratification)
212
+
213
+ subjects = df.drop_duplicates('subject')
214
+ num_folds = int(1.0 / (1 - training_fraction))
215
+
216
+ if num_folds == 1:
217
+ raise ValueError(
218
+ 'Training fraction %.2f yields a single fold', training_fraction
219
+ )
220
+
221
+ subjects['fold'] = np.arange(len(df)) % num_folds
222
+ folds = {row['subject']: row['fold'] for _, row in subjects.iterrows()}
223
+ df['fold'] = df['subject'].map(folds)
224
+
225
+ validation_fold = num_folds // 2
226
+ training = df[df['fold'] != validation_fold]
227
+ validation = df[df['fold'] == validation_fold]
228
+
229
+ if len(
230
+ set(training['subject'].values) & set(validation['subject'].values)
231
+ ) > 0:
232
+ raise ValueError('Overlap between training and validation folds')
233
+
234
+ if stratification:
235
+ for name, df in [('Training', training), ('Validation', validation)]:
236
+ _summarize(df, variables=stratification, name=name)
237
+
238
+ return training, validation
239
+
240
+ class DatasetConfiguration(BaseModel):
241
+ input_shape: Tuple[int, int, int]
242
+ bids: List[str] | None = None
243
+ fastsurfer: List[str] | None = None
244
+ labels: str
245
+ split: DataSplitConfiguration = None
246
+
247
+ @model_validator(mode='after')
248
+ def check_fastsurfer_or_bids(self):
249
+ if self.bids is not None and self.fastsurfer is not None:
250
+ raise ValueError(
251
+ 'Either \'bids\' or \'fastsurfer\'-property must be set, not '
252
+ 'both'
253
+ )
254
+ elif self.bids is None and self.fastsurfer is None:
255
+ raise ValueError(
256
+ 'Either \'bids or \'fastsurfer\'-property must be set'
257
+ )
258
+
259
+ return self
260
+
261
+ @staticmethod
262
+ def parse(
263
+ configuration: DatasetConfiguration,
264
+ target: str = None
265
+ ) -> Dict[str, pd.DataFrame]:
266
+ if configuration.split:
267
+ if configuration.bids:
268
+ df = _parse_bids_folders(configuration.bids)
269
+ elif configuration.fastsurfer:
270
+ df = _parse_fastsurfer_folders(configuration.fastsurfer)
271
+ else:
272
+ raise ValueError(
273
+ 'Unable to parse DatasetConfiguration without either '
274
+ '\'bids\' or \'fastsurfer\' set'
275
+ )
276
+
277
+ return _split_training_validation_fold(
278
+ df=df,
279
+ labels=configuration.labels,
280
+ training_fraction=configuration.split.training_fraction,
281
+ target=target,
282
+ stratification=configuration.split.stratification
283
+ )
284
+
285
+ raise NotImplementedError(
286
+ f'Not sure how to parse dataset without a split configuration'
287
+ )
pyment/configurations/finetuning_configuration.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pydantic import BaseModel
2
+
3
+ from .dataset_configuration import DatasetConfiguration
4
+ from .model_configuration import ModelConfiguration
5
+ from .training_configuration import TrainingConfiguration
6
+
7
+
8
+ class FinetuningConfiguration(BaseModel):
9
+ model: ModelConfiguration
10
+ data: DatasetConfiguration
11
+ training: TrainingConfiguration
pyment/configurations/learning_rate_schedule_configuration.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import abstractmethod
2
+ from typing import Annotated, Literal, Union
3
+
4
+ from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+
8
+ class LearningRateScheduleBaseConfiguration(BaseModel):
9
+ model_config = ConfigDict(extra='forbid')
10
+
11
+ @abstractmethod
12
+ def instantiate(self) -> Callback:
13
+ pass
14
+
15
+ class AnnealingLearningRateScheduleConfiguration(
16
+ LearningRateScheduleBaseConfiguration
17
+ ):
18
+ kind: Literal['annealing']
19
+ factor: float
20
+ patience: int
21
+ minimum_learning_rate: float
22
+
23
+ def instantiate(self) -> Callback:
24
+ return ReduceLROnPlateau(
25
+ factor=self.factor,
26
+ patience=self.patience,
27
+ min_lr=self.minimum_learning_rate,
28
+ verbose=True
29
+ )
30
+
31
+ class StepWiseLearningRateScheduleConfiguration(
32
+ LearningRateScheduleBaseConfiguration
33
+ ):
34
+ kind: Literal['stepwise']
35
+
36
+ def instantiate(self) -> Callback:
37
+ return ReduceLROnPl
38
+
39
+ LearningRateScheduleConfiguration = Annotated[
40
+ Union[
41
+ AnnealingLearningRateScheduleConfiguration,
42
+ StepWiseLearningRateScheduleConfiguration
43
+ ],
44
+ Field(discriminator='kind')
45
+ ]
pyment/configurations/model_configuration.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Dict
2
+
3
+ from pydantic import BaseModel, Field
4
+
5
+
6
+ class ModelConfiguration(BaseModel):
7
+ type: str
8
+ hyperparameters: Dict[str, Any] = Field(default_factory=dict)
9
+ weights: str = None
pyment/configurations/training_configuration.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from pydantic import BaseModel
4
+
5
+ from .learning_rate_schedule_configuration import (
6
+ LearningRateScheduleConfiguration
7
+ )
8
+
9
+
10
+ class TrainingConfiguration(BaseModel):
11
+ target: str
12
+ loss: str
13
+ metrics: List[str] = None
14
+ optimizer: str
15
+ learning_rate: float
16
+ learning_rate_schedule: LearningRateScheduleConfiguration = None
17
+ batch_size: int
18
+ epochs: int
19
+ destination: str = None
pyment/factories/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .loss_factory import loss_factory
2
+ from .metric_factory import metric_factory
3
+ from .optimizer_factory import optimizer_factory
pyment/factories/loss_factory.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import tensorflow as tf
4
+
5
+
6
+ def loss_factory(name: str) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
7
+ if name.lower() == 'mse':
8
+ return tf.keras.losses.MeanSquaredError
9
+
10
+ raise KeyError(f'Unknown loss {name}')
pyment/factories/metric_factory.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable
2
+
3
+ import tensorflow as tf
4
+
5
+
6
+ def metric_factory(name: str) -> tf.keras.metrics.Metric:
7
+ if name.lower() == 'mae':
8
+ return tf.keras.metrics.MeanAbsoluteError()
pyment/factories/optimizer_factory.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+
3
+
4
+ def optimizer_factory(name: str) -> tf.optimizers.Optimizer:
5
+ if name.lower() == 'adam':
6
+ return tf.optimizers.Adam
7
+
8
+ raise KeyError(f'Unknown optimizer {name}')
pyment/models/sfcn/__init__.py CHANGED
@@ -1,4 +1,12 @@
1
  from .sfcn import SFCN
2
  from .sfcn_multi import MultiTaskSFCN
 
3
 
4
- __all__ = ['SFCN', 'MultiTaskSFCN']
 
 
 
 
 
 
 
 
1
  from .sfcn import SFCN
2
  from .sfcn_multi import MultiTaskSFCN
3
+ from .sfcn_reg import RegressionSFCN
4
 
5
+
6
+ def sfcn_factory(model_type: str):
7
+ if model_type in ['sfcn-reg', 'regression']:
8
+ return RegressionSFCN
9
+
10
+ raise ValueError(f'Unknown SFCN type {model_type}')
11
+
12
+ __all__ = ['sfcn_factory', 'SFCN', 'MultiTaskSFCN', 'RegressionSFCN']
pyment/models/sfcn/sfcn.py CHANGED
@@ -82,6 +82,8 @@ class SFCN(Model):
82
  weights = ensure_weights(weights)
83
  status = self.load_weights(weights)
84
 
85
- # Silences warnings about optimizer-status not being loaded
86
- status.expect_partial()
87
- status.assert_existing_objects_matched()
 
 
 
82
  weights = ensure_weights(weights)
83
  status = self.load_weights(weights)
84
 
85
+ print(weights)
86
+ if not weights.endswith('hdf5'):
87
+ # Silences warnings about optimizer-status not being loaded
88
+ status.expect_partial()
89
+ status.assert_existing_objects_matched()
pyment/models/sfcn/sfcn_multi.py CHANGED
@@ -7,8 +7,8 @@ from .sfcn import SFCN
7
  class MultiTaskSFCN(SFCN):
8
  @classmethod
9
  def construct_prediction_head(
10
- cls,
11
- bottleneck: Tensor,
12
  name: str
13
  ) -> Tensor:
14
  x = bottleneck
 
7
  class MultiTaskSFCN(SFCN):
8
  @classmethod
9
  def construct_prediction_head(
10
+ cls,
11
+ bottleneck: Tensor,
12
  name: str
13
  ) -> Tensor:
14
  x = bottleneck
pyment/models/sfcn/sfcn_reg.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tensorflow import Tensor
2
+ from tensorflow.keras.layers import Dense
3
+
4
+ from .sfcn import SFCN
5
+
6
+
7
+ class RegressionSFCN(SFCN):
8
+ @classmethod
9
+ def construct_prediction_head(
10
+ cls,
11
+ bottleneck: Tensor,
12
+ name: str
13
+ ) -> Tensor:
14
+ layer = Dense(1, activation=None, name=f'{name}/predictions')
15
+
16
+ return layer(bottleneck)
pyment/models/utils/ensure_weights.py CHANGED
@@ -21,12 +21,17 @@ def ensure_weights(identifier: str) -> str:
21
  ------
22
  KeyError
23
  If the identifier is not a valid identifier and there does not
24
- exist files <identifier>.index and
25
- <identifier>.data-00000-of-00001 on the local file system.
 
26
  """
27
  if not (
28
- os.path.isfile(f'{identifier}.index') and
29
- os.path.isfile(f'{identifier}.data-00000-of-00001')
 
 
 
 
30
  ):
31
  raise NotImplementedError(
32
  f'Identifier-based lookups are not supported'
 
21
  ------
22
  KeyError
23
  If the identifier is not a valid identifier and there does not
24
+ exist either a single file <identifier> or files
25
+ <identifier>.index and <identifier>.data-00000-of-00001 on the
26
+ local file system.
27
  """
28
  if not (
29
+ (
30
+ os.path.isfile(f'{identifier}.index') and
31
+ os.path.isfile(f'{identifier}.data-00000-of-00001')
32
+ ) or (
33
+ os.path.isfile(identifier)
34
+ )
35
  ):
36
  raise NotImplementedError(
37
  f'Identifier-based lookups are not supported'
pyment/models/utils/load_select_pretrained_weights.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+
3
+ import tensorflow as tf
4
+
5
+ from ..sfcn import MultiTaskSFCN
6
+
7
+
8
+ logging.basicConfig(
9
+ format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
10
+ level=logging.DEBUG
11
+ )
12
+ logger = logging.getLogger(__name__)
13
+
14
+ def load_select_pretrained_weights(
15
+ model: tf.keras.Model,
16
+ weights: str,
17
+ target: str = None
18
+ ) -> tf.keras.Model:
19
+ logger.info('Loading pretrained weights from %s', weights)
20
+
21
+ backbone = MultiTaskSFCN(input_shape=(224, 192, 224), pooling='max')
22
+ checkpoint = tf.train.Checkpoint(backbone)
23
+
24
+ checkpoint.restore(weights).expect_partial()
25
+
26
+ conv_layers = [2, 6, 10, 14, 18, 22]
27
+ norm_layers = [3, 7, 11, 15, 19, 23]
28
+
29
+ for idx in conv_layers + norm_layers:
30
+ model.layers[idx].set_weights(backbone.layers[idx].get_weights())
31
+
32
+ # Loading weights from the specific dense-layer corresponding to the
33
+ # given prediction-task in the multi-task model
34
+ if target == 'age':
35
+ logger.info('Loaded age weights for the prediction head')
36
+ model.layers[27].set_weights(backbone.layers[27].get_weights())
37
+ elif target == 'sex':
38
+ logger.info('Loaded sex weights for the prediction head')
39
+ model.layers[27].set_weights(backbone.layers[28].get_weights())
40
+ else:
41
+ logger.warning(
42
+ 'Unknown target %s. Not loading weights for prediction layer',
43
+ target
44
+ )
45
+
46
+ return model
pyment/utils/json_serialize.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import Any
3
+
4
+ def json_serialize(obj: Any) -> Any:
5
+ if isinstance(obj, dict):
6
+ return {json_serialize(k): json_serialize(v) for k, v in obj.items()}
7
+ elif isinstance(obj, list):
8
+ return [json_serialize(v) for v in obj]
9
+ elif isinstance(obj, (np.integer,)):
10
+ return int(obj)
11
+ elif isinstance(obj, (np.floating,)):
12
+ return float(obj)
13
+ elif isinstance(obj, (np.ndarray,)):
14
+ return obj.tolist()
15
+ else:
16
+ return obj
scripts/finetune_from_bids_folder.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import pandas as pd
6
+ from typing import Any, Callable, Dict, List, Tuple
7
+
8
+ import tensorflow as tf
9
+ from tensorflow_neuroimaging.preprocessing import center_crop_or_pad
10
+ from tensorflow_neuroimaging.loaders.mgh import load_mgh
11
+
12
+ from pyment.configurations import DatasetConfiguration, FinetuningConfiguration
13
+ from pyment.factories import loss_factory, metric_factory, optimizer_factory
14
+ from pyment.models.sfcn import sfcn_factory
15
+ from pyment.models.utils.load_select_pretrained_weights import (
16
+ load_select_pretrained_weights
17
+ )
18
+ from pyment.utils.json_serialize import json_serialize
19
+
20
+
21
+ logging.basicConfig(
22
+ format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
23
+ level=logging.DEBUG
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ def _create_tensorflow_dataset(
28
+ df: pd.DataFrame, *,
29
+ target: str,
30
+ input_shape: Tuple[int, int, int],
31
+ batch_size: str,
32
+ shuffle: bool = False
33
+ ) -> tf.data.Dataset:
34
+ input_shape = tf.constant(input_shape)
35
+
36
+ df = df.copy()
37
+ df = df.sample(frac=1.)
38
+
39
+ dataset = tf.data.Dataset.from_tensor_slices((df['path'], df[target]))
40
+
41
+ if shuffle:
42
+ dataset = dataset.shuffle(buffer_size=5*batch_size)
43
+
44
+ dataset = dataset.map(
45
+ lambda path, label: (load_mgh(path), label),
46
+ num_parallel_calls=tf.data.AUTOTUNE
47
+ )
48
+ dataset = dataset.map(
49
+ lambda image, label: (center_crop_or_pad(image, input_shape), label),
50
+ num_parallel_calls=tf.data.AUTOTUNE
51
+ )
52
+
53
+ dataset = dataset.batch(batch_size)
54
+ dataset = dataset.prefetch(tf.data.AUTOTUNE)
55
+
56
+ return dataset
57
+
58
+ def _create_checkpointing_callback(
59
+ destination: str,
60
+ metrics: List[tf.keras.metrics.Metric] = None
61
+ ):
62
+ os.mkdir(destination)
63
+
64
+ train_metrics = []
65
+ val_metrics = []
66
+
67
+ if metrics is not None:
68
+ for metric in metrics:
69
+ name = metric.name.replace('_', '-')
70
+ train_metrics.append(f'{name}={{{metric.name}:.2f}}')
71
+ val_metrics.append(f'val-{name}={{val_{metric.name}:.2f}}')
72
+
73
+ terms = [
74
+ 'epoch={epoch:03d}',
75
+ 'loss={loss:.2f}'
76
+ ] + train_metrics + [
77
+ 'val-loss={val_loss:.2f}'
78
+ ] + val_metrics
79
+ filename = '_'.join(terms) + '.hdf5'
80
+ filepath = os.path.join(destination, filename)
81
+
82
+ return tf.keras.callbacks.ModelCheckpoint(
83
+ filepath,
84
+ monitor='val_loss',
85
+ save_best_only=True,
86
+ save_weights_only=True
87
+ )
88
+
89
+ def finetune(
90
+ model_type: str,
91
+ model_constructor_arguments: Dict[str, Any],
92
+ weights: str,
93
+ input_shape: Tuple[int, int, int],
94
+ target: str,
95
+ loss: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
96
+ metrics: List[tf.keras.metrics.Metric],
97
+ optimizer: tf.optimizers.Optimizer,
98
+ learning_rate_scheduler: tf.keras.callbacks.Callback,
99
+ training: pd.DataFrame,
100
+ validation: pd.DataFrame,
101
+ batch_size: int,
102
+ epochs: int,
103
+ destination: str
104
+ ):
105
+ if destination is not None:
106
+ if os.path.isdir(destination):
107
+ raise ValueError(f'Destination {destination} already exists')
108
+
109
+ logger.info('Creating destination folder %s', destination)
110
+ os.mkdir(destination)
111
+
112
+ model_class = sfcn_factory(model_type)
113
+ model = model_class(
114
+ input_shape=input_shape,
115
+ **model_constructor_arguments
116
+ )
117
+ load_select_pretrained_weights(model, weights, target=target)
118
+
119
+ model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
120
+
121
+ training_dataset = _create_tensorflow_dataset(
122
+ training,
123
+ input_shape=input_shape,
124
+ target=target,
125
+ batch_size=batch_size,
126
+ shuffle=True
127
+ )
128
+ validation_dataset = _create_tensorflow_dataset(
129
+ validation,
130
+ input_shape=input_shape,
131
+ target=target,
132
+ batch_size=batch_size,
133
+ shuffle=False
134
+ )
135
+
136
+ callbacks = [
137
+ _create_checkpointing_callback(
138
+ os.path.join(destination, 'checkpoints'),
139
+ metrics=metrics
140
+ ),
141
+ learning_rate_scheduler
142
+ ]
143
+
144
+ history = model.fit(
145
+ training_dataset,
146
+ validation_data=validation_dataset,
147
+ epochs=epochs,
148
+ callbacks=callbacks
149
+ )
150
+
151
+ with open(os.path.join(destination, 'history.json'), 'w') as f:
152
+ json.dump(json_serialize(history.history), f)
153
+
154
+ def finetune_from_configuration(configuration: str):
155
+ with open(configuration, 'r') as f:
156
+ configuration = json.load(f)
157
+
158
+ configuration = FinetuningConfiguration.model_validate(configuration)
159
+
160
+ training, validation = DatasetConfiguration.parse(
161
+ configuration.data,
162
+ target=configuration.training.target
163
+ )
164
+
165
+ # strategy = tf.distribute.MirroredStrategy()
166
+
167
+ # with strategy.scope():
168
+
169
+ loss_cls = loss_factory(configuration.training.loss)
170
+ loss = loss_cls()
171
+
172
+ optimizer_cls = optimizer_factory(configuration.training.optimizer)
173
+ optimizer = optimizer_cls(configuration.training.learning_rate)
174
+
175
+ metrics = None
176
+
177
+ if configuration.training.metrics is not None:
178
+ metrics = [
179
+ metric_factory(metric)
180
+ for metric in configuration.training.metrics
181
+ ]
182
+
183
+ learning_rate_scheduler = None
184
+
185
+ if configuration.training.learning_rate_schedule:
186
+ learning_rate_scheduler = (
187
+ configuration.training.learning_rate_schedule.instantiate()
188
+ )
189
+
190
+ finetune(
191
+ model_type=configuration.model.type,
192
+ model_constructor_arguments=configuration.model.hyperparameters,
193
+ weights=configuration.model.weights,
194
+ input_shape=configuration.data.input_shape,
195
+ target=configuration.training.target,
196
+ loss=loss,
197
+ metrics=metrics,
198
+ optimizer=optimizer,
199
+ learning_rate_scheduler=learning_rate_scheduler,
200
+ training=training,
201
+ validation=validation,
202
+ batch_size=configuration.training.batch_size,
203
+ epochs=configuration.training.epochs,
204
+ destination=configuration.training.destination
205
+ )
206
+
207
+ if __name__ == '__main__':
208
+ parser = argparse.ArgumentParser(
209
+ 'Finetunes a multi-task SFCN according to the given configuration'
210
+ )
211
+
212
+ parser.add_argument('configuration', help='Path to configuration JSON')
213
+
214
+ args = parser.parse_args()
215
+
216
+ finetune_from_configuration(args.configuration)
scripts/finetune_from_fastsurfer_folder.py ADDED
File without changes
scripts/predict_from_bids_folder.py CHANGED
@@ -4,11 +4,12 @@ import os
4
  import re
5
  import numpy as np
6
  import pandas as pd
 
7
  from tqdm import tqdm
8
 
9
  import nibabel as nib
10
 
11
- from pyment.models import MultiTaskSFCN
12
  from pyment.preprocessing.conform import conform
13
 
14
 
@@ -29,7 +30,11 @@ def _extract_run(filename: str) -> str:
29
 
30
  def predict_from_bids_folder(
31
  source: str,
32
- weights: str,
 
 
 
 
33
  destination: str = None,
34
  per_image_normalization: bool = False
35
  ) -> pd.DataFrame:
@@ -37,7 +42,8 @@ def predict_from_bids_folder(
37
  raise ValueError(f'Destination {destination} already exists')
38
 
39
  logger.info('Loading multi-task model with weights %s', weights)
40
- model = MultiTaskSFCN(weights=weights)
 
41
 
42
  results = []
43
 
@@ -73,16 +79,13 @@ def predict_from_bids_folder(
73
  )
74
 
75
  results.append({
76
- 'source': os.path.join(anat_folder, filename),
77
- 'subject': subject,
78
- 'session': session,
79
- 'run': run,
80
- 'age': predictions[0],
81
- 'sex': predictions[1],
82
- 'handedness': predictions[2],
83
- 'bmi': predictions[3],
84
- 'fluid_intelligence': predictions[4],
85
- 'neuroticism': predictions[5]
86
  })
87
 
88
  results = pd.DataFrame(results)
@@ -108,6 +111,24 @@ if __name__ == '__main__':
108
  'exist files named <path>.index and <path>.data-00000-of-00001'
109
  )
110
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  parser.add_argument(
112
  '-d', '--destination',
113
  required=False,
@@ -128,6 +149,8 @@ if __name__ == '__main__':
128
  predict_from_bids_folder(
129
  source=args.bids,
130
  weights=args.weights,
 
 
131
  destination=args.destination,
132
  per_image_normalization=args.per_image_normalization
133
  )
 
4
  import re
5
  import numpy as np
6
  import pandas as pd
7
+ from typing import List
8
  from tqdm import tqdm
9
 
10
  import nibabel as nib
11
 
12
+ from pyment.models.sfcn import sfcn_factory
13
  from pyment.preprocessing.conform import conform
14
 
15
 
 
30
 
31
  def predict_from_bids_folder(
32
  source: str,
33
+ weights: str,
34
+ model_name: str = 'sfcn-multi',
35
+ targets: List[str] = [
36
+ 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism'
37
+ ],
38
  destination: str = None,
39
  per_image_normalization: bool = False
40
  ) -> pd.DataFrame:
 
42
  raise ValueError(f'Destination {destination} already exists')
43
 
44
  logger.info('Loading multi-task model with weights %s', weights)
45
+ model_class = sfcn_factory(model_name)
46
+ model = model_class(weights=weights)
47
 
48
  results = []
49
 
 
79
  )
80
 
81
  results.append({
82
+ **{
83
+ 'source': path,
84
+ 'subject': subject,
85
+ 'session': session,
86
+ 'run': run
87
+ },
88
+ **{targets[i]: predictions[i] for i in range(len(targets))}
 
 
 
89
  })
90
 
91
  results = pd.DataFrame(results)
 
111
  'exist files named <path>.index and <path>.data-00000-of-00001'
112
  )
113
  )
114
+ parser.add_argument(
115
+ '-m', '--model',
116
+ required=False,
117
+ default='sfcn-multi',
118
+ help=(
119
+ 'Name of the model to use'
120
+ )
121
+ )
122
+ parser.add_argument(
123
+ '-t', '--targets',
124
+ required=False,
125
+ nargs='+',
126
+ default=[
127
+ 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
128
+ 'neuroticism'
129
+ ],
130
+ help='Name to use for each of the prediction heads in the output CSV'
131
+ )
132
  parser.add_argument(
133
  '-d', '--destination',
134
  required=False,
 
149
  predict_from_bids_folder(
150
  source=args.bids,
151
  weights=args.weights,
152
+ model_name=args.model,
153
+ targets=args.targets,
154
  destination=args.destination,
155
  per_image_normalization=args.per_image_normalization
156
  )
scripts/predict_from_fastsurfer_folder.py CHANGED
@@ -5,11 +5,11 @@ import re
5
  import numpy as np
6
  import pandas as pd
7
  from tqdm import tqdm
8
- from typing import Tuple
9
 
10
  import nibabel as nib
11
 
12
- from pyment.models import MultiTaskSFCN
13
  from pyment.preprocessing.conform import conform
14
 
15
 
@@ -29,14 +29,20 @@ def _parse_folder_name(name: str) -> Tuple[str, str, str]:
29
 
30
  def predict_from_fastsurfer_folder(
31
  source: str,
32
- weights: str,
 
 
 
 
33
  destination: str = None
34
  ) -> pd.DataFrame:
35
  if destination is not None and os.path.isfile(destination):
36
  raise ValueError(f'Destination {destination} already exists')
37
 
38
  logger.info('Loading multi-task model with weights %s', weights)
39
- model = MultiTaskSFCN(weights=weights)
 
 
40
 
41
  results = []
42
 
@@ -66,19 +72,18 @@ def predict_from_fastsurfer_folder(
66
  image = conform(image)
67
 
68
  predictions = model.predict(np.expand_dims(image, axis=0))[0]
 
 
69
  logger.debug('Predictions for %s: %s', folder, str(predictions))
70
 
71
  results.append({
72
- 'source': os.path.join(source, folder),
73
- 'subject': subject,
74
- 'session': session,
75
- 'run': run,
76
- 'age': predictions[0],
77
- 'sex': predictions[1],
78
- 'handedness': predictions[2],
79
- 'bmi': predictions[3],
80
- 'fluid_intelligence': predictions[4],
81
- 'neuroticism': predictions[5]
82
  })
83
 
84
  results = pd.DataFrame(results)
@@ -110,6 +115,24 @@ if __name__ == '__main__':
110
  'exist files named <path>.index and <path>.data-00000-of-00001'
111
  )
112
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  parser.add_argument(
114
  '-d', '--destination',
115
  required=False,
@@ -121,7 +144,9 @@ if __name__ == '__main__':
121
 
122
  predict_from_fastsurfer_folder(
123
  source=args.root,
 
124
  weights=args.weights,
 
125
  destination=args.destination
126
  )
127
 
 
5
  import numpy as np
6
  import pandas as pd
7
  from tqdm import tqdm
8
+ from typing import List, Tuple
9
 
10
  import nibabel as nib
11
 
12
+ from pyment.models.sfcn import sfcn_factory
13
  from pyment.preprocessing.conform import conform
14
 
15
 
 
29
 
30
  def predict_from_fastsurfer_folder(
31
  source: str,
32
+ weights: str,
33
+ model_name: str = 'sfcn-multi',
34
+ targets: List[str] = [
35
+ 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism'
36
+ ],
37
  destination: str = None
38
  ) -> pd.DataFrame:
39
  if destination is not None and os.path.isfile(destination):
40
  raise ValueError(f'Destination {destination} already exists')
41
 
42
  logger.info('Loading multi-task model with weights %s', weights)
43
+
44
+ model_class = sfcn_factory(model_name)
45
+ model = model_class(weights=weights)
46
 
47
  results = []
48
 
 
72
  image = conform(image)
73
 
74
  predictions = model.predict(np.expand_dims(image, axis=0))[0]
75
+ print(predictions.shape)
76
+ print(predictions)
77
  logger.debug('Predictions for %s: %s', folder, str(predictions))
78
 
79
  results.append({
80
+ **{
81
+ 'source': os.path.join(source, folder),
82
+ 'subject': subject,
83
+ 'session': session,
84
+ 'run': run
85
+ },
86
+ **{targets[i]: predictions[i] for i in range(len(targets))}
 
 
 
87
  })
88
 
89
  results = pd.DataFrame(results)
 
115
  'exist files named <path>.index and <path>.data-00000-of-00001'
116
  )
117
  )
118
+ parser.add_argument(
119
+ '-m', '--model',
120
+ required=False,
121
+ default='sfcn-multi',
122
+ help=(
123
+ 'Name of the model to use'
124
+ )
125
+ )
126
+ parser.add_argument(
127
+ '-t', '--targets',
128
+ required=False,
129
+ nargs='+',
130
+ default=[
131
+ 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
132
+ 'neuroticism'
133
+ ],
134
+ help='Name to use for each of the prediction heads in the output CSV'
135
+ )
136
  parser.add_argument(
137
  '-d', '--destination',
138
  required=False,
 
144
 
145
  predict_from_fastsurfer_folder(
146
  source=args.root,
147
+ model_name=args.model,
148
  weights=args.weights,
149
+ targets=args.targets,
150
  destination=args.destination
151
  )
152