Code changes
Browse files- brain2vec_PCA.py +108 -55
- requirements.txt +2 -1
brain2vec_PCA.py
CHANGED
|
@@ -3,12 +3,16 @@
|
|
| 3 |
"""
|
| 4 |
pca_autoencoder.py
|
| 5 |
|
| 6 |
-
|
| 7 |
-
1
|
| 8 |
-
2
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import os
|
|
@@ -22,17 +26,20 @@ from torch.utils.data import DataLoader
|
|
| 22 |
from monai import transforms
|
| 23 |
from monai.data import Dataset, PersistentDataset
|
| 24 |
|
| 25 |
-
|
|
|
|
|
|
|
| 26 |
|
| 27 |
###################################################################
|
| 28 |
# Constants for your typical config
|
| 29 |
###################################################################
|
| 30 |
RESOLUTION = 2
|
| 31 |
INPUT_SHAPE_AE = (80, 96, 80)
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
###################################################################
|
| 35 |
-
# Helper
|
| 36 |
###################################################################
|
| 37 |
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
| 38 |
"""
|
|
@@ -50,35 +57,57 @@ def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
|
| 50 |
return dataset
|
| 51 |
|
| 52 |
|
|
|
|
|
|
|
|
|
|
| 53 |
class PCAAutoencoder:
|
| 54 |
"""
|
| 55 |
-
A PCA 'autoencoder'
|
| 56 |
-
|
| 57 |
-
- fit(X): partial fit on batches
|
| 58 |
- transform(X): get embeddings
|
| 59 |
-
- inverse_transform(Z): reconstruct from embeddings
|
| 60 |
-
- forward(X): returns (X_recon, Z)
|
| 61 |
-
|
|
|
|
|
|
|
| 62 |
"""
|
| 63 |
-
def __init__(self, n_components=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
self.n_components = n_components
|
| 65 |
self.batch_size = batch_size
|
| 66 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
def fit(self, X: np.ndarray):
|
| 69 |
"""
|
| 70 |
-
|
| 71 |
-
|
|
|
|
| 72 |
"""
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
| 79 |
"""
|
| 80 |
-
|
| 81 |
-
Returns Z
|
| 82 |
"""
|
| 83 |
results = []
|
| 84 |
n_samples = X.shape[0]
|
|
@@ -91,7 +120,7 @@ class PCAAutoencoder:
|
|
| 91 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
| 92 |
"""
|
| 93 |
Reconstruct data from PCA latent space in batches.
|
| 94 |
-
Returns X_recon
|
| 95 |
"""
|
| 96 |
results = []
|
| 97 |
n_samples = Z.shape[0]
|
|
@@ -110,46 +139,65 @@ class PCAAutoencoder:
|
|
| 110 |
return X_recon, Z
|
| 111 |
|
| 112 |
|
|
|
|
|
|
|
|
|
|
| 113 |
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
|
| 114 |
"""
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
|
|
|
| 118 |
"""
|
| 119 |
df = pd.read_csv(csv_path)
|
| 120 |
-
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
| 121 |
|
| 122 |
-
#
|
| 123 |
-
|
|
|
|
|
|
|
| 124 |
|
| 125 |
-
|
| 126 |
-
# or multi-worker DataLoader for speed.
|
| 127 |
-
# We'll demonstrate a simple single-worker here for clarity.
|
| 128 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
| 129 |
|
|
|
|
|
|
|
| 130 |
for batch in loader:
|
| 131 |
-
# batch["image"] shape
|
| 132 |
-
img = batch["image"].squeeze(0) #
|
| 133 |
-
img_np = img.numpy()
|
| 134 |
-
flattened = img_np.flatten()
|
| 135 |
X_list.append(flattened)
|
| 136 |
|
| 137 |
-
|
|
|
|
|
|
|
|
|
|
| 138 |
return X
|
| 139 |
|
| 140 |
|
|
|
|
|
|
|
|
|
|
| 141 |
def main():
|
| 142 |
-
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms
|
| 143 |
-
parser.add_argument("--inputs_csv", type=str, required=True,
|
| 144 |
-
|
| 145 |
-
parser.add_argument("--
|
| 146 |
-
|
| 147 |
-
parser.add_argument("--
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
args = parser.parse_args()
|
| 149 |
|
| 150 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 151 |
|
| 152 |
-
#
|
| 153 |
transforms_fn = transforms.Compose([
|
| 154 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
| 155 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
|
@@ -163,27 +211,32 @@ def main():
|
|
| 163 |
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
|
| 164 |
print(f"Dataset shape after flattening: {X.shape}")
|
| 165 |
|
| 166 |
-
# Build PCAAutoencoder
|
| 167 |
-
model = PCAAutoencoder(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
|
| 169 |
# Fit the PCA model
|
| 170 |
-
print("Fitting
|
| 171 |
model.fit(X)
|
| 172 |
print("Done fitting PCA. Transforming data to embeddings...")
|
| 173 |
|
| 174 |
# Get embeddings & reconstruction
|
| 175 |
X_recon, Z = model.forward(X)
|
| 176 |
-
print("Embeddings shape:", Z.shape)
|
| 177 |
-
print("Reconstruction shape:", X_recon.shape)
|
| 178 |
|
| 179 |
-
#
|
| 180 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
| 181 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
| 182 |
np.save(embeddings_path, Z)
|
| 183 |
np.save(recons_path, X_recon)
|
| 184 |
-
print(f"Saved embeddings to {embeddings_path}
|
|
|
|
| 185 |
|
| 186 |
-
#
|
| 187 |
# from joblib import dump
|
| 188 |
# ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
|
| 189 |
# dump(model.ipca, ipca_model_path)
|
|
|
|
| 3 |
"""
|
| 4 |
pca_autoencoder.py
|
| 5 |
|
| 6 |
+
Adjustments requested:
|
| 7 |
+
1. Only fit on scans with a 'train' label in the inputs.csv 'split' column.
|
| 8 |
+
2. An option to either run incremental PCA or standard PCA.
|
| 9 |
+
|
| 10 |
+
Example usage:
|
| 11 |
+
python pca_autoencoder.py \
|
| 12 |
+
--inputs_csv /path/to/inputs.csv \
|
| 13 |
+
--output_dir ./pca_outputs \
|
| 14 |
+
--pca_type standard \
|
| 15 |
+
--n_components 100
|
| 16 |
"""
|
| 17 |
|
| 18 |
import os
|
|
|
|
| 26 |
from monai import transforms
|
| 27 |
from monai.data import Dataset, PersistentDataset
|
| 28 |
|
| 29 |
+
# We'll import both PCA classes, and decide which to use based on CLI arg.
|
| 30 |
+
from sklearn.decomposition import PCA, IncrementalPCA
|
| 31 |
+
|
| 32 |
|
| 33 |
###################################################################
|
| 34 |
# Constants for your typical config
|
| 35 |
###################################################################
|
| 36 |
RESOLUTION = 2
|
| 37 |
INPUT_SHAPE_AE = (80, 96, 80)
|
| 38 |
+
DEFAULT_N_COMPONENTS = 1200
|
| 39 |
+
|
| 40 |
|
| 41 |
###################################################################
|
| 42 |
+
# Helper: get_dataset_from_pd (same as in brain2vec_linearAE.py)
|
| 43 |
###################################################################
|
| 44 |
def get_dataset_from_pd(df: pd.DataFrame, transforms_fn, cache_dir: str):
|
| 45 |
"""
|
|
|
|
| 57 |
return dataset
|
| 58 |
|
| 59 |
|
| 60 |
+
###################################################################
|
| 61 |
+
# PCAAutoencoder
|
| 62 |
+
###################################################################
|
| 63 |
class PCAAutoencoder:
|
| 64 |
"""
|
| 65 |
+
A PCA 'autoencoder' that can use either standard PCA or IncrementalPCA:
|
| 66 |
+
- fit(X): trains the model
|
|
|
|
| 67 |
- transform(X): get embeddings
|
| 68 |
+
- inverse_transform(Z): reconstruct data from embeddings
|
| 69 |
+
- forward(X): returns (X_recon, Z)
|
| 70 |
+
|
| 71 |
+
If using standard PCA, we do a single call to .fit(X).
|
| 72 |
+
If using incremental PCA, we do .partial_fit on data in batches.
|
| 73 |
"""
|
| 74 |
+
def __init__(self, n_components=DEFAULT_N_COMPONENTS, batch_size=128, pca_type='incremental'):
|
| 75 |
+
"""
|
| 76 |
+
Args:
|
| 77 |
+
n_components (int): number of principal components to keep
|
| 78 |
+
batch_size (int): chunk size for either partial_fit or chunked .transform
|
| 79 |
+
pca_type (str): 'incremental' or 'standard'
|
| 80 |
+
"""
|
| 81 |
self.n_components = n_components
|
| 82 |
self.batch_size = batch_size
|
| 83 |
+
self.pca_type = pca_type.lower()
|
| 84 |
+
|
| 85 |
+
if self.pca_type == 'standard':
|
| 86 |
+
self.ipca = PCA(n_components=self.n_components, svd_solver='randomized')
|
| 87 |
+
else:
|
| 88 |
+
# default to incremental
|
| 89 |
+
self.ipca = IncrementalPCA(n_components=self.n_components)
|
| 90 |
|
| 91 |
def fit(self, X: np.ndarray):
|
| 92 |
"""
|
| 93 |
+
Fit the PCA model. If incremental, calls partial_fit in batches.
|
| 94 |
+
If standard, calls .fit once on the entire data matrix.
|
| 95 |
+
X: shape (n_samples, n_features)
|
| 96 |
"""
|
| 97 |
+
if self.pca_type == 'standard':
|
| 98 |
+
# Potentially large memory usage, so be sure your system can handle it.
|
| 99 |
+
self.ipca.fit(X)
|
| 100 |
+
else:
|
| 101 |
+
# IncrementalPCA
|
| 102 |
+
n_samples = X.shape[0]
|
| 103 |
+
for start_idx in range(0, n_samples, self.batch_size):
|
| 104 |
+
end_idx = min(start_idx + self.batch_size, n_samples)
|
| 105 |
+
self.ipca.partial_fit(X[start_idx:end_idx])
|
| 106 |
|
| 107 |
def transform(self, X: np.ndarray) -> np.ndarray:
|
| 108 |
"""
|
| 109 |
+
Project data into the PCA latent space in batches for memory efficiency.
|
| 110 |
+
Returns Z with shape (n_samples, n_components)
|
| 111 |
"""
|
| 112 |
results = []
|
| 113 |
n_samples = X.shape[0]
|
|
|
|
| 120 |
def inverse_transform(self, Z: np.ndarray) -> np.ndarray:
|
| 121 |
"""
|
| 122 |
Reconstruct data from PCA latent space in batches.
|
| 123 |
+
Returns X_recon with shape (n_samples, n_features).
|
| 124 |
"""
|
| 125 |
results = []
|
| 126 |
n_samples = Z.shape[0]
|
|
|
|
| 139 |
return X_recon, Z
|
| 140 |
|
| 141 |
|
| 142 |
+
###################################################################
|
| 143 |
+
# Load and Flatten Data
|
| 144 |
+
###################################################################
|
| 145 |
def load_and_flatten_dataset(csv_path: str, cache_dir: str, transforms_fn) -> np.ndarray:
|
| 146 |
"""
|
| 147 |
+
1) Reads CSV.
|
| 148 |
+
2) Filters rows if 'split' in columns => only keep 'split' == 'train'.
|
| 149 |
+
3) Applies transforms to each image, flattening them into a 1D vector (614,400).
|
| 150 |
+
4) Returns a NumPy array X: shape (n_samples, 614400).
|
| 151 |
"""
|
| 152 |
df = pd.read_csv(csv_path)
|
|
|
|
| 153 |
|
| 154 |
+
# Filter only 'train' if the split column exists
|
| 155 |
+
if 'split' in df.columns:
|
| 156 |
+
df = df[df['split'] == 'train']
|
| 157 |
+
# If there is no 'split' column, we assume the entire CSV is for training.
|
| 158 |
|
| 159 |
+
dataset = get_dataset_from_pd(df, transforms_fn, cache_dir)
|
|
|
|
|
|
|
| 160 |
loader = DataLoader(dataset, batch_size=1, num_workers=0)
|
| 161 |
|
| 162 |
+
# We'll store each flattened volume in a list, then stack
|
| 163 |
+
X_list = []
|
| 164 |
for batch in loader:
|
| 165 |
+
# batch["image"] shape => (1, 1, 80, 96, 80)
|
| 166 |
+
img = batch["image"].squeeze(0) # => (1, 80, 96, 80)
|
| 167 |
+
img_np = img.numpy()
|
| 168 |
+
flattened = img_np.flatten() # => (614400,)
|
| 169 |
X_list.append(flattened)
|
| 170 |
|
| 171 |
+
if len(X_list) == 0:
|
| 172 |
+
raise ValueError("No training samples found (split='train'). Check your CSV or 'split' values.")
|
| 173 |
+
|
| 174 |
+
X = np.vstack(X_list)
|
| 175 |
return X
|
| 176 |
|
| 177 |
|
| 178 |
+
###################################################################
|
| 179 |
+
# Main
|
| 180 |
+
###################################################################
|
| 181 |
def main():
|
| 182 |
+
parser = argparse.ArgumentParser(description="PCA Autoencoder with MONAI transforms and 'split' filtering.")
|
| 183 |
+
parser.add_argument("--inputs_csv", type=str, required=True,
|
| 184 |
+
help="Path to CSV with at least 'image_path' column, optional 'split' column.")
|
| 185 |
+
parser.add_argument("--cache_dir", type=str, default="",
|
| 186 |
+
help="Cache directory for MONAI PersistentDataset (optional).")
|
| 187 |
+
parser.add_argument("--output_dir", type=str, default="./pca_outputs",
|
| 188 |
+
help="Where to save PCA model and embeddings.")
|
| 189 |
+
parser.add_argument("--batch_size_ipca", type=int, default=128,
|
| 190 |
+
help="Batch size for partial_fit or chunked transform.")
|
| 191 |
+
parser.add_argument("--n_components", type=int, default=1200,
|
| 192 |
+
help="Number of PCA components to keep.")
|
| 193 |
+
parser.add_argument("--pca_type", type=str, default="incremental",
|
| 194 |
+
choices=["incremental", "standard"],
|
| 195 |
+
help="Which PCA algorithm to use: 'incremental' or 'standard'.")
|
| 196 |
args = parser.parse_args()
|
| 197 |
|
| 198 |
os.makedirs(args.output_dir, exist_ok=True)
|
| 199 |
|
| 200 |
+
# define transforms as in brain2vec_linearAE.py
|
| 201 |
transforms_fn = transforms.Compose([
|
| 202 |
transforms.CopyItemsD(keys={'image_path'}, names=['image']),
|
| 203 |
transforms.LoadImageD(image_only=True, keys=['image']),
|
|
|
|
| 211 |
X = load_and_flatten_dataset(args.inputs_csv, args.cache_dir, transforms_fn)
|
| 212 |
print(f"Dataset shape after flattening: {X.shape}")
|
| 213 |
|
| 214 |
+
# Build the PCAAutoencoder with chosen type
|
| 215 |
+
model = PCAAutoencoder(
|
| 216 |
+
n_components=args.n_components,
|
| 217 |
+
batch_size=args.batch_size_ipca,
|
| 218 |
+
pca_type=args.pca_type
|
| 219 |
+
)
|
| 220 |
|
| 221 |
# Fit the PCA model
|
| 222 |
+
print(f"Fitting {args.pca_type.capitalize()}PCA in batches...")
|
| 223 |
model.fit(X)
|
| 224 |
print("Done fitting PCA. Transforming data to embeddings...")
|
| 225 |
|
| 226 |
# Get embeddings & reconstruction
|
| 227 |
X_recon, Z = model.forward(X)
|
| 228 |
+
print("Embeddings shape:", Z.shape) # (n_samples, n_components)
|
| 229 |
+
print("Reconstruction shape:", X_recon.shape) # (n_samples, 614400)
|
| 230 |
|
| 231 |
+
# Save
|
| 232 |
embeddings_path = os.path.join(args.output_dir, "pca_embeddings.npy")
|
| 233 |
recons_path = os.path.join(args.output_dir, "pca_reconstructions.npy")
|
| 234 |
np.save(embeddings_path, Z)
|
| 235 |
np.save(recons_path, X_recon)
|
| 236 |
+
print(f"Saved embeddings to {embeddings_path}")
|
| 237 |
+
print(f"Saved reconstructions to {recons_path}")
|
| 238 |
|
| 239 |
+
# Optionally save the actual PCA model with joblib
|
| 240 |
# from joblib import dump
|
| 241 |
# ipca_model_path = os.path.join(args.output_dir, "pca_model.joblib")
|
| 242 |
# dump(model.ipca, ipca_model_path)
|
requirements.txt
CHANGED
|
@@ -12,4 +12,5 @@ pandas
|
|
| 12 |
numpy
|
| 13 |
nibabel
|
| 14 |
matplotlib
|
| 15 |
-
datasets
|
|
|
|
|
|
| 12 |
numpy
|
| 13 |
nibabel
|
| 14 |
matplotlib
|
| 15 |
+
datasets
|
| 16 |
+
scikit-learn
|