--- license: apache-2.0 pipeline_tag: image-segmentation tags: - medical --- # U-Net Transplant: Model Merging for 3D Medical Segmentation ![alt text](https://raw.githubusercontent.com/LucaLumetti/UNetTransplant/refs/heads/main/assets/thumbnail.png) This repository contains the implementation of **U-Net Transplant**, a framework for efficient model merging in 3D medical image segmentation. Model merging enables the combination of specialized segmentation models without requiring full retraining, offering a flexible and privacy-conscious solution for updating AI models in clinical applications. Our approach leverages **task vectors** and encourages **wide minima** during pre-training to enhance the effectiveness of model merging. We evaluate this method using the **ToothFairy2** and **BTCV Abdomen** datasets with a standard **3D U-Net** architecture, demonstrating its ability to integrate multiple specialized segmentation tasks into a single model. # Pretrain and Task Vector Checkpoints The related checkpoints and task vectors used in the paper will be available from the 23rd June 2025. # How to Run ### 1. Clone the Repository ```bash git clone git@github.com:LucaLumetti/UNetTransplant.git cd UNetTransplant ``` ### 2. Setup Environment ```bash python -m venv env source env/bin/activate pip install -r requirements.txt ``` ### 3. Downloads Ensure the datasets are downloaded and organized following the nnUNet dataset format. - **BTCV Abdomen**: [Download Here](https://www.synapse.org/Synapse:syn3193805/wiki/217753) - **ToothFairy2**: [Download Here](https://ditto.ing.unimore.it/toothfairy2/) - **AMOS**: [Download Here](https://zenodo.org/records/7262581) - **ZhimingCui**: Available upon request from the authors ([Paper](https://www.nature.com/articles/s41467-022-29637-2)) You can also download pretrained checkpoints and task vectors: ```bash #!/bin/bash BASE_ABDOMEN="https://huggingface.co/Lumett/UNetTransplant/resolve/main/Abdomen" BASE_TOOTHFAIRY="https://huggingface.co/Lumett/UNetTransplant/resolve/main/ToothFairy" abdomen_files=( Pretrain_AMOS.pth TaskVector_Kidney_Abdomen.pth TaskVector_Liver_Abdomen.pth TaskVector_Spleen_Abdomen.pth TaskVector_Stomach_Abdomen.pth ) toothfairy_files=( Pretrain_Cui.pth TaskVector_Canals_ToothFairy2.pth TaskVector_Mandible_ToothFairy2.pth TaskVector_Teeth_ToothFairy2.pth TaskVector_Pharynx_ToothFairy2.pth ) echo "🩻 Downloading Abdomen files..." for file in "${abdomen_files[@]}"; do wget -c "${BASE_ABDOMEN}/${file}" done echo "🦷 Downloading ToothFairy files..." for file in "${toothfairy_files[@]}"; do wget -c "${BASE_TOOTHFAIRY}/${file}" done ``` ### 4. Running the U-Net Transplant Framework The main script for running experiments is `main.py`. It requires specifying the type of experiment and a configuration file that defines dataset, model, optimizer, and training parameters. #### Command Structure ```bash python main.py --experiment --config [--expname ] [--override ] ``` #### Arguments - **`--experiment`**: Specifies the type of experiment to run. - `"PretrainExperiment"` → Pretrains the model from scratch. - `"TaskVectorTrainExperiment"` → Trains a task vector using a pretrained checkpoint. - **`--config`**: Path to the configuration file, which defines dataset, model, and training settings. - **`--expname`** (optional): Custom experiment name. If not provided, the config filename is used. - **`--override`** (optional): Allows overriding config values at runtime. Example: ```bash python main.py --experiment PretrainExperiment --config configs/default.yaml --override DataConfig.BATCH_SIZE=4 OptimizerConfig.LR=0.01 ``` #### Configuration File The configuration file defines: - **Dataset** (`DataConfig`): Path, batch size, patch size, and datasets used. - **Model** (`BackboneConfig` & `HeadsConfig`): Architecture, checkpoints, and initialization. - **Optimizer** (`OptimizerConfig`): Learning rates, weight decay, and momentum. - **Loss Function** (`LossConfig`): Defines the loss function used. - **Training** (`TrainConfig`): Number of epochs, checkpoint saving, and resume options. Check [the provided configs](https://github.com/LucaLumetti/UNetTransplant/tree/main/configs/miccai2025) for examples. #### Example Commands 1. **Pretraining a model**: ```bash python main.py --experiment PretrainExperiment --config configs/miccai2025/pretrain_stable.yaml ``` 2. **Training a task vector from a checkpoint**: ```bash python main.py --experiment TaskVectorTrainExperiment --config configs/miccai2025/finetune.yaml --override BackboneConfig.PRETRAIN_CHECKPOINTS="/path/to/checkpoint.pth" ``` For further details, refer to the config files used in our experiments under the `configs` folder. ### 5. Cite If you used our work, please cite it: ``` @incollection{lumetti2025u, title={U-Net Transplant: The Role of Pre-training for Model Merging in 3D Medical Segmentation}, author={Lumetti, Luca and Capitani, Giacomo and Ficarra, Elisa and Grana, Costantino and Calderara, Simone and Porrello, Angelo and Bolelli, Federico and others}, booktitle={Medical Image Computing and Computer Assisted Intervention--MICCAI 2025}, year={2025} } ```