Esten Leonardsen commited on
Commit ·
55880f9
1
Parent(s): 539bc34
Finished first version of scripts necessary to finetune models
Browse files- new_packages.txt +2 -1
- pyment/configurations/__init__.py +2 -0
- pyment/configurations/data_split_configuration.py +8 -0
- pyment/configurations/dataset_configuration.py +287 -0
- pyment/configurations/finetuning_configuration.py +11 -0
- pyment/configurations/learning_rate_schedule_configuration.py +45 -0
- pyment/configurations/model_configuration.py +9 -0
- pyment/configurations/training_configuration.py +19 -0
- pyment/factories/__init__.py +3 -0
- pyment/factories/loss_factory.py +10 -0
- pyment/factories/metric_factory.py +8 -0
- pyment/factories/optimizer_factory.py +8 -0
- pyment/models/sfcn/__init__.py +9 -1
- pyment/models/sfcn/sfcn.py +5 -3
- pyment/models/sfcn/sfcn_multi.py +2 -2
- pyment/models/sfcn/sfcn_reg.py +16 -0
- pyment/models/utils/ensure_weights.py +9 -4
- pyment/models/utils/load_select_pretrained_weights.py +46 -0
- pyment/utils/json_serialize.py +16 -0
- scripts/finetune_from_bids_folder.py +216 -0
- scripts/finetune_from_fastsurfer_folder.py +0 -0
- scripts/predict_from_bids_folder.py +36 -13
- scripts/predict_from_fastsurfer_folder.py +39 -14
new_packages.txt
CHANGED
|
@@ -13,4 +13,5 @@ tqdm==4.66.4
|
|
| 13 |
plotly==5.24.1
|
| 14 |
pytest==8.3.3
|
| 15 |
scikit-learn==1.5.1
|
| 16 |
-
xlrd==2.0.1
|
|
|
|
|
|
| 13 |
plotly==5.24.1
|
| 14 |
pytest==8.3.3
|
| 15 |
scikit-learn==1.5.1
|
| 16 |
+
xlrd==2.0.1
|
| 17 |
+
pydantic==2.10
|
pyment/configurations/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .dataset_configuration import DatasetConfiguration
|
| 2 |
+
from .finetuning_configuration import FinetuningConfiguration
|
pyment/configurations/data_split_configuration.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class DataSplitConfiguration(BaseModel):
|
| 7 |
+
training_fraction: float
|
| 8 |
+
stratification: List[str] = None
|
pyment/configurations/dataset_configuration.py
ADDED
|
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import re
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
from collections import Counter
|
| 9 |
+
from typing import Dict, List, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import nibabel as nib
|
| 12 |
+
from pydantic import model_validator, BaseModel
|
| 13 |
+
|
| 14 |
+
from .data_split_configuration import DataSplitConfiguration
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logging.basicConfig(
|
| 18 |
+
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
|
| 19 |
+
level=logging.INFO
|
| 20 |
+
)
|
| 21 |
+
logger = logging.getLogger(__name__)
|
| 22 |
+
|
| 23 |
+
def _extract_run(filename: str) -> Union[str, None]:
|
| 24 |
+
match = re.fullmatch(r'.*_run-(?P<run>[^_.]*)(?:_.*)?\.mgz', filename)
|
| 25 |
+
|
| 26 |
+
if match:
|
| 27 |
+
return match.group('run')
|
| 28 |
+
|
| 29 |
+
logger.warning('Unable to extract run from filename %s', filename)
|
| 30 |
+
|
| 31 |
+
return None
|
| 32 |
+
|
| 33 |
+
def _parse_bids_folder(root: str):
|
| 34 |
+
entries = []
|
| 35 |
+
|
| 36 |
+
for subject_folder in os.listdir(root):
|
| 37 |
+
subject_match = re.fullmatch(r'sub-(?P<subject>.*)', subject_folder)
|
| 38 |
+
|
| 39 |
+
if not subject_match:
|
| 40 |
+
logger.warning(
|
| 41 |
+
'Subject folder %s in %s does not have the expected sub-XXX '
|
| 42 |
+
'format. Skipping', subject_folder, root
|
| 43 |
+
)
|
| 44 |
+
continue
|
| 45 |
+
|
| 46 |
+
subject = subject_match.group('subject')
|
| 47 |
+
|
| 48 |
+
for session_folder in os.listdir(os.path.join(root, subject_folder)):
|
| 49 |
+
session_match = re.fullmatch(
|
| 50 |
+
r'ses-(?P<session>.*)', session_folder
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
if not session_match:
|
| 54 |
+
logger.warning(
|
| 55 |
+
'Session folder %s in subject %s in folder %s does not '
|
| 56 |
+
'match the expected ses-XXX format. Skipping',
|
| 57 |
+
session_folder, subject_folder, root
|
| 58 |
+
)
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
session = session_match.group('session')
|
| 62 |
+
|
| 63 |
+
anat_folder = os.path.join(
|
| 64 |
+
root, subject_folder, session_folder, 'anat'
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
t1s = [
|
| 68 |
+
filename for filename in os.listdir(anat_folder)
|
| 69 |
+
if 'T1' in filename
|
| 70 |
+
]
|
| 71 |
+
|
| 72 |
+
for filename in t1s:
|
| 73 |
+
run = _extract_run(filename)
|
| 74 |
+
entries.append({
|
| 75 |
+
'subject': subject,
|
| 76 |
+
'session': session,
|
| 77 |
+
'run': run,
|
| 78 |
+
'path': os.path.join(anat_folder, filename)
|
| 79 |
+
})
|
| 80 |
+
|
| 81 |
+
return pd.DataFrame(entries, columns=['subject', 'session', 'run', 'path'])
|
| 82 |
+
|
| 83 |
+
def _parse_bids_folders(folders: List[str]):
|
| 84 |
+
df = pd.concat([_parse_bids_folder(folder) for folder in folders])
|
| 85 |
+
df = df.reset_index()
|
| 86 |
+
logger.info('Parsed %d images', len(df))
|
| 87 |
+
|
| 88 |
+
return df
|
| 89 |
+
|
| 90 |
+
def _parse_fastsurfer_name(name: str) -> Tuple[str, str, str]:
|
| 91 |
+
match = re.fullmatch(r'sub-(.*)_ses-(.*)_run-(.*)(?:T1w?)?', name)
|
| 92 |
+
|
| 93 |
+
if not match:
|
| 94 |
+
raise ValueError(
|
| 95 |
+
'Unable to extract subject, session, run from folder %s', name
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
return match.groups()
|
| 99 |
+
|
| 100 |
+
def _parse_fastsurfer_folder(folder: str):
|
| 101 |
+
entries = []
|
| 102 |
+
|
| 103 |
+
for subfolder in os.listdir(folder):
|
| 104 |
+
subject, session, run = _parse_fastsurfer_name(subfolder)
|
| 105 |
+
|
| 106 |
+
mri_folder = os.path.join(folder, subfolder, 'mri')
|
| 107 |
+
brainmask = os.path.join(mri_folder, 'brainmask.mgz')
|
| 108 |
+
|
| 109 |
+
if not os.path.isfile(brainmask):
|
| 110 |
+
logger.info('Brainmask does not exist in folder %s', subfolder)
|
| 111 |
+
|
| 112 |
+
orig = os.path.join(mri_folder, 'orig.mgz')
|
| 113 |
+
mask = os.path.join(mri_folder, 'mask.mgz')
|
| 114 |
+
|
| 115 |
+
if not os.path.isfile(orig):
|
| 116 |
+
logger.error('Orig does not exist in folder %s', subfolder)
|
| 117 |
+
continue
|
| 118 |
+
elif not os.path.isfile(mask):
|
| 119 |
+
logger.error('Mask does not exist in folder %s', subfolder)
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
orig_data = nib.load(orig)
|
| 123 |
+
mask_data = nib.load(mask)
|
| 124 |
+
brainmask_data = nib.Nifti1Image(
|
| 125 |
+
orig_data.get_fdata() * mask_data.get_fdata(),
|
| 126 |
+
header=orig_data.header,
|
| 127 |
+
affine=orig_data.affine
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
nib.save(brainmask_data, brainmask)
|
| 131 |
+
|
| 132 |
+
entries.append({
|
| 133 |
+
'subject': subject,
|
| 134 |
+
'session': session,
|
| 135 |
+
'run': run,
|
| 136 |
+
'path': brainmask
|
| 137 |
+
})
|
| 138 |
+
|
| 139 |
+
return pd.DataFrame(entries, columns=['subject', 'session', 'run', 'path'])
|
| 140 |
+
|
| 141 |
+
def _parse_fastsurfer_folders(folders: List[str]):
|
| 142 |
+
df = pd.concat([_parse_fastsurfer_folder(folder) for folder in folders])
|
| 143 |
+
df = df.reset_index()
|
| 144 |
+
logger.info('Parsed %d images', len(df))
|
| 145 |
+
|
| 146 |
+
return df
|
| 147 |
+
|
| 148 |
+
def _summarize_values(values: np.ndarray, name: str):
|
| 149 |
+
if not np.issubdtype(values.dtype, np.number):
|
| 150 |
+
logger.info('%s: %s', name, Counter(values))
|
| 151 |
+
elif np.array_equal(
|
| 152 |
+
np.unique(values[~np.isnan(values)]),
|
| 153 |
+
np.asarray([0, 1])
|
| 154 |
+
):
|
| 155 |
+
nans = len(np.where(np.isnan(values))[0])
|
| 156 |
+
logger.info(
|
| 157 |
+
'%s: %s (%d NAs)', name, Counter(values[~np.isnan(values)]), nans
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
nans = len(np.where(np.isnan(values))[0])
|
| 161 |
+
mean = np.round(np.nanmean(values), 2)
|
| 162 |
+
std = np.round(np.nanstd(values), 2)
|
| 163 |
+
logger.info('%s: %.2f+/-%.2f (%d NAs)', name, mean, std, nans)
|
| 164 |
+
|
| 165 |
+
def _summarize(df: pd.DataFrame, variables: List[str], name: str):
|
| 166 |
+
logger.info('%s n=%d', name, len(df))
|
| 167 |
+
|
| 168 |
+
for variable in variables:
|
| 169 |
+
_summarize_values(df[variable].values, name=variable)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _split_training_validation_fold(
|
| 173 |
+
df: pd.DataFrame,
|
| 174 |
+
labels: str,
|
| 175 |
+
training_fraction: float,
|
| 176 |
+
target: str = None,
|
| 177 |
+
stratification: List[str] = None
|
| 178 |
+
) -> Tuple[pd.DataFrame, pd.DataFrame]:
|
| 179 |
+
|
| 180 |
+
columns = set(['subject', 'session', 'run'])
|
| 181 |
+
|
| 182 |
+
if target:
|
| 183 |
+
columns.add(target)
|
| 184 |
+
|
| 185 |
+
if stratification:
|
| 186 |
+
columns |= set(stratification)
|
| 187 |
+
|
| 188 |
+
labels = pd.read_csv(
|
| 189 |
+
labels,
|
| 190 |
+
usecols=list(columns),
|
| 191 |
+
dtype={'subject': object, 'session': object, 'run': object},
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
logger.info('Parsed %d labels', len(labels))
|
| 195 |
+
|
| 196 |
+
if not len(labels) == len(labels.drop_duplicates(['subject', 'session'])):
|
| 197 |
+
raise ValueError(
|
| 198 |
+
f'There are duplicates (subject, session)-pairs in the labels file'
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
df = pd.merge(
|
| 202 |
+
df, labels,
|
| 203 |
+
how='inner',
|
| 204 |
+
left_on=['subject', 'session'],
|
| 205 |
+
right_on=['subject', 'session']
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
logger.info('Merged %d data points', len(df))
|
| 209 |
+
|
| 210 |
+
if stratification is not None:
|
| 211 |
+
df = df.sort_values(stratification)
|
| 212 |
+
|
| 213 |
+
subjects = df.drop_duplicates('subject')
|
| 214 |
+
num_folds = int(1.0 / (1 - training_fraction))
|
| 215 |
+
|
| 216 |
+
if num_folds == 1:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
'Training fraction %.2f yields a single fold', training_fraction
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
subjects['fold'] = np.arange(len(df)) % num_folds
|
| 222 |
+
folds = {row['subject']: row['fold'] for _, row in subjects.iterrows()}
|
| 223 |
+
df['fold'] = df['subject'].map(folds)
|
| 224 |
+
|
| 225 |
+
validation_fold = num_folds // 2
|
| 226 |
+
training = df[df['fold'] != validation_fold]
|
| 227 |
+
validation = df[df['fold'] == validation_fold]
|
| 228 |
+
|
| 229 |
+
if len(
|
| 230 |
+
set(training['subject'].values) & set(validation['subject'].values)
|
| 231 |
+
) > 0:
|
| 232 |
+
raise ValueError('Overlap between training and validation folds')
|
| 233 |
+
|
| 234 |
+
if stratification:
|
| 235 |
+
for name, df in [('Training', training), ('Validation', validation)]:
|
| 236 |
+
_summarize(df, variables=stratification, name=name)
|
| 237 |
+
|
| 238 |
+
return training, validation
|
| 239 |
+
|
| 240 |
+
class DatasetConfiguration(BaseModel):
|
| 241 |
+
input_shape: Tuple[int, int, int]
|
| 242 |
+
bids: List[str] | None = None
|
| 243 |
+
fastsurfer: List[str] | None = None
|
| 244 |
+
labels: str
|
| 245 |
+
split: DataSplitConfiguration = None
|
| 246 |
+
|
| 247 |
+
@model_validator(mode='after')
|
| 248 |
+
def check_fastsurfer_or_bids(self):
|
| 249 |
+
if self.bids is not None and self.fastsurfer is not None:
|
| 250 |
+
raise ValueError(
|
| 251 |
+
'Either \'bids\' or \'fastsurfer\'-property must be set, not '
|
| 252 |
+
'both'
|
| 253 |
+
)
|
| 254 |
+
elif self.bids is None and self.fastsurfer is None:
|
| 255 |
+
raise ValueError(
|
| 256 |
+
'Either \'bids or \'fastsurfer\'-property must be set'
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
return self
|
| 260 |
+
|
| 261 |
+
@staticmethod
|
| 262 |
+
def parse(
|
| 263 |
+
configuration: DatasetConfiguration,
|
| 264 |
+
target: str = None
|
| 265 |
+
) -> Dict[str, pd.DataFrame]:
|
| 266 |
+
if configuration.split:
|
| 267 |
+
if configuration.bids:
|
| 268 |
+
df = _parse_bids_folders(configuration.bids)
|
| 269 |
+
elif configuration.fastsurfer:
|
| 270 |
+
df = _parse_fastsurfer_folders(configuration.fastsurfer)
|
| 271 |
+
else:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
'Unable to parse DatasetConfiguration without either '
|
| 274 |
+
'\'bids\' or \'fastsurfer\' set'
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
return _split_training_validation_fold(
|
| 278 |
+
df=df,
|
| 279 |
+
labels=configuration.labels,
|
| 280 |
+
training_fraction=configuration.split.training_fraction,
|
| 281 |
+
target=target,
|
| 282 |
+
stratification=configuration.split.stratification
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
raise NotImplementedError(
|
| 286 |
+
f'Not sure how to parse dataset without a split configuration'
|
| 287 |
+
)
|
pyment/configurations/finetuning_configuration.py
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pydantic import BaseModel
|
| 2 |
+
|
| 3 |
+
from .dataset_configuration import DatasetConfiguration
|
| 4 |
+
from .model_configuration import ModelConfiguration
|
| 5 |
+
from .training_configuration import TrainingConfiguration
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class FinetuningConfiguration(BaseModel):
|
| 9 |
+
model: ModelConfiguration
|
| 10 |
+
data: DatasetConfiguration
|
| 11 |
+
training: TrainingConfiguration
|
pyment/configurations/learning_rate_schedule_configuration.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import abstractmethod
|
| 2 |
+
from typing import Annotated, Literal, Union
|
| 3 |
+
|
| 4 |
+
from tensorflow.keras.callbacks import Callback, ReduceLROnPlateau
|
| 5 |
+
from pydantic import BaseModel, ConfigDict, Field
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LearningRateScheduleBaseConfiguration(BaseModel):
|
| 9 |
+
model_config = ConfigDict(extra='forbid')
|
| 10 |
+
|
| 11 |
+
@abstractmethod
|
| 12 |
+
def instantiate(self) -> Callback:
|
| 13 |
+
pass
|
| 14 |
+
|
| 15 |
+
class AnnealingLearningRateScheduleConfiguration(
|
| 16 |
+
LearningRateScheduleBaseConfiguration
|
| 17 |
+
):
|
| 18 |
+
kind: Literal['annealing']
|
| 19 |
+
factor: float
|
| 20 |
+
patience: int
|
| 21 |
+
minimum_learning_rate: float
|
| 22 |
+
|
| 23 |
+
def instantiate(self) -> Callback:
|
| 24 |
+
return ReduceLROnPlateau(
|
| 25 |
+
factor=self.factor,
|
| 26 |
+
patience=self.patience,
|
| 27 |
+
min_lr=self.minimum_learning_rate,
|
| 28 |
+
verbose=True
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
class StepWiseLearningRateScheduleConfiguration(
|
| 32 |
+
LearningRateScheduleBaseConfiguration
|
| 33 |
+
):
|
| 34 |
+
kind: Literal['stepwise']
|
| 35 |
+
|
| 36 |
+
def instantiate(self) -> Callback:
|
| 37 |
+
return ReduceLROnPl
|
| 38 |
+
|
| 39 |
+
LearningRateScheduleConfiguration = Annotated[
|
| 40 |
+
Union[
|
| 41 |
+
AnnealingLearningRateScheduleConfiguration,
|
| 42 |
+
StepWiseLearningRateScheduleConfiguration
|
| 43 |
+
],
|
| 44 |
+
Field(discriminator='kind')
|
| 45 |
+
]
|
pyment/configurations/model_configuration.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel, Field
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ModelConfiguration(BaseModel):
|
| 7 |
+
type: str
|
| 8 |
+
hyperparameters: Dict[str, Any] = Field(default_factory=dict)
|
| 9 |
+
weights: str = None
|
pyment/configurations/training_configuration.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
from pydantic import BaseModel
|
| 4 |
+
|
| 5 |
+
from .learning_rate_schedule_configuration import (
|
| 6 |
+
LearningRateScheduleConfiguration
|
| 7 |
+
)
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class TrainingConfiguration(BaseModel):
|
| 11 |
+
target: str
|
| 12 |
+
loss: str
|
| 13 |
+
metrics: List[str] = None
|
| 14 |
+
optimizer: str
|
| 15 |
+
learning_rate: float
|
| 16 |
+
learning_rate_schedule: LearningRateScheduleConfiguration = None
|
| 17 |
+
batch_size: int
|
| 18 |
+
epochs: int
|
| 19 |
+
destination: str = None
|
pyment/factories/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .loss_factory import loss_factory
|
| 2 |
+
from .metric_factory import metric_factory
|
| 3 |
+
from .optimizer_factory import optimizer_factory
|
pyment/factories/loss_factory.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def loss_factory(name: str) -> Callable[[tf.Tensor, tf.Tensor], tf.Tensor]:
|
| 7 |
+
if name.lower() == 'mse':
|
| 8 |
+
return tf.keras.losses.MeanSquaredError
|
| 9 |
+
|
| 10 |
+
raise KeyError(f'Unknown loss {name}')
|
pyment/factories/metric_factory.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def metric_factory(name: str) -> tf.keras.metrics.Metric:
|
| 7 |
+
if name.lower() == 'mae':
|
| 8 |
+
return tf.keras.metrics.MeanAbsoluteError()
|
pyment/factories/optimizer_factory.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import tensorflow as tf
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def optimizer_factory(name: str) -> tf.optimizers.Optimizer:
|
| 5 |
+
if name.lower() == 'adam':
|
| 6 |
+
return tf.optimizers.Adam
|
| 7 |
+
|
| 8 |
+
raise KeyError(f'Unknown optimizer {name}')
|
pyment/models/sfcn/__init__.py
CHANGED
|
@@ -1,4 +1,12 @@
|
|
| 1 |
from .sfcn import SFCN
|
| 2 |
from .sfcn_multi import MultiTaskSFCN
|
|
|
|
| 3 |
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from .sfcn import SFCN
|
| 2 |
from .sfcn_multi import MultiTaskSFCN
|
| 3 |
+
from .sfcn_reg import RegressionSFCN
|
| 4 |
|
| 5 |
+
|
| 6 |
+
def sfcn_factory(model_type: str):
|
| 7 |
+
if model_type in ['sfcn-reg', 'regression']:
|
| 8 |
+
return RegressionSFCN
|
| 9 |
+
|
| 10 |
+
raise ValueError(f'Unknown SFCN type {model_type}')
|
| 11 |
+
|
| 12 |
+
__all__ = ['sfcn_factory', 'SFCN', 'MultiTaskSFCN', 'RegressionSFCN']
|
pyment/models/sfcn/sfcn.py
CHANGED
|
@@ -82,6 +82,8 @@ class SFCN(Model):
|
|
| 82 |
weights = ensure_weights(weights)
|
| 83 |
status = self.load_weights(weights)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
|
|
| 82 |
weights = ensure_weights(weights)
|
| 83 |
status = self.load_weights(weights)
|
| 84 |
|
| 85 |
+
print(weights)
|
| 86 |
+
if not weights.endswith('hdf5'):
|
| 87 |
+
# Silences warnings about optimizer-status not being loaded
|
| 88 |
+
status.expect_partial()
|
| 89 |
+
status.assert_existing_objects_matched()
|
pyment/models/sfcn/sfcn_multi.py
CHANGED
|
@@ -7,8 +7,8 @@ from .sfcn import SFCN
|
|
| 7 |
class MultiTaskSFCN(SFCN):
|
| 8 |
@classmethod
|
| 9 |
def construct_prediction_head(
|
| 10 |
-
cls,
|
| 11 |
-
bottleneck: Tensor,
|
| 12 |
name: str
|
| 13 |
) -> Tensor:
|
| 14 |
x = bottleneck
|
|
|
|
| 7 |
class MultiTaskSFCN(SFCN):
|
| 8 |
@classmethod
|
| 9 |
def construct_prediction_head(
|
| 10 |
+
cls,
|
| 11 |
+
bottleneck: Tensor,
|
| 12 |
name: str
|
| 13 |
) -> Tensor:
|
| 14 |
x = bottleneck
|
pyment/models/sfcn/sfcn_reg.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tensorflow import Tensor
|
| 2 |
+
from tensorflow.keras.layers import Dense
|
| 3 |
+
|
| 4 |
+
from .sfcn import SFCN
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RegressionSFCN(SFCN):
|
| 8 |
+
@classmethod
|
| 9 |
+
def construct_prediction_head(
|
| 10 |
+
cls,
|
| 11 |
+
bottleneck: Tensor,
|
| 12 |
+
name: str
|
| 13 |
+
) -> Tensor:
|
| 14 |
+
layer = Dense(1, activation=None, name=f'{name}/predictions')
|
| 15 |
+
|
| 16 |
+
return layer(bottleneck)
|
pyment/models/utils/ensure_weights.py
CHANGED
|
@@ -21,12 +21,17 @@ def ensure_weights(identifier: str) -> str:
|
|
| 21 |
------
|
| 22 |
KeyError
|
| 23 |
If the identifier is not a valid identifier and there does not
|
| 24 |
-
exist
|
| 25 |
-
<identifier>.data-00000-of-00001 on the
|
|
|
|
| 26 |
"""
|
| 27 |
if not (
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
):
|
| 31 |
raise NotImplementedError(
|
| 32 |
f'Identifier-based lookups are not supported'
|
|
|
|
| 21 |
------
|
| 22 |
KeyError
|
| 23 |
If the identifier is not a valid identifier and there does not
|
| 24 |
+
exist either a single file <identifier> or files
|
| 25 |
+
<identifier>.index and <identifier>.data-00000-of-00001 on the
|
| 26 |
+
local file system.
|
| 27 |
"""
|
| 28 |
if not (
|
| 29 |
+
(
|
| 30 |
+
os.path.isfile(f'{identifier}.index') and
|
| 31 |
+
os.path.isfile(f'{identifier}.data-00000-of-00001')
|
| 32 |
+
) or (
|
| 33 |
+
os.path.isfile(identifier)
|
| 34 |
+
)
|
| 35 |
):
|
| 36 |
raise NotImplementedError(
|
| 37 |
f'Identifier-based lookups are not supported'
|
pyment/models/utils/load_select_pretrained_weights.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import tensorflow as tf
|
| 4 |
+
|
| 5 |
+
from ..sfcn import MultiTaskSFCN
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
logging.basicConfig(
|
| 9 |
+
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
|
| 10 |
+
level=logging.DEBUG
|
| 11 |
+
)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
def load_select_pretrained_weights(
|
| 15 |
+
model: tf.keras.Model,
|
| 16 |
+
weights: str,
|
| 17 |
+
target: str = None
|
| 18 |
+
) -> tf.keras.Model:
|
| 19 |
+
logger.info('Loading pretrained weights from %s', weights)
|
| 20 |
+
|
| 21 |
+
backbone = MultiTaskSFCN(input_shape=(224, 192, 224), pooling='max')
|
| 22 |
+
checkpoint = tf.train.Checkpoint(backbone)
|
| 23 |
+
|
| 24 |
+
checkpoint.restore(weights).expect_partial()
|
| 25 |
+
|
| 26 |
+
conv_layers = [2, 6, 10, 14, 18, 22]
|
| 27 |
+
norm_layers = [3, 7, 11, 15, 19, 23]
|
| 28 |
+
|
| 29 |
+
for idx in conv_layers + norm_layers:
|
| 30 |
+
model.layers[idx].set_weights(backbone.layers[idx].get_weights())
|
| 31 |
+
|
| 32 |
+
# Loading weights from the specific dense-layer corresponding to the
|
| 33 |
+
# given prediction-task in the multi-task model
|
| 34 |
+
if target == 'age':
|
| 35 |
+
logger.info('Loaded age weights for the prediction head')
|
| 36 |
+
model.layers[27].set_weights(backbone.layers[27].get_weights())
|
| 37 |
+
elif target == 'sex':
|
| 38 |
+
logger.info('Loaded sex weights for the prediction head')
|
| 39 |
+
model.layers[27].set_weights(backbone.layers[28].get_weights())
|
| 40 |
+
else:
|
| 41 |
+
logger.warning(
|
| 42 |
+
'Unknown target %s. Not loading weights for prediction layer',
|
| 43 |
+
target
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
return model
|
pyment/utils/json_serialize.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
def json_serialize(obj: Any) -> Any:
|
| 5 |
+
if isinstance(obj, dict):
|
| 6 |
+
return {json_serialize(k): json_serialize(v) for k, v in obj.items()}
|
| 7 |
+
elif isinstance(obj, list):
|
| 8 |
+
return [json_serialize(v) for v in obj]
|
| 9 |
+
elif isinstance(obj, (np.integer,)):
|
| 10 |
+
return int(obj)
|
| 11 |
+
elif isinstance(obj, (np.floating,)):
|
| 12 |
+
return float(obj)
|
| 13 |
+
elif isinstance(obj, (np.ndarray,)):
|
| 14 |
+
return obj.tolist()
|
| 15 |
+
else:
|
| 16 |
+
return obj
|
scripts/finetune_from_bids_folder.py
ADDED
|
@@ -0,0 +1,216 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import json
|
| 3 |
+
import logging
|
| 4 |
+
import os
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from typing import Any, Callable, Dict, List, Tuple
|
| 7 |
+
|
| 8 |
+
import tensorflow as tf
|
| 9 |
+
from tensorflow_neuroimaging.preprocessing import center_crop_or_pad
|
| 10 |
+
from tensorflow_neuroimaging.loaders.mgh import load_mgh
|
| 11 |
+
|
| 12 |
+
from pyment.configurations import DatasetConfiguration, FinetuningConfiguration
|
| 13 |
+
from pyment.factories import loss_factory, metric_factory, optimizer_factory
|
| 14 |
+
from pyment.models.sfcn import sfcn_factory
|
| 15 |
+
from pyment.models.utils.load_select_pretrained_weights import (
|
| 16 |
+
load_select_pretrained_weights
|
| 17 |
+
)
|
| 18 |
+
from pyment.utils.json_serialize import json_serialize
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
logging.basicConfig(
|
| 22 |
+
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
|
| 23 |
+
level=logging.DEBUG
|
| 24 |
+
)
|
| 25 |
+
logger = logging.getLogger(__name__)
|
| 26 |
+
|
| 27 |
+
def _create_tensorflow_dataset(
|
| 28 |
+
df: pd.DataFrame, *,
|
| 29 |
+
target: str,
|
| 30 |
+
input_shape: Tuple[int, int, int],
|
| 31 |
+
batch_size: str,
|
| 32 |
+
shuffle: bool = False
|
| 33 |
+
) -> tf.data.Dataset:
|
| 34 |
+
input_shape = tf.constant(input_shape)
|
| 35 |
+
|
| 36 |
+
df = df.copy()
|
| 37 |
+
df = df.sample(frac=1.)
|
| 38 |
+
|
| 39 |
+
dataset = tf.data.Dataset.from_tensor_slices((df['path'], df[target]))
|
| 40 |
+
|
| 41 |
+
if shuffle:
|
| 42 |
+
dataset = dataset.shuffle(buffer_size=5*batch_size)
|
| 43 |
+
|
| 44 |
+
dataset = dataset.map(
|
| 45 |
+
lambda path, label: (load_mgh(path), label),
|
| 46 |
+
num_parallel_calls=tf.data.AUTOTUNE
|
| 47 |
+
)
|
| 48 |
+
dataset = dataset.map(
|
| 49 |
+
lambda image, label: (center_crop_or_pad(image, input_shape), label),
|
| 50 |
+
num_parallel_calls=tf.data.AUTOTUNE
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
dataset = dataset.batch(batch_size)
|
| 54 |
+
dataset = dataset.prefetch(tf.data.AUTOTUNE)
|
| 55 |
+
|
| 56 |
+
return dataset
|
| 57 |
+
|
| 58 |
+
def _create_checkpointing_callback(
|
| 59 |
+
destination: str,
|
| 60 |
+
metrics: List[tf.keras.metrics.Metric] = None
|
| 61 |
+
):
|
| 62 |
+
os.mkdir(destination)
|
| 63 |
+
|
| 64 |
+
train_metrics = []
|
| 65 |
+
val_metrics = []
|
| 66 |
+
|
| 67 |
+
if metrics is not None:
|
| 68 |
+
for metric in metrics:
|
| 69 |
+
name = metric.name.replace('_', '-')
|
| 70 |
+
train_metrics.append(f'{name}={{{metric.name}:.2f}}')
|
| 71 |
+
val_metrics.append(f'val-{name}={{val_{metric.name}:.2f}}')
|
| 72 |
+
|
| 73 |
+
terms = [
|
| 74 |
+
'epoch={epoch:03d}',
|
| 75 |
+
'loss={loss:.2f}'
|
| 76 |
+
] + train_metrics + [
|
| 77 |
+
'val-loss={val_loss:.2f}'
|
| 78 |
+
] + val_metrics
|
| 79 |
+
filename = '_'.join(terms) + '.hdf5'
|
| 80 |
+
filepath = os.path.join(destination, filename)
|
| 81 |
+
|
| 82 |
+
return tf.keras.callbacks.ModelCheckpoint(
|
| 83 |
+
filepath,
|
| 84 |
+
monitor='val_loss',
|
| 85 |
+
save_best_only=True,
|
| 86 |
+
save_weights_only=True
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def finetune(
|
| 90 |
+
model_type: str,
|
| 91 |
+
model_constructor_arguments: Dict[str, Any],
|
| 92 |
+
weights: str,
|
| 93 |
+
input_shape: Tuple[int, int, int],
|
| 94 |
+
target: str,
|
| 95 |
+
loss: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
|
| 96 |
+
metrics: List[tf.keras.metrics.Metric],
|
| 97 |
+
optimizer: tf.optimizers.Optimizer,
|
| 98 |
+
learning_rate_scheduler: tf.keras.callbacks.Callback,
|
| 99 |
+
training: pd.DataFrame,
|
| 100 |
+
validation: pd.DataFrame,
|
| 101 |
+
batch_size: int,
|
| 102 |
+
epochs: int,
|
| 103 |
+
destination: str
|
| 104 |
+
):
|
| 105 |
+
if destination is not None:
|
| 106 |
+
if os.path.isdir(destination):
|
| 107 |
+
raise ValueError(f'Destination {destination} already exists')
|
| 108 |
+
|
| 109 |
+
logger.info('Creating destination folder %s', destination)
|
| 110 |
+
os.mkdir(destination)
|
| 111 |
+
|
| 112 |
+
model_class = sfcn_factory(model_type)
|
| 113 |
+
model = model_class(
|
| 114 |
+
input_shape=input_shape,
|
| 115 |
+
**model_constructor_arguments
|
| 116 |
+
)
|
| 117 |
+
load_select_pretrained_weights(model, weights, target=target)
|
| 118 |
+
|
| 119 |
+
model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
|
| 120 |
+
|
| 121 |
+
training_dataset = _create_tensorflow_dataset(
|
| 122 |
+
training,
|
| 123 |
+
input_shape=input_shape,
|
| 124 |
+
target=target,
|
| 125 |
+
batch_size=batch_size,
|
| 126 |
+
shuffle=True
|
| 127 |
+
)
|
| 128 |
+
validation_dataset = _create_tensorflow_dataset(
|
| 129 |
+
validation,
|
| 130 |
+
input_shape=input_shape,
|
| 131 |
+
target=target,
|
| 132 |
+
batch_size=batch_size,
|
| 133 |
+
shuffle=False
|
| 134 |
+
)
|
| 135 |
+
|
| 136 |
+
callbacks = [
|
| 137 |
+
_create_checkpointing_callback(
|
| 138 |
+
os.path.join(destination, 'checkpoints'),
|
| 139 |
+
metrics=metrics
|
| 140 |
+
),
|
| 141 |
+
learning_rate_scheduler
|
| 142 |
+
]
|
| 143 |
+
|
| 144 |
+
history = model.fit(
|
| 145 |
+
training_dataset,
|
| 146 |
+
validation_data=validation_dataset,
|
| 147 |
+
epochs=epochs,
|
| 148 |
+
callbacks=callbacks
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
with open(os.path.join(destination, 'history.json'), 'w') as f:
|
| 152 |
+
json.dump(json_serialize(history.history), f)
|
| 153 |
+
|
| 154 |
+
def finetune_from_configuration(configuration: str):
|
| 155 |
+
with open(configuration, 'r') as f:
|
| 156 |
+
configuration = json.load(f)
|
| 157 |
+
|
| 158 |
+
configuration = FinetuningConfiguration.model_validate(configuration)
|
| 159 |
+
|
| 160 |
+
training, validation = DatasetConfiguration.parse(
|
| 161 |
+
configuration.data,
|
| 162 |
+
target=configuration.training.target
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# strategy = tf.distribute.MirroredStrategy()
|
| 166 |
+
|
| 167 |
+
# with strategy.scope():
|
| 168 |
+
|
| 169 |
+
loss_cls = loss_factory(configuration.training.loss)
|
| 170 |
+
loss = loss_cls()
|
| 171 |
+
|
| 172 |
+
optimizer_cls = optimizer_factory(configuration.training.optimizer)
|
| 173 |
+
optimizer = optimizer_cls(configuration.training.learning_rate)
|
| 174 |
+
|
| 175 |
+
metrics = None
|
| 176 |
+
|
| 177 |
+
if configuration.training.metrics is not None:
|
| 178 |
+
metrics = [
|
| 179 |
+
metric_factory(metric)
|
| 180 |
+
for metric in configuration.training.metrics
|
| 181 |
+
]
|
| 182 |
+
|
| 183 |
+
learning_rate_scheduler = None
|
| 184 |
+
|
| 185 |
+
if configuration.training.learning_rate_schedule:
|
| 186 |
+
learning_rate_scheduler = (
|
| 187 |
+
configuration.training.learning_rate_schedule.instantiate()
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
finetune(
|
| 191 |
+
model_type=configuration.model.type,
|
| 192 |
+
model_constructor_arguments=configuration.model.hyperparameters,
|
| 193 |
+
weights=configuration.model.weights,
|
| 194 |
+
input_shape=configuration.data.input_shape,
|
| 195 |
+
target=configuration.training.target,
|
| 196 |
+
loss=loss,
|
| 197 |
+
metrics=metrics,
|
| 198 |
+
optimizer=optimizer,
|
| 199 |
+
learning_rate_scheduler=learning_rate_scheduler,
|
| 200 |
+
training=training,
|
| 201 |
+
validation=validation,
|
| 202 |
+
batch_size=configuration.training.batch_size,
|
| 203 |
+
epochs=configuration.training.epochs,
|
| 204 |
+
destination=configuration.training.destination
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if __name__ == '__main__':
|
| 208 |
+
parser = argparse.ArgumentParser(
|
| 209 |
+
'Finetunes a multi-task SFCN according to the given configuration'
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
parser.add_argument('configuration', help='Path to configuration JSON')
|
| 213 |
+
|
| 214 |
+
args = parser.parse_args()
|
| 215 |
+
|
| 216 |
+
finetune_from_configuration(args.configuration)
|
scripts/finetune_from_fastsurfer_folder.py
ADDED
|
File without changes
|
scripts/predict_from_bids_folder.py
CHANGED
|
@@ -4,11 +4,12 @@ import os
|
|
| 4 |
import re
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
|
|
|
| 7 |
from tqdm import tqdm
|
| 8 |
|
| 9 |
import nibabel as nib
|
| 10 |
|
| 11 |
-
from pyment.models import
|
| 12 |
from pyment.preprocessing.conform import conform
|
| 13 |
|
| 14 |
|
|
@@ -29,7 +30,11 @@ def _extract_run(filename: str) -> str:
|
|
| 29 |
|
| 30 |
def predict_from_bids_folder(
|
| 31 |
source: str,
|
| 32 |
-
weights: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
destination: str = None,
|
| 34 |
per_image_normalization: bool = False
|
| 35 |
) -> pd.DataFrame:
|
|
@@ -37,7 +42,8 @@ def predict_from_bids_folder(
|
|
| 37 |
raise ValueError(f'Destination {destination} already exists')
|
| 38 |
|
| 39 |
logger.info('Loading multi-task model with weights %s', weights)
|
| 40 |
-
|
|
|
|
| 41 |
|
| 42 |
results = []
|
| 43 |
|
|
@@ -73,16 +79,13 @@ def predict_from_bids_folder(
|
|
| 73 |
)
|
| 74 |
|
| 75 |
results.append({
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
'bmi': predictions[3],
|
| 84 |
-
'fluid_intelligence': predictions[4],
|
| 85 |
-
'neuroticism': predictions[5]
|
| 86 |
})
|
| 87 |
|
| 88 |
results = pd.DataFrame(results)
|
|
@@ -108,6 +111,24 @@ if __name__ == '__main__':
|
|
| 108 |
'exist files named <path>.index and <path>.data-00000-of-00001'
|
| 109 |
)
|
| 110 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 111 |
parser.add_argument(
|
| 112 |
'-d', '--destination',
|
| 113 |
required=False,
|
|
@@ -128,6 +149,8 @@ if __name__ == '__main__':
|
|
| 128 |
predict_from_bids_folder(
|
| 129 |
source=args.bids,
|
| 130 |
weights=args.weights,
|
|
|
|
|
|
|
| 131 |
destination=args.destination,
|
| 132 |
per_image_normalization=args.per_image_normalization
|
| 133 |
)
|
|
|
|
| 4 |
import re
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
+
from typing import List
|
| 8 |
from tqdm import tqdm
|
| 9 |
|
| 10 |
import nibabel as nib
|
| 11 |
|
| 12 |
+
from pyment.models.sfcn import sfcn_factory
|
| 13 |
from pyment.preprocessing.conform import conform
|
| 14 |
|
| 15 |
|
|
|
|
| 30 |
|
| 31 |
def predict_from_bids_folder(
|
| 32 |
source: str,
|
| 33 |
+
weights: str,
|
| 34 |
+
model_name: str = 'sfcn-multi',
|
| 35 |
+
targets: List[str] = [
|
| 36 |
+
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism'
|
| 37 |
+
],
|
| 38 |
destination: str = None,
|
| 39 |
per_image_normalization: bool = False
|
| 40 |
) -> pd.DataFrame:
|
|
|
|
| 42 |
raise ValueError(f'Destination {destination} already exists')
|
| 43 |
|
| 44 |
logger.info('Loading multi-task model with weights %s', weights)
|
| 45 |
+
model_class = sfcn_factory(model_name)
|
| 46 |
+
model = model_class(weights=weights)
|
| 47 |
|
| 48 |
results = []
|
| 49 |
|
|
|
|
| 79 |
)
|
| 80 |
|
| 81 |
results.append({
|
| 82 |
+
**{
|
| 83 |
+
'source': path,
|
| 84 |
+
'subject': subject,
|
| 85 |
+
'session': session,
|
| 86 |
+
'run': run
|
| 87 |
+
},
|
| 88 |
+
**{targets[i]: predictions[i] for i in range(len(targets))}
|
|
|
|
|
|
|
|
|
|
| 89 |
})
|
| 90 |
|
| 91 |
results = pd.DataFrame(results)
|
|
|
|
| 111 |
'exist files named <path>.index and <path>.data-00000-of-00001'
|
| 112 |
)
|
| 113 |
)
|
| 114 |
+
parser.add_argument(
|
| 115 |
+
'-m', '--model',
|
| 116 |
+
required=False,
|
| 117 |
+
default='sfcn-multi',
|
| 118 |
+
help=(
|
| 119 |
+
'Name of the model to use'
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument(
|
| 123 |
+
'-t', '--targets',
|
| 124 |
+
required=False,
|
| 125 |
+
nargs='+',
|
| 126 |
+
default=[
|
| 127 |
+
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
|
| 128 |
+
'neuroticism'
|
| 129 |
+
],
|
| 130 |
+
help='Name to use for each of the prediction heads in the output CSV'
|
| 131 |
+
)
|
| 132 |
parser.add_argument(
|
| 133 |
'-d', '--destination',
|
| 134 |
required=False,
|
|
|
|
| 149 |
predict_from_bids_folder(
|
| 150 |
source=args.bids,
|
| 151 |
weights=args.weights,
|
| 152 |
+
model_name=args.model,
|
| 153 |
+
targets=args.targets,
|
| 154 |
destination=args.destination,
|
| 155 |
per_image_normalization=args.per_image_normalization
|
| 156 |
)
|
scripts/predict_from_fastsurfer_folder.py
CHANGED
|
@@ -5,11 +5,11 @@ import re
|
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
from tqdm import tqdm
|
| 8 |
-
from typing import Tuple
|
| 9 |
|
| 10 |
import nibabel as nib
|
| 11 |
|
| 12 |
-
from pyment.models import
|
| 13 |
from pyment.preprocessing.conform import conform
|
| 14 |
|
| 15 |
|
|
@@ -29,14 +29,20 @@ def _parse_folder_name(name: str) -> Tuple[str, str, str]:
|
|
| 29 |
|
| 30 |
def predict_from_fastsurfer_folder(
|
| 31 |
source: str,
|
| 32 |
-
weights: str,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
destination: str = None
|
| 34 |
) -> pd.DataFrame:
|
| 35 |
if destination is not None and os.path.isfile(destination):
|
| 36 |
raise ValueError(f'Destination {destination} already exists')
|
| 37 |
|
| 38 |
logger.info('Loading multi-task model with weights %s', weights)
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
results = []
|
| 42 |
|
|
@@ -66,19 +72,18 @@ def predict_from_fastsurfer_folder(
|
|
| 66 |
image = conform(image)
|
| 67 |
|
| 68 |
predictions = model.predict(np.expand_dims(image, axis=0))[0]
|
|
|
|
|
|
|
| 69 |
logger.debug('Predictions for %s: %s', folder, str(predictions))
|
| 70 |
|
| 71 |
results.append({
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
'bmi': predictions[3],
|
| 80 |
-
'fluid_intelligence': predictions[4],
|
| 81 |
-
'neuroticism': predictions[5]
|
| 82 |
})
|
| 83 |
|
| 84 |
results = pd.DataFrame(results)
|
|
@@ -110,6 +115,24 @@ if __name__ == '__main__':
|
|
| 110 |
'exist files named <path>.index and <path>.data-00000-of-00001'
|
| 111 |
)
|
| 112 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
parser.add_argument(
|
| 114 |
'-d', '--destination',
|
| 115 |
required=False,
|
|
@@ -121,7 +144,9 @@ if __name__ == '__main__':
|
|
| 121 |
|
| 122 |
predict_from_fastsurfer_folder(
|
| 123 |
source=args.root,
|
|
|
|
| 124 |
weights=args.weights,
|
|
|
|
| 125 |
destination=args.destination
|
| 126 |
)
|
| 127 |
|
|
|
|
| 5 |
import numpy as np
|
| 6 |
import pandas as pd
|
| 7 |
from tqdm import tqdm
|
| 8 |
+
from typing import List, Tuple
|
| 9 |
|
| 10 |
import nibabel as nib
|
| 11 |
|
| 12 |
+
from pyment.models.sfcn import sfcn_factory
|
| 13 |
from pyment.preprocessing.conform import conform
|
| 14 |
|
| 15 |
|
|
|
|
| 29 |
|
| 30 |
def predict_from_fastsurfer_folder(
|
| 31 |
source: str,
|
| 32 |
+
weights: str,
|
| 33 |
+
model_name: str = 'sfcn-multi',
|
| 34 |
+
targets: List[str] = [
|
| 35 |
+
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence', 'neuroticism'
|
| 36 |
+
],
|
| 37 |
destination: str = None
|
| 38 |
) -> pd.DataFrame:
|
| 39 |
if destination is not None and os.path.isfile(destination):
|
| 40 |
raise ValueError(f'Destination {destination} already exists')
|
| 41 |
|
| 42 |
logger.info('Loading multi-task model with weights %s', weights)
|
| 43 |
+
|
| 44 |
+
model_class = sfcn_factory(model_name)
|
| 45 |
+
model = model_class(weights=weights)
|
| 46 |
|
| 47 |
results = []
|
| 48 |
|
|
|
|
| 72 |
image = conform(image)
|
| 73 |
|
| 74 |
predictions = model.predict(np.expand_dims(image, axis=0))[0]
|
| 75 |
+
print(predictions.shape)
|
| 76 |
+
print(predictions)
|
| 77 |
logger.debug('Predictions for %s: %s', folder, str(predictions))
|
| 78 |
|
| 79 |
results.append({
|
| 80 |
+
**{
|
| 81 |
+
'source': os.path.join(source, folder),
|
| 82 |
+
'subject': subject,
|
| 83 |
+
'session': session,
|
| 84 |
+
'run': run
|
| 85 |
+
},
|
| 86 |
+
**{targets[i]: predictions[i] for i in range(len(targets))}
|
|
|
|
|
|
|
|
|
|
| 87 |
})
|
| 88 |
|
| 89 |
results = pd.DataFrame(results)
|
|
|
|
| 115 |
'exist files named <path>.index and <path>.data-00000-of-00001'
|
| 116 |
)
|
| 117 |
)
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
'-m', '--model',
|
| 120 |
+
required=False,
|
| 121 |
+
default='sfcn-multi',
|
| 122 |
+
help=(
|
| 123 |
+
'Name of the model to use'
|
| 124 |
+
)
|
| 125 |
+
)
|
| 126 |
+
parser.add_argument(
|
| 127 |
+
'-t', '--targets',
|
| 128 |
+
required=False,
|
| 129 |
+
nargs='+',
|
| 130 |
+
default=[
|
| 131 |
+
'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
|
| 132 |
+
'neuroticism'
|
| 133 |
+
],
|
| 134 |
+
help='Name to use for each of the prediction heads in the output CSV'
|
| 135 |
+
)
|
| 136 |
parser.add_argument(
|
| 137 |
'-d', '--destination',
|
| 138 |
required=False,
|
|
|
|
| 144 |
|
| 145 |
predict_from_fastsurfer_folder(
|
| 146 |
source=args.root,
|
| 147 |
+
model_name=args.model,
|
| 148 |
weights=args.weights,
|
| 149 |
+
targets=args.targets,
|
| 150 |
destination=args.destination
|
| 151 |
)
|
| 152 |
|