Code changes
Browse files- README.md +45 -44
- inference_brain2vec.py +6 -7
- train_brain2vec.py +31 -42
README.md
CHANGED
|
@@ -13,23 +13,29 @@ pretty_name: 3D Brain Structure MRI Autoencoder
|
|
| 13 |
|
| 14 |
## 🧠 Model Summary
|
| 15 |
# brain2vec
|
| 16 |
-
An autoencoder model for brain structure T1 MRIs
|
| 17 |
- [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
|
| 18 |
- [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
|
| 19 |
- [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
|
| 20 |
- [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
|
| 21 |
|
| 22 |
|
| 23 |
-
|
| 24 |
# Training data
|
| 25 |
[Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
|
| 26 |
|
|
|
|
| 27 |
# Example usage
|
| 28 |
```
|
| 29 |
# get brain2vec model repository
|
| 30 |
git clone https://huggingface.co/radiata-ai/brain2vec
|
| 31 |
cd brain2vec
|
| 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# set up virtual environemt
|
| 34 |
python3 -m venv venv_brain2vec
|
| 35 |
source venv_brain2vec/bin/activate
|
|
@@ -38,54 +44,54 @@ source venv_brain2vec/bin/activate
|
|
| 38 |
pip install -r requirements.txt
|
| 39 |
|
| 40 |
# create the csv file inputs.csv listing the scan paths and other info
|
| 41 |
-
# this script loads the radiata-ai/brain-structure dataset
|
| 42 |
python create_csv.py
|
| 43 |
|
| 44 |
mkdir ae_cache
|
| 45 |
mkdir ae_output
|
| 46 |
|
| 47 |
-
# install git lfs to pull large model weights
|
| 48 |
-
sudo apt-get update
|
| 49 |
-
sudo apt install git-lfs
|
| 50 |
-
git lfs install
|
| 51 |
-
git lfs pull
|
| 52 |
-
|
| 53 |
# train the model
|
| 54 |
-
nohup python
|
| 55 |
-
--dataset_csv
|
| 56 |
--cache_dir ./ae_cache \
|
| 57 |
--output_dir ./ae_output \
|
| 58 |
--n_epochs 10 \
|
| 59 |
> train_log.txt 2>&1 &
|
| 60 |
|
| 61 |
# model inference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
python inference_brain2vec.py \
|
| 63 |
--checkpoint_path /path/to/model.pth \
|
| 64 |
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
| 65 |
-
--output_dir ./
|
| 66 |
-
--embeddings_filename
|
| 67 |
-
--save_recons
|
| 68 |
```
|
| 69 |
|
| 70 |
# Methods
|
| 71 |
Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
|
| 72 |
|
| 73 |
-
The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1].
|
|
|
|
|
|
|
| 74 |
|
| 75 |
-
10 epochs
|
| 76 |
-
max_batch_size: int = 2,
|
| 77 |
-
batch_size: int = 16,
|
| 78 |
-
lr: float = 1e-4,
|
| 79 |
|
| 80 |
# References
|
| 81 |
-
Puglisi
|
| 82 |
-
Pinaya
|
|
|
|
| 83 |
|
| 84 |
# Citation
|
| 85 |
```
|
| 86 |
-
@misc{Radiata-
|
| 87 |
author = {Jesse Brown and Clayton Young},
|
| 88 |
-
title = {
|
| 89 |
year = {2025},
|
| 90 |
url = {https://huggingface.co/radiata-ai/brain2vec},
|
| 91 |
note = {Version 1.0},
|
|
@@ -93,25 +99,20 @@ Pinaya
|
|
| 93 |
}
|
| 94 |
```
|
| 95 |
|
|
|
|
| 96 |
# License
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
Copyright
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 113 |
-
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 114 |
-
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 115 |
-
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 116 |
-
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 117 |
-
SOFTWARE.
|
|
|
|
| 13 |
|
| 14 |
## 🧠 Model Summary
|
| 15 |
# brain2vec
|
| 16 |
+
An autoencoder model for brain structure T1 MRIs (forked from [Brain Latent Progression](https://github.com/LemuelPuglisi/BrLP/tree/main)). The autoencoder takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image. The loss functions for training the autoencoder are:
|
| 17 |
- [L1Loss](https://pytorch.org/docs/stable/generated/torch.nn.L1Loss.html)
|
| 18 |
- [KLDivergenceLoss](https://pytorch.org/docs/stable/generated/torch.nn.KLDivLoss.html)
|
| 19 |
- [PatchAdversarialLoss](https://docs.monai.io/en/stable/losses.html#patchadversarialloss)
|
| 20 |
- [PerceptualLoss](https://docs.monai.io/en/stable/losses.html#perceptualloss)
|
| 21 |
|
| 22 |
|
|
|
|
| 23 |
# Training data
|
| 24 |
[Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
|
| 25 |
|
| 26 |
+
|
| 27 |
# Example usage
|
| 28 |
```
|
| 29 |
# get brain2vec model repository
|
| 30 |
git clone https://huggingface.co/radiata-ai/brain2vec
|
| 31 |
cd brain2vec
|
| 32 |
|
| 33 |
+
# pull pre-trained model weights
|
| 34 |
+
sudo apt-get update
|
| 35 |
+
sudo apt install git-lfs
|
| 36 |
+
git lfs install
|
| 37 |
+
git lfs pull
|
| 38 |
+
|
| 39 |
# set up virtual environemt
|
| 40 |
python3 -m venv venv_brain2vec
|
| 41 |
source venv_brain2vec/bin/activate
|
|
|
|
| 44 |
pip install -r requirements.txt
|
| 45 |
|
| 46 |
# create the csv file inputs.csv listing the scan paths and other info
|
| 47 |
+
# this script loads the radiata-ai/brain-structure dataset from Hugging Face
|
| 48 |
python create_csv.py
|
| 49 |
|
| 50 |
mkdir ae_cache
|
| 51 |
mkdir ae_output
|
| 52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
# train the model
|
| 54 |
+
nohup python train_brain2vec.py \
|
| 55 |
+
--dataset_csv inputs.csv \
|
| 56 |
--cache_dir ./ae_cache \
|
| 57 |
--output_dir ./ae_output \
|
| 58 |
--n_epochs 10 \
|
| 59 |
> train_log.txt 2>&1 &
|
| 60 |
|
| 61 |
# model inference
|
| 62 |
+
# for a set of scans in inputs.csv
|
| 63 |
+
python inference_brain2vec.py \
|
| 64 |
+
--checkpoint_path /path/to/model.pth \
|
| 65 |
+
--csv_input inputs.csv \
|
| 66 |
+
--output_dir ./ae_output \
|
| 67 |
+
--embeddings_filename ae_embeddings_all.npy
|
| 68 |
+
|
| 69 |
+
# or for individual scans
|
| 70 |
python inference_brain2vec.py \
|
| 71 |
--checkpoint_path /path/to/model.pth \
|
| 72 |
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
| 73 |
+
--output_dir ./ae_output \
|
| 74 |
+
--embeddings_filename ae_embeddings_2.npy
|
|
|
|
| 75 |
```
|
| 76 |
|
| 77 |
# Methods
|
| 78 |
Input scan image dimensions are 113x137x113, 1.5mm^3 resolution, aligned to MNI152 space (see [radiata-ai/brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure)).
|
| 79 |
|
| 80 |
+
The image transform crops to 80 x 96 x 80, 2mm^3 resolution, and scales image intensity to range [0,1].
|
| 81 |
+
|
| 82 |
+
The model was trained with an effective batch size=16, 10 epochs, learning rate=1e-4 (see references 1 and 2).
|
| 83 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
# References
|
| 86 |
+
1. Puglisi L, Alexander DC, Ravì D. Enhancing Spatiotemporal Disease Progression Models via Latent Diffusion and Prior Knowledge [Internet]. arXiv; 2024. Available from: http://arxiv.org/abs/2405.03328
|
| 87 |
+
2. Pinaya WHL, Tudosiu PD, Dafflon J, Costa PF da, Fernandez V, Nachev P, et al. Brain Imaging Generation with Latent Diffusion Models [Internet]. arXiv; 2022. Available from: http://arxiv.org/abs/2209.07162
|
| 88 |
+
|
| 89 |
|
| 90 |
# Citation
|
| 91 |
```
|
| 92 |
+
@misc{Radiata-Brain2vec,
|
| 93 |
author = {Jesse Brown and Clayton Young},
|
| 94 |
+
title = {Brain2vec: An Autoencoder Model for Brain Structure T1 MRIs},
|
| 95 |
year = {2025},
|
| 96 |
url = {https://huggingface.co/radiata-ai/brain2vec},
|
| 97 |
note = {Version 1.0},
|
|
|
|
| 99 |
}
|
| 100 |
```
|
| 101 |
|
| 102 |
+
|
| 103 |
# License
|
| 104 |
+
### Apache License 2.0
|
| 105 |
+
|
| 106 |
+
Copyright 2025 Jesse Brown
|
| 107 |
+
|
| 108 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 109 |
+
you may not use this file except in compliance with the License.
|
| 110 |
+
You may obtain a copy of the License at:
|
| 111 |
+
|
| 112 |
+
[http://www.apache.org/licenses/LICENSE-2.0](http://www.apache.org/licenses/LICENSE-2.0)
|
| 113 |
+
|
| 114 |
+
Unless required by applicable law or agreed to in writing, software
|
| 115 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 116 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 117 |
+
See the License for the specific language governing permissions and
|
| 118 |
+
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_brain2vec.py
CHANGED
|
@@ -143,10 +143,6 @@ def main() -> None:
|
|
| 143 |
"--output_dir", type=str, default="./vae_inference_outputs",
|
| 144 |
help="Directory to save reconstructions and latent parameters."
|
| 145 |
)
|
| 146 |
-
parser.add_argument(
|
| 147 |
-
"--device", type=str, default="cpu",
|
| 148 |
-
help="Device to run inference on ('cpu', 'cuda', etc.)."
|
| 149 |
-
)
|
| 150 |
# Two ways to supply images: multiple file paths or a CSV
|
| 151 |
parser.add_argument(
|
| 152 |
"--input_images", type=str, nargs="*",
|
|
@@ -172,10 +168,13 @@ def main() -> None:
|
|
| 172 |
|
| 173 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 174 |
|
| 175 |
-
#
|
|
|
|
|
|
|
|
|
|
| 176 |
model = Brain2vec.from_pretrained(
|
| 177 |
checkpoint_path=args.checkpoint_path,
|
| 178 |
-
device=
|
| 179 |
)
|
| 180 |
|
| 181 |
# Gather image paths
|
|
@@ -199,7 +198,7 @@ def main() -> None:
|
|
| 199 |
raise FileNotFoundError(f"Image not found: {img_path}")
|
| 200 |
|
| 201 |
print(f"[INFO] Processing image {i}: {img_path}")
|
| 202 |
-
img_tensor = preprocess_mri(img_path, device=
|
| 203 |
|
| 204 |
with torch.no_grad():
|
| 205 |
recon, z_mu, z_sigma = model.forward(img_tensor)
|
|
|
|
| 143 |
"--output_dir", type=str, default="./vae_inference_outputs",
|
| 144 |
help="Directory to save reconstructions and latent parameters."
|
| 145 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 146 |
# Two ways to supply images: multiple file paths or a CSV
|
| 147 |
parser.add_argument(
|
| 148 |
"--input_images", type=str, nargs="*",
|
|
|
|
| 168 |
|
| 169 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 170 |
|
| 171 |
+
# After parsing args, add:
|
| 172 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 173 |
+
|
| 174 |
+
# Then pass that device to the model:
|
| 175 |
model = Brain2vec.from_pretrained(
|
| 176 |
checkpoint_path=args.checkpoint_path,
|
| 177 |
+
device=device
|
| 178 |
)
|
| 179 |
|
| 180 |
# Gather image paths
|
|
|
|
| 198 |
raise FileNotFoundError(f"Image not found: {img_path}")
|
| 199 |
|
| 200 |
print(f"[INFO] Processing image {i}: {img_path}")
|
| 201 |
+
img_tensor = preprocess_mri(img_path, device=device)
|
| 202 |
|
| 203 |
with torch.no_grad():
|
| 204 |
recon, z_mu, z_sigma = model.forward(img_tensor)
|
train_brain2vec.py
CHANGED
|
@@ -9,10 +9,10 @@ a perceptual loss, and KL divergence regularization for robust latent
|
|
| 9 |
representations.
|
| 10 |
|
| 11 |
Example usage:
|
| 12 |
-
python train_brain2vec.py
|
| 13 |
-
--dataset_csv
|
| 14 |
-
--cache_dir
|
| 15 |
-
--output_dir
|
| 16 |
--n_epochs 10
|
| 17 |
"""
|
| 18 |
|
|
@@ -487,50 +487,39 @@ def train(
|
|
| 487 |
|
| 488 |
def main():
|
| 489 |
"""
|
| 490 |
-
Main function to parse command-line arguments and
|
| 491 |
"""
|
|
|
|
|
|
|
| 492 |
parser = argparse.ArgumentParser(description="brain2vec Training Script")
|
| 493 |
|
| 494 |
-
|
| 495 |
-
|
| 496 |
-
|
| 497 |
-
|
| 498 |
-
|
| 499 |
-
|
| 500 |
-
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
train_parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
|
| 506 |
-
train_parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
|
| 507 |
-
train_parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
| 508 |
-
train_parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
| 509 |
|
| 510 |
args = parser.parse_args()
|
| 511 |
|
| 512 |
-
|
| 513 |
-
|
| 514 |
-
|
| 515 |
-
|
| 516 |
-
|
| 517 |
-
|
| 518 |
-
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
|
| 522 |
-
|
| 523 |
-
|
| 524 |
-
|
| 525 |
-
)
|
| 526 |
-
elif args.command == 'infer':
|
| 527 |
-
inference(
|
| 528 |
-
dataset_csv=args.dataset_csv,
|
| 529 |
-
aekl_ckpt=args.aekl_ckpt,
|
| 530 |
-
output_dir=args.output_dir,
|
| 531 |
-
)
|
| 532 |
-
else:
|
| 533 |
-
parser.print_help()
|
| 534 |
|
| 535 |
|
| 536 |
if __name__ == '__main__':
|
|
|
|
| 9 |
representations.
|
| 10 |
|
| 11 |
Example usage:
|
| 12 |
+
python train_brain2vec.py \
|
| 13 |
+
--dataset_csv inputs.csv \
|
| 14 |
+
--cache_dir ./ae_cache \
|
| 15 |
+
--output_dir ./ae_output \
|
| 16 |
--n_epochs 10
|
| 17 |
"""
|
| 18 |
|
|
|
|
| 487 |
|
| 488 |
def main():
|
| 489 |
"""
|
| 490 |
+
Main function to parse command-line arguments and run train().
|
| 491 |
"""
|
| 492 |
+
import argparse
|
| 493 |
+
|
| 494 |
parser = argparse.ArgumentParser(description="brain2vec Training Script")
|
| 495 |
|
| 496 |
+
parser.add_argument('--dataset_csv', type=str, required=True, help='Path to the dataset CSV file.')
|
| 497 |
+
parser.add_argument('--cache_dir', type=str, required=True, help='Directory for caching data.')
|
| 498 |
+
parser.add_argument('--output_dir', type=str, required=True, help='Directory to save model checkpoints.')
|
| 499 |
+
parser.add_argument('--aekl_ckpt', type=str, default=None, help='Path to the autoencoder checkpoint.')
|
| 500 |
+
parser.add_argument('--disc_ckpt', type=str, default=None, help='Path to the discriminator checkpoint.')
|
| 501 |
+
parser.add_argument('--num_workers', type=int, default=8, help='Number of data loader workers.')
|
| 502 |
+
parser.add_argument('--n_epochs', type=int, default=5, help='Number of training epochs.')
|
| 503 |
+
parser.add_argument('--max_batch_size', type=int, default=2, help='Actual batch size per iteration.')
|
| 504 |
+
parser.add_argument('--batch_size', type=int, default=16, help='Expected (effective) batch size.')
|
| 505 |
+
parser.add_argument('--lr', type=float, default=1e-4, help='Learning rate.')
|
| 506 |
+
parser.add_argument('--aug_p', type=float, default=0.8, help='Augmentation probability.')
|
|
|
|
|
|
|
|
|
|
|
|
|
| 507 |
|
| 508 |
args = parser.parse_args()
|
| 509 |
|
| 510 |
+
train(
|
| 511 |
+
dataset_csv=args.dataset_csv,
|
| 512 |
+
cache_dir=args.cache_dir,
|
| 513 |
+
output_dir=args.output_dir,
|
| 514 |
+
aekl_ckpt=args.aekl_ckpt,
|
| 515 |
+
disc_ckpt=args.disc_ckpt,
|
| 516 |
+
num_workers=args.num_workers,
|
| 517 |
+
n_epochs=args.n_epochs,
|
| 518 |
+
max_batch_size=args.max_batch_size,
|
| 519 |
+
batch_size=args.batch_size,
|
| 520 |
+
lr=args.lr,
|
| 521 |
+
aug_p=args.aug_p,
|
| 522 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 523 |
|
| 524 |
|
| 525 |
if __name__ == '__main__':
|