import argparse import json import logging import os import pandas as pd from typing import Any, Callable, Dict, List, Tuple import tensorflow as tf from tensorflow_neuroimaging.preprocessing import center_crop_or_pad from tensorflow_neuroimaging.loaders.mgh import load_mgh from pyment.configurations import DatasetConfiguration, FinetuningConfiguration from pyment.factories import loss_factory, metric_factory, optimizer_factory from pyment.models.sfcn import sfcn_factory from pyment.models.utils.load_select_pretrained_weights import ( load_select_pretrained_weights ) from pyment.utils.json_serialize import json_serialize logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s: %(message)s', level=logging.DEBUG ) logger = logging.getLogger(__name__) def _create_tensorflow_dataset( df: pd.DataFrame, *, target: str, input_shape: Tuple[int, int, int], batch_size: str, shuffle: bool = False ) -> tf.data.Dataset: input_shape = tf.constant(input_shape) df = df.copy() df = df.sample(frac=1.) dataset = tf.data.Dataset.from_tensor_slices((df['path'], df[target])) if shuffle: dataset = dataset.shuffle(buffer_size=5*batch_size) dataset = dataset.map( lambda path, label: (load_mgh(path), label), num_parallel_calls=tf.data.AUTOTUNE ) dataset = dataset.map( lambda image, label: (center_crop_or_pad(image, input_shape), label), num_parallel_calls=tf.data.AUTOTUNE ) dataset = dataset.batch(batch_size) dataset = dataset.prefetch(tf.data.AUTOTUNE) return dataset def _create_checkpointing_callback( destination: str, metrics: List[tf.keras.metrics.Metric] = None ): os.mkdir(destination) train_metrics = [] val_metrics = [] if metrics is not None: for metric in metrics: name = metric.name.replace('_', '-') train_metrics.append(f'{name}={{{metric.name}:.2f}}') val_metrics.append(f'val-{name}={{val_{metric.name}:.2f}}') terms = [ 'epoch={epoch:03d}', 'loss={loss:.2f}' ] + train_metrics + [ 'val-loss={val_loss:.2f}' ] + val_metrics filename = '_'.join(terms) + '.hdf5' filepath = os.path.join(destination, filename) return tf.keras.callbacks.ModelCheckpoint( filepath, monitor='val_loss', save_best_only=True, save_weights_only=True ) def finetune( model_type: str, model_constructor_arguments: Dict[str, Any], weights: str, input_shape: Tuple[int, int, int], target: str, loss: Callable[[tf.Tensor, tf.Tensor], tf.Tensor], metrics: List[tf.keras.metrics.Metric], optimizer: tf.optimizers.Optimizer, learning_rate_scheduler: tf.keras.callbacks.Callback, training: pd.DataFrame, validation: pd.DataFrame, batch_size: int, epochs: int, destination: str ): if destination is not None: if os.path.isdir(destination): raise ValueError(f'Destination {destination} already exists') logger.info('Creating destination folder %s', destination) os.mkdir(destination) model_class = sfcn_factory(model_type) model = model_class( input_shape=input_shape, **model_constructor_arguments ) load_select_pretrained_weights(model, weights, target=target) model.compile(loss=loss, optimizer=optimizer, metrics=metrics) training_dataset = _create_tensorflow_dataset( training, input_shape=input_shape, target=target, batch_size=batch_size, shuffle=True ) validation_dataset = _create_tensorflow_dataset( validation, input_shape=input_shape, target=target, batch_size=batch_size, shuffle=False ) callbacks = [ _create_checkpointing_callback( os.path.join(destination, 'checkpoints'), metrics=metrics ), learning_rate_scheduler ] history = model.fit( training_dataset, validation_data=validation_dataset, epochs=epochs, callbacks=callbacks ) with open(os.path.join(destination, 'history.json'), 'w') as f: json.dump(json_serialize(history.history), f) def finetune_from_configuration(configuration: str): with open(configuration, 'r') as f: configuration = json.load(f) configuration = FinetuningConfiguration.model_validate(configuration) training, validation = DatasetConfiguration.parse( configuration.data, target=configuration.training.target ) loss_cls = loss_factory(configuration.training.loss) loss = loss_cls() optimizer_cls = optimizer_factory(configuration.training.optimizer) optimizer = optimizer_cls(configuration.training.learning_rate) metrics = None if configuration.training.metrics is not None: metrics = [ metric_factory(metric) for metric in configuration.training.metrics ] learning_rate_scheduler = None if configuration.training.learning_rate_schedule: learning_rate_scheduler = ( configuration.training.learning_rate_schedule.instantiate() ) finetune( model_type=configuration.model.type, model_constructor_arguments=configuration.model.hyperparameters, weights=configuration.model.weights, input_shape=configuration.data.input_shape, target=configuration.training.target, loss=loss, metrics=metrics, optimizer=optimizer, learning_rate_scheduler=learning_rate_scheduler, training=training, validation=validation, batch_size=configuration.training.batch_size, epochs=configuration.training.epochs, destination=configuration.training.destination ) if __name__ == '__main__': parser = argparse.ArgumentParser( 'Finetunes a multi-task SFCN according to the given configuration' ) parser.add_argument('configuration', help='Path to configuration JSON') args = parser.parse_args() finetune_from_configuration(args.configuration)