License change to apache-2.0
Browse files- README.md +8 -6
- inference_brain2vec.py +19 -1
README.md
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
---
|
| 2 |
-
license:
|
| 3 |
language:
|
| 4 |
- en
|
| 5 |
task_categories:
|
|
@@ -52,11 +52,13 @@ nohup python brain2vec.py train \
|
|
| 52 |
--n_epochs 10 \
|
| 53 |
> train_log.txt 2>&1 &
|
| 54 |
|
| 55 |
-
#
|
| 56 |
-
python
|
| 57 |
-
--
|
| 58 |
-
--
|
| 59 |
-
--output_dir
|
|
|
|
|
|
|
| 60 |
```
|
| 61 |
|
| 62 |
# Methods
|
|
|
|
| 1 |
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
language:
|
| 4 |
- en
|
| 5 |
task_categories:
|
|
|
|
| 52 |
--n_epochs 10 \
|
| 53 |
> train_log.txt 2>&1 &
|
| 54 |
|
| 55 |
+
# model inference
|
| 56 |
+
python inference_brain2vec.py \
|
| 57 |
+
--checkpoint_path /path/to/model.pth \
|
| 58 |
+
--input_images /path/to/img1.nii.gz /path/to/img2.nii.gz \
|
| 59 |
+
--output_dir ./vae_inference_outputs \
|
| 60 |
+
--embeddings_filename pca_output/pca_embeddings_2.npy \
|
| 61 |
+
--save_recons
|
| 62 |
```
|
| 63 |
|
| 64 |
# Methods
|
inference_brain2vec.py
CHANGED
|
@@ -156,6 +156,18 @@ def main() -> None:
|
|
| 156 |
"--csv_input", type=str,
|
| 157 |
help="Path to a CSV file with an 'image_path' column."
|
| 158 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 159 |
args = parser.parse_args()
|
| 160 |
|
| 161 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
@@ -198,6 +210,7 @@ def main() -> None:
|
|
| 198 |
z_sigma_np = z_sigma.detach().cpu().numpy()
|
| 199 |
|
| 200 |
# Save each reconstruction (per image) as .npy
|
|
|
|
| 201 |
recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
|
| 202 |
np.save(recon_path, recon_np)
|
| 203 |
print(f"[INFO] Saved reconstruction to {recon_path}")
|
|
@@ -210,8 +223,13 @@ def main() -> None:
|
|
| 210 |
stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
|
| 211 |
stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
|
| 212 |
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
|
|
|
|
| 215 |
np.save(mu_path, stacked_mu)
|
| 216 |
np.save(sigma_path, stacked_sigma)
|
| 217 |
|
|
|
|
| 156 |
"--csv_input", type=str,
|
| 157 |
help="Path to a CSV file with an 'image_path' column."
|
| 158 |
)
|
| 159 |
+
parser.add_argument(
|
| 160 |
+
"--embeddings_filename",
|
| 161 |
+
type=str,
|
| 162 |
+
required=True,
|
| 163 |
+
help="Filename (in output_dir) to save the stacked z_mu embeddings (e.g. 'all_z_mu.npy')."
|
| 164 |
+
)
|
| 165 |
+
parser.add_argument(
|
| 166 |
+
"--save_recons",
|
| 167 |
+
action="store_true",
|
| 168 |
+
help="If set, saves each reconstruction as .npy. Default is not to save."
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
args = parser.parse_args()
|
| 172 |
|
| 173 |
os.makedirs(args.output_dir, exist_ok=True)
|
|
|
|
| 210 |
z_sigma_np = z_sigma.detach().cpu().numpy()
|
| 211 |
|
| 212 |
# Save each reconstruction (per image) as .npy
|
| 213 |
+
if args.save_recons:
|
| 214 |
recon_path = os.path.join(args.output_dir, f"reconstruction_{i}.npy")
|
| 215 |
np.save(recon_path, recon_np)
|
| 216 |
print(f"[INFO] Saved reconstruction to {recon_path}")
|
|
|
|
| 223 |
stacked_mu = np.concatenate(all_z_mu, axis=0) # e.g., shape (N, latent_channels, ...)
|
| 224 |
stacked_sigma = np.concatenate(all_z_sigma, axis=0) # e.g., shape (N, latent_channels, ...)
|
| 225 |
|
| 226 |
+
mu_filename = args.embeddings_filename
|
| 227 |
+
if not mu_filename.lower().endswith(".npy"):
|
| 228 |
+
mu_filename += ".npy"
|
| 229 |
+
|
| 230 |
+
mu_path = os.path.join(args.output_dir, mu_filename)
|
| 231 |
sigma_path = os.path.join(args.output_dir, "all_z_sigma.npy")
|
| 232 |
+
|
| 233 |
np.save(mu_path, stacked_mu)
|
| 234 |
np.save(sigma_path, stacked_sigma)
|
| 235 |
|