pyment-public / scripts /finetune_from_bids_folder.py
Esten Leonardsen
Small changes to BIDS scripts
feeef85
import argparse
import json
import logging
import os
import pandas as pd
from typing import Any, Callable, Dict, List, Tuple
import tensorflow as tf
from tensorflow_neuroimaging.preprocessing import center_crop_or_pad
from tensorflow_neuroimaging.loaders.mgh import load_mgh
from pyment.configurations import DatasetConfiguration, FinetuningConfiguration
from pyment.factories import loss_factory, metric_factory, optimizer_factory
from pyment.models.sfcn import sfcn_factory
from pyment.models.utils.load_select_pretrained_weights import (
load_select_pretrained_weights
)
from pyment.utils.json_serialize import json_serialize
logging.basicConfig(
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
level=logging.DEBUG
)
logger = logging.getLogger(__name__)
def _create_tensorflow_dataset(
df: pd.DataFrame, *,
target: str,
input_shape: Tuple[int, int, int],
batch_size: str,
shuffle: bool = False
) -> tf.data.Dataset:
input_shape = tf.constant(input_shape)
df = df.copy()
df = df.sample(frac=1.)
dataset = tf.data.Dataset.from_tensor_slices((df['path'], df[target]))
if shuffle:
dataset = dataset.shuffle(buffer_size=5*batch_size)
dataset = dataset.map(
lambda path, label: (load_mgh(path), label),
num_parallel_calls=tf.data.AUTOTUNE
)
dataset = dataset.map(
lambda image, label: (center_crop_or_pad(image, input_shape), label),
num_parallel_calls=tf.data.AUTOTUNE
)
dataset = dataset.batch(batch_size)
dataset = dataset.prefetch(tf.data.AUTOTUNE)
return dataset
def _create_checkpointing_callback(
destination: str,
metrics: List[tf.keras.metrics.Metric] = None
):
os.mkdir(destination)
train_metrics = []
val_metrics = []
if metrics is not None:
for metric in metrics:
name = metric.name.replace('_', '-')
train_metrics.append(f'{name}={{{metric.name}:.2f}}')
val_metrics.append(f'val-{name}={{val_{metric.name}:.2f}}')
terms = [
'epoch={epoch:03d}',
'loss={loss:.2f}'
] + train_metrics + [
'val-loss={val_loss:.2f}'
] + val_metrics
filename = '_'.join(terms) + '.hdf5'
filepath = os.path.join(destination, filename)
return tf.keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
save_best_only=True,
save_weights_only=True
)
def finetune(
model_type: str,
model_constructor_arguments: Dict[str, Any],
weights: str,
input_shape: Tuple[int, int, int],
target: str,
loss: Callable[[tf.Tensor, tf.Tensor], tf.Tensor],
metrics: List[tf.keras.metrics.Metric],
optimizer: tf.optimizers.Optimizer,
learning_rate_scheduler: tf.keras.callbacks.Callback,
training: pd.DataFrame,
validation: pd.DataFrame,
batch_size: int,
epochs: int,
destination: str
):
if destination is not None:
if os.path.isdir(destination):
raise ValueError(f'Destination {destination} already exists')
logger.info('Creating destination folder %s', destination)
os.mkdir(destination)
model_class = sfcn_factory(model_type)
model = model_class(
input_shape=input_shape,
**model_constructor_arguments
)
load_select_pretrained_weights(model, weights, target=target)
model.compile(loss=loss, optimizer=optimizer, metrics=metrics)
training_dataset = _create_tensorflow_dataset(
training,
input_shape=input_shape,
target=target,
batch_size=batch_size,
shuffle=True
)
validation_dataset = _create_tensorflow_dataset(
validation,
input_shape=input_shape,
target=target,
batch_size=batch_size,
shuffle=False
)
callbacks = [
_create_checkpointing_callback(
os.path.join(destination, 'checkpoints'),
metrics=metrics
),
learning_rate_scheduler
]
history = model.fit(
training_dataset,
validation_data=validation_dataset,
epochs=epochs,
callbacks=callbacks
)
with open(os.path.join(destination, 'history.json'), 'w') as f:
json.dump(json_serialize(history.history), f)
def finetune_from_configuration(configuration: str):
with open(configuration, 'r') as f:
configuration = json.load(f)
configuration = FinetuningConfiguration.model_validate(configuration)
training, validation = DatasetConfiguration.parse(
configuration.data,
target=configuration.training.target
)
loss_cls = loss_factory(configuration.training.loss)
loss = loss_cls()
optimizer_cls = optimizer_factory(configuration.training.optimizer)
optimizer = optimizer_cls(configuration.training.learning_rate)
metrics = None
if configuration.training.metrics is not None:
metrics = [
metric_factory(metric)
for metric in configuration.training.metrics
]
learning_rate_scheduler = None
if configuration.training.learning_rate_schedule:
learning_rate_scheduler = (
configuration.training.learning_rate_schedule.instantiate()
)
finetune(
model_type=configuration.model.type,
model_constructor_arguments=configuration.model.hyperparameters,
weights=configuration.model.weights,
input_shape=configuration.data.input_shape,
target=configuration.training.target,
loss=loss,
metrics=metrics,
optimizer=optimizer,
learning_rate_scheduler=learning_rate_scheduler,
training=training,
validation=validation,
batch_size=configuration.training.batch_size,
epochs=configuration.training.epochs,
destination=configuration.training.destination
)
if __name__ == '__main__':
parser = argparse.ArgumentParser(
'Finetunes a multi-task SFCN according to the given configuration'
)
parser.add_argument('configuration', help='Path to configuration JSON')
args = parser.parse_args()
finetune_from_configuration(args.configuration)