Initial commit
Browse files- README.md +103 -3
- brain2vec_PCA.py +194 -0
- create_csv.py +39 -0
- inputs_example.csv +6 -0
- model.py +121 -0
- requirements.txt +15 -0
README.md
CHANGED
|
@@ -1,3 +1,103 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
task_categories:
|
| 6 |
+
- image-classification
|
| 7 |
+
tags:
|
| 8 |
+
- medical
|
| 9 |
+
- brain-data
|
| 10 |
+
- mri
|
| 11 |
+
pretty_name: 3D Brain Structure MRI PCA
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
## 🧠 Model Summary
|
| 15 |
+
# brain2vec
|
| 16 |
+
An linear PCA model for brain structure T1 MRIs. The models takes in a 3d MRI NIfTI file and compresses to 1200 latent dimensions before reconstructing the image.
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# Training data
|
| 20 |
+
[Radiata brain-structure](https://huggingface.co/datasets/radiata-ai/brain-structure): 3066 scans from 2085 individuals in the 'train' split. Mean age = 45.1 +- 24.5, including 2847 scans from cognitively normal subjects and 219 scans from individuals with an Alzheimer's disease clinical diagnosis.
|
| 21 |
+
|
| 22 |
+
# Example usage
|
| 23 |
+
```
|
| 24 |
+
# get brain2vec model repository
|
| 25 |
+
git clone https://huggingface.co/radiata-ai/brain2vec
|
| 26 |
+
cd brain2vec
|
| 27 |
+
|
| 28 |
+
# set up virtual environemt
|
| 29 |
+
python3 -m venv venv_brain2vec
|
| 30 |
+
source venv_brain2vec/bin/activate
|
| 31 |
+
|
| 32 |
+
# install Python libraries
|
| 33 |
+
pip install -r requirements.txt
|
| 34 |
+
|
| 35 |
+
# create the csv file inputs.csv listing the scan paths and other info
|
| 36 |
+
# this script loads the radiata-ai/brain-structure dataset
|
| 37 |
+
python create_csv.py
|
| 38 |
+
|
| 39 |
+
mkdir ae_cache
|
| 40 |
+
mkdir ae_output
|
| 41 |
+
|
| 42 |
+
# train the model
|
| 43 |
+
nohup python brain2vec.py train \
|
| 44 |
+
--dataset_csv /home/ubuntu/brain2vec/inputs.csv \
|
| 45 |
+
--cache_dir ./ae_cache \
|
| 46 |
+
--output_dir ./ae_output \
|
| 47 |
+
--n_epochs 10 \
|
| 48 |
+
> train_log.txt 2>&1 &
|
| 49 |
+
|
| 50 |
+
# run model inference to create *_embeddings.npz files
|
| 51 |
+
python brain2vec.py infererence \
|
| 52 |
+
--dataset_csv home/ubuntu/brain2vec/inputs.csv \
|
| 53 |
+
--aekl_ckpt /home/ubuntu/brain2vec/autoencoder_final.pth \
|
| 54 |
+
--output_dir /home/ubuntu/brain2vec
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
# Methods
|
| 58 |
+
transform:
|
| 59 |
+
(80, 96, 80)
|
| 60 |
+
pixdim=2
|
| 61 |
+
10 epochs
|
| 62 |
+
max_batch_size: int = 2,
|
| 63 |
+
batch_size: int = 16,
|
| 64 |
+
lr: float = 1e-4,
|
| 65 |
+
|
| 66 |
+
# References
|
| 67 |
+
Puglisi
|
| 68 |
+
Pinaya
|
| 69 |
+
|
| 70 |
+
# Citation
|
| 71 |
+
```
|
| 72 |
+
@dataset{Radiata-Brain-Structure,
|
| 73 |
+
author = {Jesse Brown and Clayton Young},
|
| 74 |
+
title = {Brain-Structure: Processed Structural MRI Brain Scans Across the Lifespan},
|
| 75 |
+
year = {2025},
|
| 76 |
+
url = {https://huggingface.co/datasets/radiata-ai/brain-structure},
|
| 77 |
+
note = {Version 1.0},
|
| 78 |
+
publisher = {Hugging Face}
|
| 79 |
+
}
|
| 80 |
+
```
|
| 81 |
+
|
| 82 |
+
# License
|
| 83 |
+
MIT License
|
| 84 |
+
|
| 85 |
+
Copyright (c) 2025
|
| 86 |
+
|
| 87 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 88 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 89 |
+
in the Software without restriction, including without limitation the rights
|
| 90 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 91 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 92 |
+
furnished to do so, subject to the following conditions:
|
| 93 |
+
|
| 94 |
+
The above copyright notice and this permission notice shall be included in all
|
| 95 |
+
copies or substantial portions of the Software.
|
| 96 |
+
|
| 97 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 98 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 99 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 100 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 101 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 102 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 103 |
+
SOFTWARE.
|
brain2vec_PCA.py
ADDED
|
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
"""
|
| 4 |
+
pca_autoencoder.py
|
| 5 |
+
|
| 6 |
+
This script demonstrates how to:
|
| 7 |
+
1) Load a dataset of MRI volumes using MONAI transforms (as in brain2vec_linearAE.py).
|
| 8 |
+
2) Flatten each 3D volume into a 1D vector (614,400 features if 80x96x80).
|
| 9 |
+
3) Perform IncrementalPCA to reduce dimensionality to 1200 components.
|
| 10 |
+
4) Provide a 'forward()' method that returns (reconstruction, embedding),
|
| 11 |
+
mimicking the interface of a linear autoencoder.
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import argparse
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
|
| 19 |
+
import torch
|
| 20 |
+
from torch.utils.data import DataLoader
|
| 21 |
+
|
| 22 |
+
from monai import transforms
|
| 23 |
+
from monai.data import Dataset, PersistentDataset
|
| 24 |
+
|
| 25 |
+
from sklearn.decomposition import IncrementalPCA
|
| 26 |
+
|
| 27 |
+
###################################################################
|
| 28 |
+
# Constants for your typical config
|
| 29 |
+
###################################################################
|
| 30 |
+
RESOLUTION = 2
|
| 31 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
| 32 |
+
N_COMPONENTS = 1200
|
| 33 |
+
|
| 34 |
+
###################################################################
|
| 35 |
+
# Helper classes/functions
|
| 36 |
+
###################################################################
|
| 37 |
+
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
| 38 |
+
"""
|
| 39 |
+
Returns a monai.data.Dataset or monai.data.PersistentDataset
|
| 40 |
+
if `cache_dir` is defined, to speed up loading.
|
| 41 |
+
"""
|
| 42 |
+
if cache_dir and cache_dir.strip():
|
| 43 |
+
os.makedirs(cache_dir, exist_ok=True)
|
| 44 |
+
dataset = PersistentDataset(data=df.to_dict(orient='records'),
|
| 45 |
+
transform=transforms_fn,
|
| 46 |
+
cache_dir=cache_dir)
|
| 47 |
+
else:
|
| 48 |
+
dataset = Dataset(data=df.to_dict(orient='records'),
|
| 49 |
+
transform=transforms_fn)
|
| 50 |
+
return dataset
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class PCAAutoencoder:
|
| 54 |
+
"""
|
| 55 |
+
A PCA 'autoencoder' using IncrementalPCA for memory efficiency,
|
| 56 |
+
providing:
|
| 57 |
+
- fit(X): partial fit on batches
|
| 58 |
+
- transform(X): get embeddings
|
| 59 |
+
- inverse_transform(Z): reconstruct from embeddings
|
| 60 |
+
- forward(X): returns (X_recon, Z) for a direct API
|
| 61 |
+
similar to a shallow linear AE.
|
| 62 |
+
"""
|
| 63 |
+
def __init__(self, n_components=N_COMPONENTS, batch_size=128):
|
| 64 |
+
self.n_components = n_components
|
| 65 |
+
self.batch_size = batch_size
|
| 66 |
+
self.ipca = IncrementalPCA(n_components=self.n_components)
|
| 67 |
+
|
| 68 |
+
def fit(self, X: np.ndarray):
|
| 69 |
+
"""
|
| 70 |
+
Incrementally fit the PCA model on batches of data.
|
| 71 |
+
X: shape (n_samples, n_features).
|
| 72 |
+
"""
|
| 73 |
+
n_samples = X.shape[0]
|
| 74 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
| 75 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
| 76 |
+
self.ipca.partial_fit(X[start_idx:end_idx])
|
| 77 |
+
|
| 78 |
+
def transform(self, X: np.ndarray) -> np.ndarray:
|
| 79 |
+
"""
|
| 80 |
+
Projects data into the PCA latent space in batches.
|
| 81 |
+
Returns Z: shape (n_samples, n_components).
|
| 82 |
+
"""
|
| 83 |
+
results = []
|
| 84 |
+
n_samples = X.shape[0]
|
| 85 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
| 86 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
| 87 |
+
Z_chunk = self.ipca.transform(X[start_idx:end_idx])
|
| 88 |
+
results.append(Z_chunk)
|
| 89 |
+
return np.vstack(results)
|
| 90 |
+
|
| 91 |
+
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
| 92 |
+
"""
|
| 93 |
+
Reconstruct data from PCA latent space in batches.
|
| 94 |
+
Returns X_recon: shape (n_samples, n_features).
|
| 95 |
+
"""
|
| 96 |
+
results = []
|
| 97 |
+
n_samples = Z.shape[0]
|
| 98 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
| 99 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
| 100 |
+
X_chunk = self.ipca.inverse_transform(Z[start_idx:end_idx])
|
| 101 |
+
results.append(X_chunk)
|
| 102 |
+
return np.vstack(results)
|
| 103 |
+
|
| 104 |
+
def forward(self, X: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
|
| 105 |
+
"""
|
| 106 |
+
Mimics a linear AE's forward() returning (X_recon, Z).
|
| 107 |
+
"""
|
| 108 |
+
Z = self.transform(X)
|
| 109 |
+
X_recon = self.inverse_transform(Z)
|
| 110 |
+
return X_recon, Z
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
|
| 114 |
+
"""
|
| 115 |
+
Loads the dataset from csv_path, applies the monai transforms,
|
| 116 |
+
and flattens each 3D MRI into a 1D vector of shape (80*96*80).
|
| 117 |
+
Returns a numpy array X with shape (n_samples, 614400).
|
| 118 |
+
"""
|
| 119 |
+
df = pd.read_csv(csv_path)
|
| 120 |
+
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
| 121 |
+
|
| 122 |
+
# We'll put the flattened data into this list, then stack.
|
| 123 |
+
X_list = []
|
| 124 |
+
|
| 125 |
+
# If memory allows, you can simply do a single-threaded loop
|
| 126 |
+
# or multi-worker DataLoader for speed.
|
| 127 |
+
# We'll demonstrate a simple single-worker here for clarity.
|
| 128 |
+
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
| 129 |
+
|
| 130 |
+
for batch in loader:
|
| 131 |
+
# batch["image"] shape: (1, 1, 80, 96, 80)
|
| 132 |
+
img = batch["image"].squeeze(0) # shape: (1, 80, 96, 80)
|
| 133 |
+
img_np = img.numpy() # convert to np array, shape: (1, D, H, W)
|
| 134 |
+
flattened = img_np.flatten() # shape: (614400,)
|
| 135 |
+
X_list.append(flattened)
|
| 136 |
+
|
| 137 |
+
X = np.vstack(X_list) # shape: (n_samples, 614400)
|
| 138 |
+
return X
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def main():
|
| 142 |
+
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms example.")
|
| 143 |
+
parser.add_argument("--inputs_csv", type=str, required=True, help="CSV with 'image_path' column.")
|
| 144 |
+
parser.add_argument("--cache_dir", type=str, default="", help="Cache directory for MONAI PersistentDataset.")
|
| 145 |
+
parser.add_argument("--output_dir", type=str, default="./pca_outputs", help="Where to save PCA model and embeddings.")
|
| 146 |
+
parser.add_argument("--batch_size_ipca", type=int, default=128, help="Batch size for IncrementalPCA partial_fit().")
|
| 147 |
+
parser.add_argument("--n_components", type=int, default=1200, help="Number of PCA components.")
|
| 148 |
+
args = parser.parse_args()
|
| 149 |
+
|
| 150 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
| 151 |
+
|
| 152 |
+
# Same transforms as in brain2vec_linearAE.py
|
| 153 |
+
transforms_fn = transforms.Compose([
|
| 154 |
+
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
| 155 |
+
transforms.LoadImageD(image_only=True, keys=['image']),
|
| 156 |
+
transforms.EnsureChannelFirstD(keys=['image']),
|
| 157 |
+
transforms.SpacingD(pixdim=RESOLUTION, keys=['image']),
|
| 158 |
+
transforms.ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
| 159 |
+
transforms.ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 160 |
+
])
|
| 161 |
+
|
| 162 |
+
print("Loading and flattening dataset from:", args.inputs_csv)
|
| 163 |
+
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
|
| 164 |
+
print(f"Dataset shape after flattening: {X.shape}")
|
| 165 |
+
|
| 166 |
+
# Build PCAAutoencoder
|
| 167 |
+
model = PCAAutoencoder(n_components=args.n_components, batch_size=args.batch_size_ipca)
|
| 168 |
+
|
| 169 |
+
# Fit the PCA model
|
| 170 |
+
print("Fitting IncrementalPCA in batches...")
|
| 171 |
+
model.fit(X)
|
| 172 |
+
print("Done fitting PCA. Transforming data to embeddings...")
|
| 173 |
+
|
| 174 |
+
# Get embeddings & reconstruction
|
| 175 |
+
X_recon, Z = model.forward(X)
|
| 176 |
+
print("Embeddings shape:", Z.shape)
|
| 177 |
+
print("Reconstruction shape:", X_recon.shape)
|
| 178 |
+
|
| 179 |
+
# Optional: Save
|
| 180 |
+
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
| 181 |
+
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
| 182 |
+
np.save(embeddings_path, Z)
|
| 183 |
+
np.save(recons_path, X_recon)
|
| 184 |
+
print(f"Saved embeddings to {embeddings_path} and reconstructions to {recons_path}")
|
| 185 |
+
|
| 186 |
+
# If you want to store the actual PCA components for future usage:
|
| 187 |
+
# from joblib import dump
|
| 188 |
+
# ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
|
| 189 |
+
# dump(model.ipca, ipca_model_path)
|
| 190 |
+
# print(f"Saved PCA model to {ipca_model_path}")
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
if __name__ == "__main__":
|
| 194 |
+
main()
|
create_csv.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
import os
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from datasets import load_dataset
|
| 5 |
+
|
| 6 |
+
def row_to_dict(row, split_name):
|
| 7 |
+
return {
|
| 8 |
+
"image_uid": row["id"],
|
| 9 |
+
"age": int(row["metadata"]["age"]),
|
| 10 |
+
"sex": 1 if row["metadata"]["sex"].lower() == "male" else 2,
|
| 11 |
+
"image_path": os.path.abspath(row["nii_filepath"]),
|
| 12 |
+
"split": split_name
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
# Load the datasets
|
| 17 |
+
ds_train = load_dataset("radiata-ai/brain-structure", split="train", trust_remote_code=True)
|
| 18 |
+
ds_val = load_dataset("radiata-ai/brain-structure", split="validation", trust_remote_code=True)
|
| 19 |
+
ds_test = load_dataset("radiata-ai/brain-structure", split="test", trust_remote_code=True)
|
| 20 |
+
|
| 21 |
+
rows = []
|
| 22 |
+
|
| 23 |
+
# Process each split
|
| 24 |
+
for data_row in ds_train:
|
| 25 |
+
rows.append(row_to_dict(data_row, "train"))
|
| 26 |
+
for data_row in ds_val:
|
| 27 |
+
rows.append(row_to_dict(data_row, "validation"))
|
| 28 |
+
for data_row in ds_test:
|
| 29 |
+
rows.append(row_to_dict(data_row, "test"))
|
| 30 |
+
|
| 31 |
+
# Create a DataFrame and write it to CSV
|
| 32 |
+
df = pd.DataFrame(rows)
|
| 33 |
+
output_csv = "inputs.csv"
|
| 34 |
+
df.to_csv(output_csv, index=False)
|
| 35 |
+
print(f"CSV file created: {output_csv}")
|
| 36 |
+
|
| 37 |
+
if __name__ == "__main__":
|
| 38 |
+
main()
|
| 39 |
+
|
inputs_example.csv
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
image_uid,age,sex,image_path,split
|
| 2 |
+
0,81,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-03/anat/msub-OASIS20133_ses-03_T1w_brain_affine_mni.nii.gz,train
|
| 3 |
+
1,78,2,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20133/ses-01/anat/msub-OASIS20133_ses-01_T1w_brain_affine_mni.nii.gz,train
|
| 4 |
+
2,87,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-02/anat/msub-OASIS20105_ses-02_T1w_brain_affine_mni.nii.gz,train
|
| 5 |
+
3,86,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20105/ses-01/anat/msub-OASIS20105_ses-01_T1w_brain_affine_mni.nii.gz,train
|
| 6 |
+
4,84,1,/Users/jbrown2/.cache/huggingface/datasets/downloads/extracted/6429865a89f9ae54df1c3c2db5d0f1f25cf7dd43cb87704d76ed08cf8c194aba/OASIS-2/sub-OASIS20102/ses-02/anat/msub-OASIS20102_ses-02_T1w_brain_affine_mni.nii.gz,train
|
model.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# model.py
|
| 2 |
+
import os
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
from monai.transforms import (
|
| 8 |
+
Compose,
|
| 9 |
+
CopyItemsD,
|
| 10 |
+
LoadImageD,
|
| 11 |
+
EnsureChannelFirstD,
|
| 12 |
+
SpacingD,
|
| 13 |
+
ResizeWithPadOrCropD,
|
| 14 |
+
ScaleIntensityD,
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
# Constants for your typical config
|
| 18 |
+
RESOLUTION = 2
|
| 19 |
+
INPUT_SHAPE_AE = (80, 96, 80)
|
| 20 |
+
|
| 21 |
+
# Define the exact transform pipeline for input MRI
|
| 22 |
+
transforms_fn = Compose([
|
| 23 |
+
CopyItemsD(keys={'image_path'}, names=['image']),
|
| 24 |
+
LoadImageD(image_only=True, keys=['image']),
|
| 25 |
+
EnsureChannelFirstD(keys=['image']),
|
| 26 |
+
SpacingD(pixdim=RESOLUTION, keys=['image']),
|
| 27 |
+
ResizeWithPadOrCropD(spatial_size=INPUT_SHAPE_AE, mode='minimum', keys=['image']),
|
| 28 |
+
ScaleIntensityD(minv=0, maxv=1, keys=['image']),
|
| 29 |
+
])
|
| 30 |
+
|
| 31 |
+
def preprocess_mri(image_path: str, device: str = "cpu") -> torch.Tensor:
|
| 32 |
+
"""
|
| 33 |
+
Preprocess an MRI using MONAI transforms to produce
|
| 34 |
+
a 5D tensor (batch=1, channels=1, D, H, W) for inference.
|
| 35 |
+
"""
|
| 36 |
+
data_dict = {"image_path": image_path}
|
| 37 |
+
output_dict = transforms_fn(data_dict)
|
| 38 |
+
image_tensor = output_dict["image"] # shape: (1, D, H, W)
|
| 39 |
+
image_tensor = image_tensor.unsqueeze(0) # => (batch=1, channel=1, D, H, W)
|
| 40 |
+
return image_tensor.to(device)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class ShallowLinearAutoencoder(nn.Module):
|
| 44 |
+
"""
|
| 45 |
+
A purely linear autoencoder with one hidden layer.
|
| 46 |
+
- Flatten input into a vector
|
| 47 |
+
- Linear encoder (no activation)
|
| 48 |
+
- Linear decoder (no activation)
|
| 49 |
+
- Reshape output to original volume shape
|
| 50 |
+
"""
|
| 51 |
+
def __init__(self, input_shape=(80, 96, 80), hidden_size=1200):
|
| 52 |
+
super().__init__()
|
| 53 |
+
self.input_shape = input_shape
|
| 54 |
+
self.input_dim = input_shape[0] * input_shape[1] * input_shape[2]
|
| 55 |
+
self.hidden_size = hidden_size
|
| 56 |
+
|
| 57 |
+
# Encoder (no activation for PCA-like behavior)
|
| 58 |
+
self.encoder = nn.Sequential(
|
| 59 |
+
nn.Flatten(),
|
| 60 |
+
nn.Linear(self.input_dim, self.hidden_size),
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Decoder (no activation)
|
| 64 |
+
self.decoder = nn.Sequential(
|
| 65 |
+
nn.Linear(self.hidden_size, self.input_dim),
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
def encode(self, x: torch.Tensor):
|
| 69 |
+
return self.encoder(x)
|
| 70 |
+
|
| 71 |
+
def decode(self, z: torch.Tensor):
|
| 72 |
+
out = self.decoder(z)
|
| 73 |
+
# Reshape to (N, 1, D, H, W)
|
| 74 |
+
return out.view(-1, 1, *self.input_shape)
|
| 75 |
+
|
| 76 |
+
def forward(self, x: torch.Tensor):
|
| 77 |
+
"""
|
| 78 |
+
Return (reconstruction, embedding, None) to keep a similar API
|
| 79 |
+
to the old VAE-based code, though there's no σ for sampling.
|
| 80 |
+
"""
|
| 81 |
+
z = self.encode(x)
|
| 82 |
+
reconstruction = self.decode(z)
|
| 83 |
+
return reconstruction, z, None
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class Brain2vec(nn.Module):
|
| 87 |
+
"""
|
| 88 |
+
A wrapper around the ShallowLinearAutoencoder, providing a from_pretrained(...)
|
| 89 |
+
method for model loading, mirroring the old usage with AutoencoderKL.
|
| 90 |
+
"""
|
| 91 |
+
def __init__(self, device: str = "cpu"):
|
| 92 |
+
super().__init__()
|
| 93 |
+
# Instantiate the shallow linear model
|
| 94 |
+
self.model = ShallowLinearAutoencoder(input_shape=INPUT_SHAPE_AE, hidden_size=1200)
|
| 95 |
+
self.to(device)
|
| 96 |
+
|
| 97 |
+
def forward(self, x: torch.Tensor):
|
| 98 |
+
"""
|
| 99 |
+
Forward pass that returns (reconstruction, embedding, None).
|
| 100 |
+
"""
|
| 101 |
+
return self.model(x)
|
| 102 |
+
|
| 103 |
+
@staticmethod
|
| 104 |
+
def from_pretrained(
|
| 105 |
+
checkpoint_path: Optional[str] = None,
|
| 106 |
+
device: str = "cpu"
|
| 107 |
+
) -> nn.Module:
|
| 108 |
+
"""
|
| 109 |
+
Load a pretrained ShallowLinearAutoencoder if a checkpoint path is provided.
|
| 110 |
+
Args:
|
| 111 |
+
checkpoint_path (Optional[str]): path to a .pth checkpoint
|
| 112 |
+
device (str): "cpu", "cuda", etc.
|
| 113 |
+
"""
|
| 114 |
+
model = Brain2vec(device=device)
|
| 115 |
+
if checkpoint_path is not None:
|
| 116 |
+
if not os.path.exists(checkpoint_path):
|
| 117 |
+
raise FileNotFoundError(f"Checkpoint {checkpoint_path} not found.")
|
| 118 |
+
state_dict = torch.load(checkpoint_path, map_location=device)
|
| 119 |
+
model.load_state_dict(state_dict)
|
| 120 |
+
model.eval()
|
| 121 |
+
return model
|
requirements.txt
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# requirements.txt
|
| 2 |
+
|
| 3 |
+
# PyTorch (CUDA or CPU version). For GPU install, see PyTorch docs for the correct wheel.
|
| 4 |
+
torch>=1.12
|
| 5 |
+
|
| 6 |
+
# MONAI v1.2+ has the 'generative' subpackage with AutoencoderKL, PatchDiscriminator, etc.
|
| 7 |
+
monai-weekly
|
| 8 |
+
monai-generative
|
| 9 |
+
|
| 10 |
+
# Common Python libraries
|
| 11 |
+
pandas
|
| 12 |
+
numpy
|
| 13 |
+
nibabel
|
| 14 |
+
matplotlib
|
| 15 |
+
datasets
|