Upload 26 files
Browse filesAll files as in
https://github.com/lsprietog/dMRI-IVIM-ML-Toolkit
- .gitattributes +2 -0
- README.md +96 -59
- data/CFIN/T1.nii +3 -0
- data/CFIN/__DTI_AX_ep2d_2_5_iso_33d_20141015095334_4.bval +1 -0
- data/CFIN/__DTI_AX_ep2d_2_5_iso_33d_20141015095334_4.bvec +3 -0
- data/CFIN/__DTI_AX_ep2d_2_5_iso_33d_20141015095334_4.nii +3 -0
- debug_pixel.py +77 -0
- demo.py +99 -0
- demo_cfin.py +169 -0
- demo_colab.ipynb +160 -0
- demo_dipy.py +100 -0
- demo_fit.png +0 -0
- inference_colab.ipynb +206 -0
- ivim_dki_extratrees.joblib +3 -0
- models/b_values_config.npy +3 -0
- models/ivim_dki_extratrees.joblib +3 -0
- requirements.txt +13 -0
- setup.py +32 -0
- src/__init__.py +0 -0
- src/__pycache__/ivim_model.cpython-313.pyc +0 -0
- src/__pycache__/ml_models.cpython-313.pyc +0 -0
- src/__pycache__/utils.cpython-313.pyc +0 -0
- src/ivim_model.py +280 -0
- src/ml_models.py +132 -0
- src/utils.py +48 -0
- train_ml_demo.py +96 -0
- train_pretrained.py +102 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
data/CFIN/__DTI_AX_ep2d_2_5_iso_33d_20141015095334_4.nii filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
data/CFIN/T1.nii filter=lfs diff=lfs merge=lfs -text
|
README.md
CHANGED
|
@@ -1,59 +1,96 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
```
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# dMRI-IVIM-ML-Toolkit: Machine Learning for Diffusion MRI Analysis
|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/downloads/)
|
| 4 |
+
[](https://opensource.org/licenses/MIT)
|
| 5 |
+
[](https://colab.research.google.com/github/lsprietog/public_release/blob/main/demo_colab.ipynb)
|
| 6 |
+
[](https://colab.research.google.com/github/lsprietog/public_release/blob/main/inference_colab.ipynb)
|
| 7 |
+
|
| 8 |
+
This repository hosts the source code and implementation details for the paper:
|
| 9 |
+
|
| 10 |
+
**"Exploring the Potential of Machine Learning Algorithms to Improve Diffusion Nuclear Magnetic Resonance Imaging Models Analysis"**
|
| 11 |
+
*Prieto-González LS, Agulles-Pedrós L.*
|
| 12 |
+
*Journal of Medical Physics, 2024.*
|
| 13 |
+
|
| 14 |
+
[Read the full paper here](https://pmc.ncbi.nlm.nih.gov/articles/PMC11309135/)
|
| 15 |
+
|
| 16 |
+
## Project Overview
|
| 17 |
+
|
| 18 |
+
Diffusion MRI analysis, particularly Intravoxel Incoherent Motion (IVIM) and Diffusion Kurtosis Imaging (DKI), often relies on non-linear least squares fitting. While standard, this approach can be computationally expensive and sensitive to noise, especially for parameters like pseudo-diffusion ($D^*$) and perfusion fraction ($f$).
|
| 19 |
+
|
| 20 |
+
This project introduces a **Python-based Machine Learning framework** (utilizing Random Forest, Extra Trees, and MLP) to estimate these parameters directly from the signal attenuation curve. It serves as a comprehensive toolkit for researchers looking to implement **IVIM Machine Learning** workflows in their **Python** pipelines.
|
| 21 |
+
|
| 22 |
+
### Key Findings
|
| 23 |
+
|
| 24 |
+
* **Computational Efficiency:** The ML approach reduces processing time significantly—from approximately 75 minutes to under 20 seconds for 100,000 voxels (a ~230x speedup).
|
| 25 |
+
* **Noise Robustness:** The algorithms demonstrate improved stability in parameter estimation across varying SNR levels compared to conventional fitting.
|
| 26 |
+
* **Versatility:** Validated across different anatomical regions (Prostate, Brain, Head & Neck) and species.
|
| 27 |
+
|
| 28 |
+

|
| 29 |
+
|
| 30 |
+
*Figure: Visual comparison of parameter maps derived from standard fitting versus the proposed ML estimation.*
|
| 31 |
+
|
| 32 |
+
## Installation
|
| 33 |
+
|
| 34 |
+
To set up the environment, clone this repository and install the required dependencies:
|
| 35 |
+
|
| 36 |
+
```bash
|
| 37 |
+
git clone https://github.com/lsprietog/public_release.git
|
| 38 |
+
cd public_release
|
| 39 |
+
pip install -r requirements.txt
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
## Usage Example
|
| 43 |
+
|
| 44 |
+
We provide a demonstration script `demo.py` that generates synthetic IVIM-DKI data to train and validate the model.
|
| 45 |
+
|
| 46 |
+
To run the demo locally:
|
| 47 |
+
```bash
|
| 48 |
+
python demo.py
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
### Integration with NIfTI Data
|
| 52 |
+
|
| 53 |
+
The `IVIMRegressor` class is designed to work with standard numpy arrays derived from NIfTI files.
|
| 54 |
+
|
| 55 |
+
```python
|
| 56 |
+
from src.ml_models import IVIMRegressor
|
| 57 |
+
import nibabel as nib
|
| 58 |
+
import numpy as np
|
| 59 |
+
|
| 60 |
+
# Example workflow
|
| 61 |
+
# 1. Prepare your data: Shape should be (n_voxels, n_b_values)
|
| 62 |
+
# X_train, y_train = load_your_data(...)
|
| 63 |
+
|
| 64 |
+
# 2. Initialize the regressor (e.g., Extra Trees)
|
| 65 |
+
model = IVIMRegressor(model_type='extra_trees')
|
| 66 |
+
|
| 67 |
+
# 3. Train the model
|
| 68 |
+
model.train(X_train, y_train)
|
| 69 |
+
|
| 70 |
+
# 4. Predict on new data
|
| 71 |
+
# X_new shape: (n_voxels_in_volume, n_b_values)
|
| 72 |
+
estimated_parameters = model.predict(X_new)
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
## Citation
|
| 76 |
+
|
| 77 |
+
If you find this repository useful for your research, please consider citing our work:
|
| 78 |
+
|
| 79 |
+
```bibtex
|
| 80 |
+
@article{PrietoGonzalez2024,
|
| 81 |
+
title={Exploring the Potential of Machine Learning Algorithms to Improve Diffusion Nuclear Magnetic Resonance Imaging Models Analysis},
|
| 82 |
+
author={Prieto-González, Leonar Steven and Agulles-Pedrós, Luis},
|
| 83 |
+
journal={Journal of Medical Physics},
|
| 84 |
+
volume={49},
|
| 85 |
+
issue={2},
|
| 86 |
+
pages={189--202},
|
| 87 |
+
year={2024},
|
| 88 |
+
publisher={Wolters Kluwer - Medknow},
|
| 89 |
+
doi={10.4103/jmp.jmp_10_24},
|
| 90 |
+
url={https://pmc.ncbi.nlm.nih.gov/articles/PMC11309135/}
|
| 91 |
+
}
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
## Contact
|
| 95 |
+
|
| 96 |
+
For inquiries regarding the code or the paper, please open an issue in this repository.
|
data/CFIN/T1.nii
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8f1f3561a745ab4a92b3437a794518c7c750823add2c622da3263ec9fa825158
|
| 3 |
+
size 23069024
|
data/CFIN/__DTI_AX_ep2d_2_5_iso_33d_20141015095334_4.bval
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
0 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 200 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 400 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 600 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 800 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1000 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1200 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1400 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1600 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 1800 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2000 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2200 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2400 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2600 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 2800 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000 3000
|
data/CFIN/__DTI_AX_ep2d_2_5_iso_33d_20141015095334_4.bvec
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
0 0.869998 0.862033 0.886199 0.845792 -0.842594 -0.883964 -0.874659 -0.862782 0.274032 -0.228233 -0.406856 0.451853 -0.304319 0.258413 0.386145 -0.438481 0.402851 -0.460287 0.225337 -0.283958 0.445394 -0.386941 0.308325 -0.249216 0.999031 -0.00301848 0.0426293 0.0281464 0.686742 0.733299 -0.0329733 0.722998 0.682772 0.868609 0.862082 0.886767 0.84412 -0.843462 -0.88419 -0.873846 -0.860751 0.277475 -0.227707 -0.406385 0.452442 -0.302973 0.256404 0.387865 -0.436367 0.4031 -0.459668 0.226396 -0.284435 0.445449 -0.387672 0.306581 -0.251255 0.999063 -0.0032341 0.0426443 0.0279856 0.685897 0.732337 -0.0328328 0.723756 0.684843 0.869043 0.862209 0.886762 0.84444 -0.843832 -0.884912 -0.873285 -0.860649 0.277058 -0.228109 -0.405898 0.453508 -0.30252 0.255814 0.388255 -0.436928 0.402642 -0.458564 0.227076 -0.283684 0.444405 -0.387382 0.306567 -0.250513 0.999026 -0.00323519 0.0426492 0.0278783 0.685009 0.731455 -0.0328325 0.723945 0.68484 0.86879 0.861894 0.886678 0.844687 -0.843767 -0.884847 -0.873076 -0.86003 0.276576 -0.228562 -0.40632 0.453606 -0.304136 0.256152 0.388526 -0.436484 0.402144 -0.45892 0.226444 -0.283991 0.444432 -0.387457 0.306338 -0.250892 0.999005 -0.00328869 0.042654 0.027865 0.685053 0.731432 -0.0328324 0.724515 0.684903 0.868812 0.862226 0.886728 0.844533 -0.843926 -0.884783 -0.873031 -0.860055 0.276504 -0.228222 -0.406552 0.453468 -0.303632 0.255631 0.388021 -0.436647 0.402408 -0.458632 0.226643 -0.283631 0.443994 -0.386756 0.306568 -0.250521 0.999054 -0.0033213 0.0426587 0.0278245 0.685033 0.731083 -0.0328044 0.724472 0.685303 0.868521 0.862181 0.886677 0.844604 -0.84405 -0.88501 -0.872778 -0.85997 0.276513 -0.228176 -0.406744 0.453394 -0.303433 0.256291 0.388374 -0.436755 0.402442 -0.45873 0.226468 -0.283897 0.444081 -0.386501 0.30609 -0.24997 0.999062 -0.00334263 0.0426587 0.0278246 0.684753 0.730831 -0.0328089 0.724731 0.685528 0.868707 0.862171 0.886571 0.844403 -0.844073 -0.885146 -0.872722 -0.859958 0.276315 -0.228166 -0.406732 0.45345 -0.303749 0.256318 0.388261 -0.436695 0.401949 -0.458533 0.226786 -0.283298 0.444255 -0.386544 0.306465 -0.249712 0.999044 -0.00335817 0.0426601 0.0278014 0.684814 0.73091 -0.0327923 0.724946 0.685465 0.868597 0.862248 0.886553 0.844487 -0.84428 -0.885194 -0.872614 -0.859621 0.276457 -0.22826 -0.406723 0.453595 -0.303481 0.256193 0.388197 -0.436908 0.402117 -0.458278 0.226867 -0.283383 0.443796 -0.386713 0.306364 -0.250044 0.999069 -0.00336957 0.0426622 0.0278043 0.684603 0.730743 -0.0327799 0.724866 0.685645 0.868652 0.86216 0.886705 0.84441 -0.844323 -0.885217 -0.872679 -0.859842 0.276503 -0.228135 -0.406716 0.453708 -0.303565 0.25602 0.388731 -0.43676 0.402081 -0.458467 0.226501 -0.28345 0.443876 -0.38694 0.306488 -0.249933 0.999054 -0.00337865 0.0426629 0.0277886 0.684463 0.730625 -0.0327856 0.724995 0.685783 0.868589 0.862248 0.886637 0.844405 -0.844348 -0.885146 -0.872668 -0.859694 0.276334 -0.228007 -0.406711 0.453716 -0.303471 0.256066 0.388704 -0.43689 0.402135 -0.458302 0.226594 -0.283325 0.443845 -0.386819 0.306587 -0.250033 0.999058 -0.00338572 0.0426643 0.0277922 0.684534 0.73052 -0.0327763 0.724926 0.6857 0.868683 0.862221 0.886527 0.84446 -0.844276 -0.885238 -0.872461 -0.859611 0.276224 -0.228344 -0.406631 0.453901 -0.303729 0.255921 0.388762 -0.43677 0.402104 -0.458303 0.226503 -0.283413 0.443787 -0.386772 0.306411 -0.249944 0.999075 -0.00340143 0.0426647 0.0277804 0.684593 0.730591 -0.0327687 0.725044 0.685799 0.868548 0.862282 0.886584 0.84438 -0.844369 -0.885224 -0.872577 -0.859778 0.276681 -0.22822 -0.406633 0.453892 -0.303846 0.256071 0.38873 -0.436814 0.401983 -0.4583 0.226581 -0.283416 0.443913 -0.386789 0.306422 -0.250041 0.999063 -0.00340543 0.0426642 0.0277705 0.684326 0.730513 -0.0327739 0.725126 0.685891 0.86859 0.862296 0.886597 0.844487 -0.844392 -0.885316 -0.872494 -0.859669 0.276537 -0.22837 -0.406564 0.454014 -0.303846 0.256026 0.388777 -0.436909 0.402032 -0.45837 0.226602 -0.283459 0.443826 -0.386641 0.306474 -0.249964 0.999065 -0.00340894 0.0426652 0.0277747 0.684384 0.730441 -0.0327677 0.725062 0.685829 0.868554 0.862282 0.88664 0.844308 -0.844404 -0.885239 -0.872478 -0.859644 0.276561 -0.228336 -0.406571 0.453935 -0.303745 0.255915 0.388875 -0.436814 0.402015 -0.458372 0.22668 -0.283369 0.443732 -0.386602 0.306474 -0.250123 0.999067 -0.00341194 0.0426654 0.0277667 0.68431 0.730386 -0.0327624 0.725131 0.685906 0.868567 0.862255 0.886626 0.844405 -0.844472 -0.885317 -0.872507 -0.859559 0.276548 -0.228255 -0.406522 0.454039 -0.303791 0.256272 0.388788 -0.43692 0.402047 -0.458131 0.226541 -0.283431 0.443656 -0.38657 0.30642 -0.249979 0.999069 -0.00341445 0.0426662 0.0277597 0.684245 0.730332 -0.032767 0.725205 0.685971
|
| 2 |
+
0 0.341143 -0.399649 0.324281 -0.380884 0.503055 -0.445372 0.21603 -0.157769 0.768125 0.942991 0.81819 0.889336 -0.925614 -0.782487 -0.917988 -0.789843 0.0867398 -0.426497 -0.602671 0.265468 -0.113151 0.454092 0.582292 -0.245856 -0.041406 0.566266 0.837505 0.983158 -0.171119 0.123454 -0.204304 0.675742 -0.722869 0.340777 -0.39728 0.321034 -0.379686 0.50076 -0.443126 0.214949 -0.157922 0.7637 0.942725 0.816926 0.888715 -0.92493 -0.781602 -0.916851 -0.788332 0.0847424 -0.423712 -0.598072 0.262329 -0.114041 0.450737 0.5802 -0.243463 -0.0394278 0.562101 0.835455 0.982252 -0.167679 0.120106 -0.199286 0.674226 -0.719927 0.33893 -0.396422 0.319276 -0.378614 0.49929 -0.441362 0.213629 -0.154977 0.762988 0.941989 0.815686 0.888047 -0.924697 -0.780322 -0.916546 -0.786533 0.0838246 -0.421525 -0.597007 0.260313 -0.112892 0.44923 0.579487 -0.2418 -0.039094 0.56208 0.834769 0.981637 -0.16664 0.119142 -0.199277 0.673268 -0.71993 0.337944 -0.396053 0.318957 -0.37694 0.499022 -0.441212 0.212986 -0.155296 0.762239 0.941791 0.814583 0.887901 -0.924051 -0.779552 -0.916279 -0.78681 0.0834505 -0.420498 -0.596564 0.260039 -0.111318 0.448965 0.578358 -0.241654 -0.0389205 0.561045 0.834081 0.98156 -0.165178 0.118547 -0.199273 0.672676 -0.719599 0.337499 -0.395514 0.318297 -0.377104 0.498483 -0.440875 0.212708 -0.154312 0.762102 0.941594 0.814454 0.887926 -0.924086 -0.779456 -0.916426 -0.786294 0.0830519 -0.420443 -0.595948 0.259291 -0.111793 0.448633 0.577796 -0.24102 -0.037643 0.560412 0.833398 0.981325 -0.165695 0.118276 -0.198275 0.672414 -0.719065 0.33757 -0.395095 0.318249 -0.376051 0.498026 -0.44029 0.211806 -0.154083 0.761373 0.941529 0.814115 0.887951 -0.923914 -0.779075 -0.916174 -0.78595 0.0822424 -0.419482 -0.595419 0.258967 -0.110929 0.448148 0.577525 -0.240921 -0.0369534 0.559999 0.833397 0.981325 -0.164929 0.117513 -0.198436 0.671936 -0.718934 0.336743 -0.394665 0.317822 -0.375912 0.498173 -0.440037 0.212144 -0.153924 0.761175 0.941383 0.813888 0.887873 -0.923676 -0.778734 -0.916179 -0.785625 0.0820889 -0.419759 -0.595118 0.258682 -0.1106 0.447871 0.577207 -0.240997 -0.0371612 0.559698 0.833199 0.98119 -0.164339 0.11741 -0.197847 0.671787 -0.718885 0.336792 -0.394239 0.317846 -0.375984 0.497725 -0.439748 0.211769 -0.153707 0.761182 0.941272 0.813818 0.887785 -0.923744 -0.778678 -0.916191 -0.785618 0.082389 -0.419284 -0.594914 0.258521 -0.110603 0.447275 0.576927 -0.240553 -0.0365826 0.559477 0.832878 0.981207 -0.164802 0.116945 -0.197404 0.671731 -0.71864 0.336433 -0.394508 0.317453 -0.375548 0.497505 -0.439587 0.211612 -0.15336 0.760918 0.941266 0.813674 0.887716 -0.923574 -0.778575 -0.915921 -0.785549 0.0821574 -0.419077 -0.594796 0.258397 -0.11016 0.44699 0.576773 -0.240324 -0.0367873 0.559301 0.832781 0.981115 -0.16439 0.116974 -0.197608 0.671488 -0.71845 0.336288 -0.394222 0.317296 -0.375374 0.497385 -0.439656 0.211115 -0.15317 0.760717 0.941219 0.813559 0.887685 -0.923561 -0.778387 -0.91594 -0.785284 0.0817623 -0.418765 -0.594665 0.258187 -0.11001 0.447259 0.576651 -0.240238 -0.0364577 0.559163 0.832568 0.981137 -0.16403 0.116646 -0.197277 0.671472 -0.718533 0.336074 -0.39411 0.317483 -0.375261 0.497462 -0.439379 0.211312 -0.153218 0.760651 0.941152 0.813409 0.8876 -0.923463 -0.77823 -0.915897 -0.785258 0.0816302 -0.418717 -0.594493 0.257946 -0.110159 0.446805 0.576377 -0.24008 -0.0361007 0.558859 0.832517 0.981067 -0.163737 0.116343 -0.197007 0.671421 -0.718447 0.336228 -0.393797 0.317085 -0.37529 0.49723 -0.439408 0.211206 -0.152998 0.760577 0.941122 0.813335 0.887583 -0.923397 -0.778082 -0.915876 -0.785172 0.0816879 -0.418646 -0.594409 0.257806 -0.110032 0.446658 0.576519 -0.239872 -0.0362944 0.558781 0.83259 0.981009 -0.163829 0.116413 -0.197193 0.671264 -0.71832 0.336023 -0.393865 0.316983 -0.374938 0.497115 -0.439185 0.210905 -0.15288 0.760448 0.941051 0.813432 0.887543 -0.923372 -0.778181 -0.915844 -0.784997 0.0814204 -0.418342 -0.59427 0.257775 -0.109772 0.446557 0.57626 -0.239766 -0.0360789 0.558713 0.832439 0.981034 -0.163865 0.116205 -0.196972 0.671269 -0.718341 0.335884 -0.393739 0.3168 -0.375113 0.497057 -0.439271 0.210755 -0.152662 0.760331 0.94107 0.813367 0.887578 -0.923349 -0.778072 -0.915777 -0.784997 0.0813411 -0.418224 -0.594103 0.25767 -0.109738 0.446614 0.576114 -0.239735 -0.0358951 0.558654 0.83241 0.980987 -0.163666 0.116276 -0.196782 0.671145 -0.718239 0.335704 -0.393791 0.316881 -0.374819 0.496986 -0.439086 0.210604 -0.152582 0.760287 0.94108 0.81327 0.887508 -0.923299 -0.777868 -0.915805 -0.785035 0.0812193 -0.418443 -0.593922 0.257527 -0.109618 0.446401 0.576037 -0.239535 -0.0357358 0.558606 0.832291 0.980946 -0.163495 0.116104 -0.196946 0.671129 -0.718152
|
| 3 |
+
0 -0.355986 0.311738 0.330897 -0.373582 0.192276 -0.142307 -0.43394 0.48033 -0.578697 0.242233 -0.406244 0.0700702 -0.225009 0.566514 -0.090495 0.428815 -0.911146 -0.778612 -0.765514 -0.921355 0.888156 0.802544 0.752244 0.936721 -0.0149107 -0.824217 0.544764 -0.180579 -0.706473 0.668605 -0.978352 -0.143692 0.106225 -0.359708 0.314616 0.332537 -0.37855 0.19445 -0.147809 -0.436108 0.48391 -0.582898 0.243761 -0.409246 0.0740399 -0.229592 0.568643 -0.0945752 0.433724 -0.911224 -0.780496 -0.768801 -0.922106 0.888015 0.804081 0.75457 0.936801 -0.0178329 -0.827062 0.547902 -0.185466 -0.708116 0.670266 -0.979391 -0.146962 0.112675 -0.360405 0.315348 0.334241 -0.37891 0.196617 -0.148764 -0.437877 0.485042 -0.584027 0.246217 -0.412193 0.07552 -0.231122 0.570663 -0.0959281 0.436418 -0.911511 -0.782327 -0.769428 -0.922909 0.888684 0.805064 0.755124 0.93743 -0.0204492 -0.827076 0.548946 -0.188711 -0.70922 0.6714 -0.979393 -0.150381 0.112669 -0.361936 0.316671 0.334767 -0.380025 0.197572 -0.149593 -0.438606 0.486037 -0.585234 0.246553 -0.413955 0.0766337 -0.231585 0.571563 -0.0973683 0.436363 -0.911765 -0.782671 -0.769958 -0.922892 0.888869 0.805176 0.756081 0.937367 -0.0217651 -0.827779 0.54999 -0.189113 -0.70952 0.671531 -0.979394 -0.15028 0.114388 -0.362298 0.31644 0.335263 -0.380207 0.198251 -0.150962 -0.438831 0.486305 -0.585446 0.247619 -0.41398 0.0771662 -0.232105 0.571928 -0.0979983 0.437128 -0.911685 -0.782869 -0.770376 -0.923213 0.889028 0.805698 0.756418 0.937629 -0.0217622 -0.828207 0.551025 -0.190335 -0.709419 0.671959 -0.979597 -0.151657 0.115346 -0.36293 0.317088 0.335442 -0.381089 0.198868 -0.151333 -0.439769 0.486528 -0.586388 0.247909 -0.41446 0.0773119 -0.233047 0.572152 -0.0989466 0.437639 -0.911744 -0.783327 -0.770836 -0.923222 0.889093 0.80609 0.756818 0.937802 -0.0225799 -0.828487 0.551026 -0.190333 -0.709868 0.672366 -0.979565 -0.152533 0.114831 -0.363252 0.317649 0.336127 -0.381672 0.198405 -0.151272 -0.439717 0.4866 -0.586739 0.24847 -0.414916 0.0778734 -0.233581 0.572603 -0.0993457 0.438282 -0.911975 -0.783294 -0.770975 -0.923486 0.889047 0.806223 0.756909 0.937851 -0.0230269 -0.82869 0.551325 -0.191032 -0.709945 0.672299 -0.979684 -0.152168 0.115507 -0.36347 0.317968 0.336153 -0.381415 0.198647 -0.151834 -0.440112 0.487264 -0.586662 0.248804 -0.415063 0.0780367 -0.233656 0.572735 -0.0994888 0.438081 -0.911874 -0.783698 -0.771109 -0.923505 0.889276 0.806473 0.757164 0.937877 -0.0228675 -0.828839 0.55181 -0.190943 -0.710042 0.672562 -0.979774 -0.152797 0.115966 -0.36367 0.317875 0.336122 -0.382015 0.199015 -0.152161 -0.440058 0.486983 -0.586984 0.248943 -0.415352 0.0781637 -0.234222 0.572953 -0.099883 0.438353 -0.911911 -0.783698 -0.771307 -0.923519 0.889291 0.806522 0.75723 0.937965 -0.0231811 -0.828958 0.551956 -0.191417 -0.710272 0.672684 -0.979733 -0.153254 0.116322 -0.363956 0.317989 0.336449 -0.382198 0.199211 -0.152374 -0.440319 0.487304 -0.587324 0.249239 -0.415582 0.078472 -0.234391 0.573188 -0.0998183 0.438698 -0.911923 -0.783961 -0.771381 -0.923616 0.889325 0.806431 0.757284 0.93796 -0.0235317 -0.829051 0.552277 -0.191308 -0.710287 0.672856 -0.9798 -0.153649 0.1163 -0.363929 0.318202 0.336564 -0.382185 0.199321 -0.152644 -0.440635 0.487435 -0.587462 0.249184 -0.415953 0.0783589 -0.234446 0.573466 -0.0999905 0.438864 -0.911948 -0.783987 -0.77154 -0.923656 0.889336 0.806705 0.757563 0.938025 -0.02337 -0.829256 0.552354 -0.191664 -0.710298 0.672831 -0.979854 -0.153313 0.116247 -0.364108 0.318425 0.336787 -0.382334 0.199509 -0.152639 -0.440456 0.48721 -0.587342 0.24941 -0.416095 0.0785998 -0.234556 0.573599 -0.100303 0.438975 -0.911996 -0.784026 -0.771582 -0.923695 0.889289 0.806778 0.757451 0.938052 -0.0235634 -0.829308 0.552244 -0.191961 -0.710533 0.672903 -0.979817 -0.153614 0.116492 -0.364197 0.318302 0.33685 -0.382444 0.199697 -0.152751 -0.440764 0.487439 -0.587576 0.249539 -0.415973 0.0783514 -0.234651 0.573485 -0.100411 0.439193 -0.911999 -0.784147 -0.771683 -0.92369 0.889364 0.806905 0.757627 0.938099 -0.0238036 -0.829354 0.552471 -0.191834 -0.710469 0.673017 -0.979861 -0.15389 0.116729 -0.364412 0.318496 0.336908 -0.382668 0.199789 -0.15295 -0.440868 0.487551 -0.587717 0.249498 -0.416094 0.0784196 -0.234874 0.573682 -0.100646 0.439288 -0.912013 -0.784209 -0.771789 -0.923747 0.889415 0.806893 0.757737 0.938065 -0.0240084 -0.829394 0.552516 -0.192077 -0.710587 0.673065 -0.9799 -0.154106 0.116904 -0.364547 0.318505 0.336869 -0.382741 0.199677 -0.153026 -0.440883 0.487727 -0.58778 0.249536 -0.416331 0.0786 -0.235009 0.573799 -0.10072 0.439114 -0.91201 -0.784233 -0.771969 -0.923768 0.889468 0.807025 0.757818 0.938154 -0.0241859 -0.829426 0.552695 -0.192287 -0.710688 0.673153 -0.979867 -0.15383 0.117056
|
data/CFIN/__DTI_AX_ep2d_2_5_iso_33d_20141015095334_4.nii
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1081b5b586a5429a6754078653ffe4ab430d348eb712bc7231c8f99919ac3eff
|
| 3 |
+
size 173703520
|
debug_pixel.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nibabel as nib
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from scipy.optimize import curve_fit
|
| 7 |
+
|
| 8 |
+
# Add src to path
|
| 9 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 10 |
+
from ivim_model import calculate_ivim_params, _fit_pixel
|
| 11 |
+
|
| 12 |
+
def func_qua(x, a0, a1, a2):
|
| 13 |
+
return a2 * (x ** 2) + a1 * x + a0
|
| 14 |
+
|
| 15 |
+
def main():
|
| 16 |
+
print("=== Debugging Single Pixel Fit ===")
|
| 17 |
+
|
| 18 |
+
# Load data
|
| 19 |
+
data_dir = os.path.join(os.path.dirname(__file__), 'data', 'CFIN')
|
| 20 |
+
nii_file = [f for f in os.listdir(data_dir) if f.endswith('.nii') and 'DTI' in f][0]
|
| 21 |
+
bval_file = [f for f in os.listdir(data_dir) if f.endswith('.bval')][0]
|
| 22 |
+
|
| 23 |
+
img = nib.load(os.path.join(data_dir, nii_file))
|
| 24 |
+
data = img.get_fdata()
|
| 25 |
+
bvals = np.loadtxt(os.path.join(data_dir, bval_file))
|
| 26 |
+
|
| 27 |
+
# Preprocess (Average)
|
| 28 |
+
unique_b = np.unique(np.round(bvals, -1))
|
| 29 |
+
avg_data_list = []
|
| 30 |
+
for b_val in unique_b:
|
| 31 |
+
idxs = np.where(np.abs(bvals - b_val) < 10)[0]
|
| 32 |
+
if len(idxs) > 0:
|
| 33 |
+
avg_data_list.append(np.mean(data[:, :, :, idxs], axis=3))
|
| 34 |
+
avg_data = np.stack(avg_data_list, axis=3)
|
| 35 |
+
|
| 36 |
+
# Pick a pixel in the brain
|
| 37 |
+
# Slice 9, x=50, y=50 (approx center)
|
| 38 |
+
slice_idx = 9
|
| 39 |
+
x, y = 48, 48
|
| 40 |
+
|
| 41 |
+
signal = avg_data[x, y, slice_idx, :]
|
| 42 |
+
print(f"Pixel ({x}, {y}, {slice_idx})")
|
| 43 |
+
print(f"b-values: {unique_b}")
|
| 44 |
+
print(f"Signal: {signal}")
|
| 45 |
+
|
| 46 |
+
# Run IVIM calculation
|
| 47 |
+
print("\n--- Running calculate_ivim_params (Quadratic) ---")
|
| 48 |
+
r2, D, f, D_star, K = calculate_ivim_params(unique_b, signal, model_type='quadratic', gof=0.9)
|
| 49 |
+
print(f"Result: D={D}, f={f}, D*={D_star}, K={K}, R2={r2}")
|
| 50 |
+
|
| 51 |
+
# Manual check of the fit
|
| 52 |
+
vec_b = unique_b
|
| 53 |
+
vec_S = signal / np.max(signal)
|
| 54 |
+
vec_S_log = np.log(vec_S + 1e-10)
|
| 55 |
+
|
| 56 |
+
# Fit high b
|
| 57 |
+
limit_dif = 180
|
| 58 |
+
idx = np.where(vec_b >= limit_dif)[0][0]
|
| 59 |
+
b_high = vec_b[idx:]
|
| 60 |
+
S_high = vec_S_log[idx:]
|
| 61 |
+
|
| 62 |
+
print(f"\nHigh-b data (b >= {limit_dif}):")
|
| 63 |
+
print(f"b: {b_high}")
|
| 64 |
+
print(f"log(S): {S_high}")
|
| 65 |
+
|
| 66 |
+
# Try fit
|
| 67 |
+
bounds = ([-np.inf, -np.inf, 0], [np.inf, 0, np.inf])
|
| 68 |
+
try:
|
| 69 |
+
popt, _ = curve_fit(func_qua, b_high, S_high, bounds=bounds)
|
| 70 |
+
print(f"Manual Curve Fit params: a0={popt[0]}, a1={popt[1]}, a2={popt[2]}")
|
| 71 |
+
print(f"Implied D = {-popt[1]}")
|
| 72 |
+
print(f"Implied K = {popt[2] / popt[1]**2 * 6 if popt[1]!=0 else 0}")
|
| 73 |
+
except Exception as e:
|
| 74 |
+
print(f"Fit failed: {e}")
|
| 75 |
+
|
| 76 |
+
if __name__ == "__main__":
|
| 77 |
+
main()
|
demo.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import matplotlib.pyplot as plt
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
# Add src to path
|
| 7 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 8 |
+
|
| 9 |
+
from ivim_model import calculate_ivim_params
|
| 10 |
+
|
| 11 |
+
def generate_synthetic_signal(b_values, D, f, D_star, K=0, noise_level=0.02):
|
| 12 |
+
"""
|
| 13 |
+
Generates synthetic IVIM-Kurtosis signal.
|
| 14 |
+
S(b) = S0 * [ f * exp(-b*D*) + (1-f) * exp(-b*D + 1/6 * b^2 * D^2 * K) ]
|
| 15 |
+
"""
|
| 16 |
+
S0 = 1000
|
| 17 |
+
b = np.array(b_values)
|
| 18 |
+
|
| 19 |
+
# Diffusion part with Kurtosis
|
| 20 |
+
# Note: Taylor expansion for Kurtosis is usually valid for b*D < 1 or similar.
|
| 21 |
+
# Standard DKI model: exp(-b*D + 1/6 * b^2 * D^2 * K)
|
| 22 |
+
|
| 23 |
+
term_diff = np.exp(-b * D + (1/6) * (b**2) * (D**2) * K)
|
| 24 |
+
term_perf = np.exp(-b * D_star)
|
| 25 |
+
|
| 26 |
+
signal = S0 * (f * term_perf + (1 - f) * term_diff)
|
| 27 |
+
|
| 28 |
+
# Add Rician noise (approximated as Gaussian for high SNR)
|
| 29 |
+
noise = np.random.normal(0, noise_level * S0, size=len(b))
|
| 30 |
+
signal_noisy = signal + noise
|
| 31 |
+
signal_noisy[signal_noisy < 0] = 0 # Magnitude signal
|
| 32 |
+
|
| 33 |
+
return signal_noisy
|
| 34 |
+
|
| 35 |
+
def main():
|
| 36 |
+
print("=== IVIM-DKI Estimation Demo ===")
|
| 37 |
+
|
| 38 |
+
# 1. Define Ground Truth Parameters
|
| 39 |
+
D_true = 0.001 # mm^2/s
|
| 40 |
+
f_true = 0.15 # fraction
|
| 41 |
+
D_star_true = 0.01 # mm^2/s
|
| 42 |
+
K_true = 0.8 # dimensionless
|
| 43 |
+
|
| 44 |
+
print(f"Ground Truth: D={D_true}, f={f_true}, D*={D_star_true}, K={K_true}")
|
| 45 |
+
|
| 46 |
+
# 2. Define b-values (typical clinical protocol)
|
| 47 |
+
b_values = [0, 10, 20, 30, 50, 80, 100, 200, 400, 800, 1000, 1500, 2000]
|
| 48 |
+
print(f"b-values: {b_values}")
|
| 49 |
+
|
| 50 |
+
# 3. Generate Synthetic Data
|
| 51 |
+
signal = generate_synthetic_signal(b_values, D_true, f_true, D_star_true, K_true)
|
| 52 |
+
|
| 53 |
+
# 4. Fit Model
|
| 54 |
+
print("\nFitting model...")
|
| 55 |
+
# We use gof=0.999 to force the segmented fit (since synthetic data is very clean)
|
| 56 |
+
# We use model_type='quadratic' to enable Kurtosis estimation
|
| 57 |
+
r2, D_est, f_est, D_star_est, K_est = calculate_ivim_params(b_values, signal, gof=0.999, model_type='quadratic')
|
| 58 |
+
|
| 59 |
+
# 5. Show Results
|
| 60 |
+
print("\n--- Results ---")
|
| 61 |
+
print(f"Estimated D: {D_est:.6f} (Error: {abs(D_est - D_true)/D_true*100:.2f}%)")
|
| 62 |
+
print(f"Estimated f: {f_est:.6f} (Error: {abs(f_est - f_true)/f_true*100:.2f}%)")
|
| 63 |
+
print(f"Estimated D*: {D_star_est:.6f} (Error: {abs(D_star_est - D_star_true)/D_star_true*100:.2f}%)")
|
| 64 |
+
print(f"Estimated K: {K_est:.6f} (Error: {abs(K_est - K_true)/K_true*100:.2f}%)")
|
| 65 |
+
print(f"Goodness of fit (R2): {r2:.4f}")
|
| 66 |
+
|
| 67 |
+
# 6. Plot
|
| 68 |
+
plt.figure(figsize=(10, 6))
|
| 69 |
+
plt.plot(b_values, signal, 'o', label='Noisy Data')
|
| 70 |
+
|
| 71 |
+
# Reconstruct fitted curve
|
| 72 |
+
# Note: The fitting function returns parameters, we need to reconstruct the curve to plot
|
| 73 |
+
# Ideally we would have a 'predict' function in ivim_model.py
|
| 74 |
+
# For now, we manually reconstruct using the same logic as generation
|
| 75 |
+
|
| 76 |
+
# Reconstruct using estimated parameters
|
| 77 |
+
# Note: The fitting logic uses segmented approach, so the reconstruction might be slightly different
|
| 78 |
+
# if we strictly follow the fitting steps, but for visualization, the full model is best.
|
| 79 |
+
|
| 80 |
+
b_smooth = np.linspace(0, max(b_values), 100)
|
| 81 |
+
term_diff_est = np.exp(-b_smooth * D_est + (1/6) * (b_smooth**2) * (D_est**2) * K_est)
|
| 82 |
+
term_perf_est = np.exp(-b_smooth * D_star_est)
|
| 83 |
+
S0_est = np.max(signal) # Approximation used in fitting
|
| 84 |
+
signal_est = S0_est * (f_est * term_perf_est + (1 - f_est) * term_diff_est)
|
| 85 |
+
|
| 86 |
+
plt.plot(b_smooth, signal_est, '-', label='Fitted Curve')
|
| 87 |
+
plt.xlabel('b-value (s/mm^2)')
|
| 88 |
+
plt.ylabel('Signal Intensity')
|
| 89 |
+
plt.title('IVIM-DKI Fit Demo')
|
| 90 |
+
plt.legend()
|
| 91 |
+
plt.grid(True)
|
| 92 |
+
|
| 93 |
+
# Save plot
|
| 94 |
+
output_plot = 'demo_fit.png'
|
| 95 |
+
plt.savefig(output_plot)
|
| 96 |
+
print(f"\nPlot saved to {output_plot}")
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
main()
|
demo_cfin.py
ADDED
|
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import nibabel as nib
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
|
| 8 |
+
# Add src to path
|
| 9 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 10 |
+
|
| 11 |
+
from ivim_model import calculate_ivim_params, process_slice_parallel
|
| 12 |
+
from utils import filtro_2D
|
| 13 |
+
|
| 14 |
+
def load_cfin_data(data_dir):
|
| 15 |
+
"""
|
| 16 |
+
Loads CFIN dataset files.
|
| 17 |
+
"""
|
| 18 |
+
# Find files
|
| 19 |
+
nii_file = None
|
| 20 |
+
bval_file = None
|
| 21 |
+
|
| 22 |
+
for f in os.listdir(data_dir):
|
| 23 |
+
if f.endswith('.nii') and 'DTI' in f:
|
| 24 |
+
nii_file = os.path.join(data_dir, f)
|
| 25 |
+
elif f.endswith('.bval'):
|
| 26 |
+
bval_file = os.path.join(data_dir, f)
|
| 27 |
+
|
| 28 |
+
if not nii_file or not bval_file:
|
| 29 |
+
raise FileNotFoundError("Could not find required .nii or .bval files in data directory.")
|
| 30 |
+
|
| 31 |
+
print(f"Loading NIfTI: {os.path.basename(nii_file)}")
|
| 32 |
+
img = nib.load(nii_file)
|
| 33 |
+
data = img.get_fdata()
|
| 34 |
+
|
| 35 |
+
print(f"Loading b-values: {os.path.basename(bval_file)}")
|
| 36 |
+
bvals = np.loadtxt(bval_file)
|
| 37 |
+
|
| 38 |
+
return data, bvals
|
| 39 |
+
|
| 40 |
+
def preprocess_dwi_data(data, bvals):
|
| 41 |
+
"""
|
| 42 |
+
Averages signals for unique b-values (Trace-weighted image).
|
| 43 |
+
This is necessary for DTI/DKI datasets with multiple directions per shell.
|
| 44 |
+
"""
|
| 45 |
+
unique_b, inverse_indices = np.unique(np.round(bvals, -1), return_inverse=True) # Round to nearest 10 to group
|
| 46 |
+
|
| 47 |
+
# Sort unique b-values
|
| 48 |
+
sorted_indices = np.argsort(unique_b)
|
| 49 |
+
unique_b = unique_b[sorted_indices]
|
| 50 |
+
|
| 51 |
+
# Initialize averaged data
|
| 52 |
+
rows, cols, slices, _ = data.shape
|
| 53 |
+
avg_data = np.zeros((rows, cols, slices, len(unique_b)))
|
| 54 |
+
|
| 55 |
+
print(f"Averaging {len(bvals)} volumes into {len(unique_b)} unique b-values: {unique_b}")
|
| 56 |
+
|
| 57 |
+
for i, b_idx in enumerate(sorted_indices):
|
| 58 |
+
# Find all original indices corresponding to this unique b-value
|
| 59 |
+
# Note: inverse_indices maps original -> unique.
|
| 60 |
+
# We need to find where inverse_indices == b_idx (before sorting)
|
| 61 |
+
# Actually, let's do it simpler.
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
+
# Re-implementing loop for clarity
|
| 65 |
+
avg_data_list = []
|
| 66 |
+
for b_val in unique_b:
|
| 67 |
+
# Find indices in original bvals that are close to this b_val
|
| 68 |
+
# Using a tolerance of 10 s/mm2
|
| 69 |
+
idxs = np.where(np.abs(bvals - b_val) < 10)[0]
|
| 70 |
+
|
| 71 |
+
if len(idxs) > 0:
|
| 72 |
+
# Average across the 4th dimension (volumes)
|
| 73 |
+
vol_avg = np.mean(data[:, :, :, idxs], axis=3)
|
| 74 |
+
avg_data_list.append(vol_avg)
|
| 75 |
+
else:
|
| 76 |
+
print(f"Warning: No volumes found for b={b_val}")
|
| 77 |
+
|
| 78 |
+
# Stack along the 4th dimension
|
| 79 |
+
avg_data = np.stack(avg_data_list, axis=3)
|
| 80 |
+
|
| 81 |
+
return avg_data, unique_b
|
| 82 |
+
|
| 83 |
+
def main():
|
| 84 |
+
print("=== CFIN Dataset IVIM Demo ===")
|
| 85 |
+
|
| 86 |
+
data_dir = os.path.join(os.path.dirname(__file__), 'data', 'CFIN')
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
data, bvals = load_cfin_data(data_dir)
|
| 90 |
+
except FileNotFoundError as e:
|
| 91 |
+
print(f"Error: {e}")
|
| 92 |
+
print("Please run 'python download_data.py' first.")
|
| 93 |
+
return
|
| 94 |
+
|
| 95 |
+
print(f"Original Data shape: {data.shape}")
|
| 96 |
+
print(f"Original Number of b-values: {len(bvals)}")
|
| 97 |
+
|
| 98 |
+
# Preprocess: Average shells
|
| 99 |
+
data, bvals = preprocess_dwi_data(data, bvals)
|
| 100 |
+
print(f"Processed Data shape: {data.shape}")
|
| 101 |
+
|
| 102 |
+
# Select a middle slice for demonstration
|
| 103 |
+
slice_idx = data.shape[2] // 2
|
| 104 |
+
print(f"Processing slice {slice_idx}...")
|
| 105 |
+
|
| 106 |
+
slice_data = data[:, :, slice_idx, :]
|
| 107 |
+
rows, cols, _ = slice_data.shape
|
| 108 |
+
|
| 109 |
+
# Initialize parameter maps
|
| 110 |
+
map_D = np.zeros((rows, cols))
|
| 111 |
+
map_f = np.zeros((rows, cols))
|
| 112 |
+
map_D_star = np.zeros((rows, cols))
|
| 113 |
+
map_K = np.zeros((rows, cols))
|
| 114 |
+
map_R2 = np.zeros((rows, cols))
|
| 115 |
+
|
| 116 |
+
# Simple mask to avoid background (threshold on b0)
|
| 117 |
+
b0_idx = np.argmin(bvals)
|
| 118 |
+
b0_img = slice_data[:, :, b0_idx]
|
| 119 |
+
mask = b0_img > np.mean(b0_img) * 0.2 # Simple threshold
|
| 120 |
+
|
| 121 |
+
# DEMO OPTIMIZATION:
|
| 122 |
+
# We now use parallel processing and optimized outlier search.
|
| 123 |
+
# We can process the full slice much faster.
|
| 124 |
+
|
| 125 |
+
print(f"\nProcessing full slice with parallel execution.")
|
| 126 |
+
|
| 127 |
+
# Use parallel processing on the full mask
|
| 128 |
+
map_R2, map_D, map_f, map_D_star, map_K = process_slice_parallel(
|
| 129 |
+
bvals, slice_data, mask=mask, gof=0.90, model_type='quadratic', n_jobs=-1
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Apply filter to smooth results (optional, as in the paper)
|
| 133 |
+
# Note: The provided filtro_2D in utils.py appears to be a binary mask filter,
|
| 134 |
+
# not a smoothing filter for continuous values. We skip it for parameter maps.
|
| 135 |
+
# print("Applying spatial filter...")
|
| 136 |
+
# map_D = filtro_2D(1, cols, rows, map_D)
|
| 137 |
+
|
| 138 |
+
# Print statistics
|
| 139 |
+
print("\n--- Parameter Statistics ---")
|
| 140 |
+
print(f"D : Mean={np.mean(map_D):.6f}, Max={np.max(map_D):.6f}, Min={np.min(map_D):.6f}")
|
| 141 |
+
print(f"f : Mean={np.mean(map_f):.6f}, Max={np.max(map_f):.6f}, Min={np.min(map_f):.6f}")
|
| 142 |
+
print(f"D* : Mean={np.mean(map_D_star):.6f}, Max={np.max(map_D_star):.6f}, Min={np.min(map_D_star):.6f}")
|
| 143 |
+
print(f"K : Mean={np.mean(map_K):.6f}, Max={np.max(map_K):.6f}, Min={np.min(map_K):.6f}")
|
| 144 |
+
|
| 145 |
+
# Plot results
|
| 146 |
+
print("Plotting results...")
|
| 147 |
+
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
|
| 148 |
+
|
| 149 |
+
# Helper to rotate for better visualization (anatomical orientation)
|
| 150 |
+
def show_map(ax, data, title, vmin, vmax, cmap):
|
| 151 |
+
# Rotate 90 degrees to match anatomical view usually
|
| 152 |
+
im = ax.imshow(np.rot90(data), cmap=cmap, vmin=vmin, vmax=vmax)
|
| 153 |
+
ax.set_title(title)
|
| 154 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 155 |
+
ax.axis('off')
|
| 156 |
+
|
| 157 |
+
# Adjust vmin/vmax based on typical physiological values
|
| 158 |
+
show_map(axes[0], map_R2, "Goodness of Fit (R2)", 0.5, 1.0, 'gray')
|
| 159 |
+
show_map(axes[1], map_D, "Diffusion (D)", 0, 0.003, 'viridis')
|
| 160 |
+
show_map(axes[2], map_f, "Perfusion Fraction (f)", 0, 0.3, 'viridis') # f is usually 0-0.3
|
| 161 |
+
show_map(axes[3], map_D_star, "Pseudo-Diffusion (D*)", 0, 0.1, 'viridis') # D* is usually high
|
| 162 |
+
show_map(axes[4], map_K, "Kurtosis (K)", 0, 2.0, 'viridis')
|
| 163 |
+
|
| 164 |
+
plt.tight_layout()
|
| 165 |
+
plt.savefig('cfin_demo_results.png')
|
| 166 |
+
print("Results saved to cfin_demo_results.png")
|
| 167 |
+
|
| 168 |
+
if __name__ == "__main__":
|
| 169 |
+
main()
|
demo_colab.ipynb
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"<a href=\"https://colab.research.google.com/github/lsprietog/public_release/blob/main/demo_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"# \ud83c\udf93 IVIM-DKI Machine Learning Demo\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook demonstrates how to train and evaluate Machine Learning models for Diffusion MRI analysis, as described in the paper **\"Exploring the Potential of Machine Learning Algorithms to Improve Diffusion Nuclear Magnetic Resonance Imaging Models Analysis\"**.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"**What you will learn:**\n",
|
| 14 |
+
"1. How to generate synthetic IVIM-DKI signal data.\n",
|
| 15 |
+
"2. How to train Random Forest and Extra Trees models.\n",
|
| 16 |
+
"3. How to compare ML performance against standard Least Squares fitting.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"---"
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": [
|
| 27 |
+
"# @title 1. Setup & Install\n",
|
| 28 |
+
"!git clone https://github.com/lsprietog/public_release.git\n",
|
| 29 |
+
"%cd public_release\n",
|
| 30 |
+
"!pip install -r requirements.txt\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"import numpy as np\n",
|
| 33 |
+
"import matplotlib.pyplot as plt\n",
|
| 34 |
+
"from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor\n",
|
| 35 |
+
"from sklearn.metrics import mean_squared_error, r2_score\n",
|
| 36 |
+
"from scipy.optimize import curve_fit\n",
|
| 37 |
+
"import joblib\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"# Define the IVIM-DKI model function\n",
|
| 40 |
+
"def ivim_dki_model(b, D, f, Dstar, K):\n",
|
| 41 |
+
" return f * np.exp(-b * Dstar) + (1 - f) * np.exp(-b * D + (1/6) * (b**2) * (D**2) * K)"
|
| 42 |
+
]
|
| 43 |
+
},
|
| 44 |
+
{
|
| 45 |
+
"cell_type": "code",
|
| 46 |
+
"execution_count": null,
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"# @title 2. Generate Synthetic Data\n",
|
| 51 |
+
"# Simulation parameters\n",
|
| 52 |
+
"n_samples = 5000\n",
|
| 53 |
+
"b_values = np.array([0, 10, 20, 30, 50, 80, 100, 150, 200, 400, 600, 800, 1000, 1500, 2000])\n",
|
| 54 |
+
"snr = 30\n",
|
| 55 |
+
"\n",
|
| 56 |
+
"# Random ground truth parameters\n",
|
| 57 |
+
"np.random.seed(42)\n",
|
| 58 |
+
"D_true = np.random.uniform(0.0005, 0.003, n_samples)\n",
|
| 59 |
+
"f_true = np.random.uniform(0.1, 0.4, n_samples)\n",
|
| 60 |
+
"Dstar_true = np.random.uniform(0.005, 0.1, n_samples)\n",
|
| 61 |
+
"K_true = np.random.uniform(0, 2.0, n_samples)\n",
|
| 62 |
+
"\n",
|
| 63 |
+
"Y_true = np.stack([D_true, f_true, Dstar_true, K_true], axis=1)\n",
|
| 64 |
+
"\n",
|
| 65 |
+
"# Generate signals\n",
|
| 66 |
+
"signals = []\n",
|
| 67 |
+
"for i in range(n_samples):\n",
|
| 68 |
+
" sig = ivim_dki_model(b_values, D_true[i], f_true[i], Dstar_true[i], K_true[i])\n",
|
| 69 |
+
" # Add Rician noise\n",
|
| 70 |
+
" noise_r = np.random.normal(0, 1/snr, len(b_values))\n",
|
| 71 |
+
" noise_i = np.random.normal(0, 1/snr, len(b_values))\n",
|
| 72 |
+
" noisy_sig = np.sqrt((sig + noise_r)**2 + noise_i**2)\n",
|
| 73 |
+
" signals.append(noisy_sig)\n",
|
| 74 |
+
"\n",
|
| 75 |
+
"X = np.array(signals)\n",
|
| 76 |
+
"print(f\"Generated {n_samples} synthetic signals with SNR={snr}\")"
|
| 77 |
+
]
|
| 78 |
+
},
|
| 79 |
+
{
|
| 80 |
+
"cell_type": "code",
|
| 81 |
+
"execution_count": null,
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [],
|
| 84 |
+
"source": [
|
| 85 |
+
"# @title 3. Train Machine Learning Model\n",
|
| 86 |
+
"print(\"Training Extra Trees Regressor...\")\n",
|
| 87 |
+
"# Using optimized hyperparameters for speed/size balance\n",
|
| 88 |
+
"model = ExtraTreesRegressor(\n",
|
| 89 |
+
" n_estimators=50,\n",
|
| 90 |
+
" max_depth=15,\n",
|
| 91 |
+
" min_samples_split=5,\n",
|
| 92 |
+
" n_jobs=-1,\n",
|
| 93 |
+
" random_state=42\n",
|
| 94 |
+
")\n",
|
| 95 |
+
"\n",
|
| 96 |
+
"model.fit(X, Y_true)\n",
|
| 97 |
+
"print(\"Training complete!\")\n",
|
| 98 |
+
"\n",
|
| 99 |
+
"# Evaluate\n",
|
| 100 |
+
"Y_pred = model.predict(X)\n",
|
| 101 |
+
"r2 = r2_score(Y_true, Y_pred)\n",
|
| 102 |
+
"print(f\"Model R2 Score: {r2:.4f}\")"
|
| 103 |
+
]
|
| 104 |
+
},
|
| 105 |
+
{
|
| 106 |
+
"cell_type": "code",
|
| 107 |
+
"execution_count": null,
|
| 108 |
+
"metadata": {},
|
| 109 |
+
"outputs": [],
|
| 110 |
+
"source": [
|
| 111 |
+
"# @title 4. Visualize Results\n",
|
| 112 |
+
"param_names = ['D', 'f', 'D*', 'K']\n",
|
| 113 |
+
"fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
|
| 114 |
+
"\n",
|
| 115 |
+
"for i in range(4):\n",
|
| 116 |
+
" axes[i].scatter(Y_true[:, i], Y_pred[:, i], alpha=0.1, s=5)\n",
|
| 117 |
+
" axes[i].plot([Y_true[:, i].min(), Y_true[:, i].max()], \n",
|
| 118 |
+
" [Y_true[:, i].min(), Y_true[:, i].max()], 'r--')\n",
|
| 119 |
+
" axes[i].set_xlabel('Ground Truth')\n",
|
| 120 |
+
" axes[i].set_ylabel('Prediction')\n",
|
| 121 |
+
" axes[i].set_title(f'{param_names[i]} (R2={r2_score(Y_true[:, i], Y_pred[:, i]):.2f})')\n",
|
| 122 |
+
"\n",
|
| 123 |
+
"plt.show()"
|
| 124 |
+
]
|
| 125 |
+
},
|
| 126 |
+
{
|
| 127 |
+
"cell_type": "code",
|
| 128 |
+
"execution_count": null,
|
| 129 |
+
"metadata": {},
|
| 130 |
+
"outputs": [],
|
| 131 |
+
"source": [
|
| 132 |
+
"# @title 5. Save Model\n",
|
| 133 |
+
"joblib.dump(model, 'my_trained_model.joblib')\n",
|
| 134 |
+
"print(\"Model saved as 'my_trained_model.joblib'\")\n",
|
| 135 |
+
"files.download('my_trained_model.joblib')"
|
| 136 |
+
]
|
| 137 |
+
}
|
| 138 |
+
],
|
| 139 |
+
"metadata": {
|
| 140 |
+
"kernelspec": {
|
| 141 |
+
"display_name": "Python 3",
|
| 142 |
+
"language": "python",
|
| 143 |
+
"name": "python3"
|
| 144 |
+
},
|
| 145 |
+
"language_info": {
|
| 146 |
+
"codemirror_mode": {
|
| 147 |
+
"name": "ipython",
|
| 148 |
+
"version": 3
|
| 149 |
+
},
|
| 150 |
+
"file_extension": ".py",
|
| 151 |
+
"mimetype": "text/x-python",
|
| 152 |
+
"name": "python",
|
| 153 |
+
"nbconvert_exporter": "python",
|
| 154 |
+
"pygments_lexer": "ipython3",
|
| 155 |
+
"version": "3.8.5"
|
| 156 |
+
}
|
| 157 |
+
},
|
| 158 |
+
"nbformat": 4,
|
| 159 |
+
"nbformat_minor": 4
|
| 160 |
+
}
|
demo_dipy.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import numpy as np
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
from dipy.data import fetch_ivim
|
| 6 |
+
from dipy.core.gradients import gradient_table
|
| 7 |
+
import nibabel as nib
|
| 8 |
+
from tempfile import TemporaryDirectory
|
| 9 |
+
|
| 10 |
+
# Add src to path
|
| 11 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 12 |
+
|
| 13 |
+
from ivim_model import process_slice_parallel
|
| 14 |
+
from utils import filtro_2D
|
| 15 |
+
|
| 16 |
+
def main():
|
| 17 |
+
print("=== DIPY IVIM Dataset Demo ===")
|
| 18 |
+
|
| 19 |
+
# Fetch IVIM data
|
| 20 |
+
print("Fetching IVIM data from DIPY...")
|
| 21 |
+
# This downloads the data to the user's home directory by default (.dipy/ivim)
|
| 22 |
+
files, folder = fetch_ivim()
|
| 23 |
+
|
| 24 |
+
# The fetch_ivim returns a dictionary or list of files depending on version,
|
| 25 |
+
# but usually it downloads 'ivim_data.nii.gz' and 'ivim_bvals.txt'
|
| 26 |
+
# Let's find them in the folder
|
| 27 |
+
print(f"Data downloaded to: {folder}")
|
| 28 |
+
|
| 29 |
+
nii_path = os.path.join(folder, 'ivim.nii.gz')
|
| 30 |
+
bval_path = os.path.join(folder, 'ivim.bval')
|
| 31 |
+
|
| 32 |
+
if not os.path.exists(nii_path):
|
| 33 |
+
# Fallback for different dipy versions
|
| 34 |
+
nii_path = os.path.join(folder, 'ivim_data.nii.gz')
|
| 35 |
+
|
| 36 |
+
print(f"Loading NIfTI: {nii_path}")
|
| 37 |
+
img = nib.load(nii_path)
|
| 38 |
+
data = img.get_fdata()
|
| 39 |
+
|
| 40 |
+
print(f"Loading b-values: {bval_path}")
|
| 41 |
+
bvals = np.loadtxt(bval_path)
|
| 42 |
+
|
| 43 |
+
print(f"Data shape: {data.shape}")
|
| 44 |
+
print(f"b-values: {bvals}")
|
| 45 |
+
|
| 46 |
+
# Select slice 15 as requested
|
| 47 |
+
slice_idx = 15
|
| 48 |
+
if slice_idx >= data.shape[2]:
|
| 49 |
+
slice_idx = data.shape[2] // 2
|
| 50 |
+
|
| 51 |
+
print(f"Processing slice {slice_idx}...")
|
| 52 |
+
|
| 53 |
+
slice_data = data[:, :, slice_idx, :]
|
| 54 |
+
rows, cols, _ = slice_data.shape
|
| 55 |
+
|
| 56 |
+
# Create a mask (simple thresholding on b0)
|
| 57 |
+
b0_idx = np.argmin(bvals)
|
| 58 |
+
b0_img = slice_data[:, :, b0_idx]
|
| 59 |
+
mask = b0_img > np.mean(b0_img) * 0.2
|
| 60 |
+
|
| 61 |
+
print(f"Processing full slice with parallel execution.")
|
| 62 |
+
|
| 63 |
+
# Use parallel processing
|
| 64 |
+
# Note: The original paper uses 'quadratic' (Kurtosis) model often, but for standard IVIM data 'linear' might be safer?
|
| 65 |
+
# Let's stick to 'quadratic' as it's the paper's main contribution.
|
| 66 |
+
map_R2, map_D, map_f, map_D_star, map_K = process_slice_parallel(
|
| 67 |
+
bvals, slice_data, mask=mask, gof=0.90, model_type='quadratic', n_jobs=-1
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
# Apply filter
|
| 71 |
+
# Note: filtro_2D is binary, skipping for parameter maps
|
| 72 |
+
# print("Applying spatial filter...")
|
| 73 |
+
# map_D = filtro_2D(1, cols, rows, map_D)
|
| 74 |
+
# map_f = filtro_2D(1, cols, rows, map_f)
|
| 75 |
+
# map_D_star = filtro_2D(1, cols, rows, map_D_star)
|
| 76 |
+
|
| 77 |
+
# Plot results matching the style of Principal_IVIM.ipynb
|
| 78 |
+
print("Plotting results...")
|
| 79 |
+
fig, axes = plt.subplots(1, 5, figsize=(20, 4))
|
| 80 |
+
|
| 81 |
+
def show_map(ax, data, title, vmin=None, vmax=None, cmap='gray'):
|
| 82 |
+
# Rotate to match typical orientation if needed
|
| 83 |
+
im = ax.imshow(np.rot90(data), cmap=cmap, vmin=vmin, vmax=vmax, interpolation='none')
|
| 84 |
+
ax.set_title(title)
|
| 85 |
+
plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
|
| 86 |
+
ax.axis('off')
|
| 87 |
+
|
| 88 |
+
# Ranges based on typical values or auto-scaled
|
| 89 |
+
show_map(axes[0], map_R2, "R2", 0.5, 1.0, 'gray')
|
| 90 |
+
show_map(axes[1], map_D, "D (Diffusion)", 0, 0.002, 'viridis') # Typical D in brain ~0.0007 - 0.001 mm2/s
|
| 91 |
+
show_map(axes[2], map_f, "f (Perfusion)", 0, 0.3, 'viridis')
|
| 92 |
+
show_map(axes[3], map_D_star, "D* (Pseudo-Diff)", 0, 0.05, 'viridis')
|
| 93 |
+
show_map(axes[4], map_K, "K (Kurtosis)", 0, 1.5, 'viridis')
|
| 94 |
+
|
| 95 |
+
plt.tight_layout()
|
| 96 |
+
plt.savefig('dipy_ivim_results.png')
|
| 97 |
+
print("Results saved to dipy_ivim_results.png")
|
| 98 |
+
|
| 99 |
+
if __name__ == "__main__":
|
| 100 |
+
main()
|
demo_fit.png
ADDED
|
inference_colab.ipynb
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"metadata": {},
|
| 6 |
+
"source": [
|
| 7 |
+
"<a href=\"https://colab.research.google.com/github/lsprietog/public_release/blob/main/inference_colab.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>\n",
|
| 8 |
+
"\n",
|
| 9 |
+
"# \ud83e\udde0 Apply IVIM-DKI Machine Learning to Your Data\n",
|
| 10 |
+
"\n",
|
| 11 |
+
"This notebook allows you to use the pre-trained Machine Learning models from the paper **\"Exploring the Potential of Machine Learning Algorithms to Improve Diffusion Nuclear Magnetic Resonance Imaging Models Analysis\"** on your own NIfTI data.\n",
|
| 12 |
+
"\n",
|
| 13 |
+
"**What this tool does:**\n",
|
| 14 |
+
"1. Loads your Diffusion MRI data (NIfTI format).\n",
|
| 15 |
+
"2. Applies a pre-trained Extra Trees Regressor to estimate $D$, $f$, $D^*$, and $K$.\n",
|
| 16 |
+
"3. Generates and saves the parameter maps.\n",
|
| 17 |
+
"\n",
|
| 18 |
+
"**Note:** The pre-trained model included here was trained on a standard set of b-values: `[0, 10, 20, 30, 50, 80, 100, 150, 200, 400, 600, 800, 1000, 1500, 2000]`. For best results on your specific protocol, we recommend retraining the model using the `demo_colab.ipynb` notebook."
|
| 19 |
+
]
|
| 20 |
+
},
|
| 21 |
+
{
|
| 22 |
+
"cell_type": "code",
|
| 23 |
+
"execution_count": null,
|
| 24 |
+
"metadata": {},
|
| 25 |
+
"outputs": [],
|
| 26 |
+
"source": [
|
| 27 |
+
"# @title 1. Setup & Install\n",
|
| 28 |
+
"!git clone https://github.com/lsprietog/public_release.git\n",
|
| 29 |
+
"%cd public_release\n",
|
| 30 |
+
"!pip install -r requirements.txt\n",
|
| 31 |
+
"\n",
|
| 32 |
+
"import numpy as np\n",
|
| 33 |
+
"import nibabel as nib\n",
|
| 34 |
+
"import joblib\n",
|
| 35 |
+
"import matplotlib.pyplot as plt\n",
|
| 36 |
+
"import os\n",
|
| 37 |
+
"from google.colab import files"
|
| 38 |
+
]
|
| 39 |
+
},
|
| 40 |
+
{
|
| 41 |
+
"cell_type": "code",
|
| 42 |
+
"execution_count": null,
|
| 43 |
+
"metadata": {},
|
| 44 |
+
"outputs": [],
|
| 45 |
+
"source": [
|
| 46 |
+
"# @title 2. Upload Your Data\n",
|
| 47 |
+
"print(\"Please upload your 4D DWI NIfTI file (.nii or .nii.gz)...\")\n",
|
| 48 |
+
"uploaded = files.upload()\n",
|
| 49 |
+
"if uploaded:\n",
|
| 50 |
+
" dwi_filename = list(uploaded.keys())[0]\n",
|
| 51 |
+
" print(f\"Uploaded: {dwi_filename}\")\n",
|
| 52 |
+
"else:\n",
|
| 53 |
+
" print(\"No file uploaded.\")\n",
|
| 54 |
+
"\n",
|
| 55 |
+
"print(\"\\n(Optional) Upload a Mask file (binary mask). If none, press Cancel or skip.\")\n",
|
| 56 |
+
"try:\n",
|
| 57 |
+
" uploaded_mask = files.upload()\n",
|
| 58 |
+
" if uploaded_mask:\n",
|
| 59 |
+
" mask_filename = list(uploaded_mask.keys())[0]\n",
|
| 60 |
+
" else:\n",
|
| 61 |
+
" mask_filename = None\n",
|
| 62 |
+
"except:\n",
|
| 63 |
+
" mask_filename = None\n",
|
| 64 |
+
" print(\"No mask provided. Processing entire volume (this might take longer).\")"
|
| 65 |
+
]
|
| 66 |
+
},
|
| 67 |
+
{
|
| 68 |
+
"cell_type": "code",
|
| 69 |
+
"execution_count": null,
|
| 70 |
+
"metadata": {},
|
| 71 |
+
"outputs": [],
|
| 72 |
+
"source": [
|
| 73 |
+
"# @title 3. Load Model & Data\n",
|
| 74 |
+
"# Load the pre-trained model\n",
|
| 75 |
+
"model_path = 'models/ivim_dki_extratrees.joblib'\n",
|
| 76 |
+
"if not os.path.exists(model_path):\n",
|
| 77 |
+
" print(\"Downloading pre-trained model...\")\n",
|
| 78 |
+
" # In a real scenario, you might fetch this from a release URL if it's not in the repo\n",
|
| 79 |
+
" # But here we assume it's in the repo we cloned\n",
|
| 80 |
+
" pass\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"print(f\"Loading model from {model_path}...\")\n",
|
| 83 |
+
"model = joblib.load(model_path)\n",
|
| 84 |
+
"\n",
|
| 85 |
+
"# Load NIfTI\n",
|
| 86 |
+
"img = nib.load(dwi_filename)\n",
|
| 87 |
+
"data = img.get_fdata()\n",
|
| 88 |
+
"affine = img.affine\n",
|
| 89 |
+
"\n",
|
| 90 |
+
"if mask_filename:\n",
|
| 91 |
+
" mask = nib.load(mask_filename).get_fdata() > 0\n",
|
| 92 |
+
"else:\n",
|
| 93 |
+
" # Create a simple threshold mask to avoid background noise\n",
|
| 94 |
+
" # Assuming b0 is the first volume\n",
|
| 95 |
+
" mask = data[..., 0] > (np.mean(data[..., 0]) * 0.1)\n",
|
| 96 |
+
"\n",
|
| 97 |
+
"print(f\"Data shape: {data.shape}\")\n",
|
| 98 |
+
"print(f\"Mask voxels: {np.sum(mask)}\")"
|
| 99 |
+
]
|
| 100 |
+
},
|
| 101 |
+
{
|
| 102 |
+
"cell_type": "code",
|
| 103 |
+
"execution_count": null,
|
| 104 |
+
"metadata": {},
|
| 105 |
+
"outputs": [],
|
| 106 |
+
"source": [
|
| 107 |
+
"# @title 4. Run Prediction\n",
|
| 108 |
+
"print(\"Preprocessing data...\")\n",
|
| 109 |
+
"\n",
|
| 110 |
+
"# Reshape for prediction\n",
|
| 111 |
+
"n_x, n_y, n_z, n_b = data.shape\n",
|
| 112 |
+
"flat_data = data[mask]\n",
|
| 113 |
+
"\n",
|
| 114 |
+
"# Normalize signal (S/S0)\n",
|
| 115 |
+
"# Assuming the first volume is b=0\n",
|
| 116 |
+
"S0 = flat_data[:, 0][:, np.newaxis]\n",
|
| 117 |
+
"S0[S0 == 0] = 1 # Avoid division by zero\n",
|
| 118 |
+
"X_input = flat_data / S0\n",
|
| 119 |
+
"\n",
|
| 120 |
+
"# Check if b-values match (Basic check)\n",
|
| 121 |
+
"expected_b_len = 15 # Based on our training script\n",
|
| 122 |
+
"if n_b != expected_b_len:\n",
|
| 123 |
+
" print(f\"WARNING: Your data has {n_b} volumes, but the model expects {expected_b_len}.\")\n",
|
| 124 |
+
" print(\"We will try to interpolate or truncate, but results may be inaccurate.\")\n",
|
| 125 |
+
" # Simple truncation or padding for demo purposes\n",
|
| 126 |
+
" if n_b > expected_b_len:\n",
|
| 127 |
+
" X_input = X_input[:, :expected_b_len]\n",
|
| 128 |
+
" else:\n",
|
| 129 |
+
" # Pad with zeros (Not ideal, but prevents crash)\n",
|
| 130 |
+
" padding = np.zeros((X_input.shape[0], expected_b_len - n_b))\n",
|
| 131 |
+
" X_input = np.hstack([X_input, padding])\n",
|
| 132 |
+
"\n",
|
| 133 |
+
"print(\"Predicting parameters (this is the fast part!)...\")\n",
|
| 134 |
+
"predictions = model.predict(X_input)\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"# Reconstruct 3D maps\n",
|
| 137 |
+
"param_maps = np.zeros((n_x, n_y, n_z, 4)) # D, f, D*, K\n",
|
| 138 |
+
"param_maps[mask] = predictions\n",
|
| 139 |
+
"\n",
|
| 140 |
+
"print(\"Prediction complete!\")"
|
| 141 |
+
]
|
| 142 |
+
},
|
| 143 |
+
{
|
| 144 |
+
"cell_type": "code",
|
| 145 |
+
"execution_count": null,
|
| 146 |
+
"metadata": {},
|
| 147 |
+
"outputs": [],
|
| 148 |
+
"source": [
|
| 149 |
+
"# @title 5. Visualize & Download\n",
|
| 150 |
+
"# Extract maps\n",
|
| 151 |
+
"D_map = param_maps[..., 0]\n",
|
| 152 |
+
"f_map = param_maps[..., 1]\n",
|
| 153 |
+
"Dstar_map = param_maps[..., 2]\n",
|
| 154 |
+
"K_map = param_maps[..., 3]\n",
|
| 155 |
+
"\n",
|
| 156 |
+
"# Plot middle slice\n",
|
| 157 |
+
"z_slice = n_z // 2\n",
|
| 158 |
+
"\n",
|
| 159 |
+
"fig, axes = plt.subplots(1, 4, figsize=(20, 5))\n",
|
| 160 |
+
"axes[0].imshow(np.rot90(D_map[..., z_slice]), cmap='gray', vmin=0, vmax=0.003)\n",
|
| 161 |
+
"axes[0].set_title('Diffusion (D)')\n",
|
| 162 |
+
"axes[1].imshow(np.rot90(f_map[..., z_slice]), cmap='jet', vmin=0, vmax=0.5)\n",
|
| 163 |
+
"axes[1].set_title('Perfusion Fraction (f)')\n",
|
| 164 |
+
"axes[2].imshow(np.rot90(Dstar_map[..., z_slice]), cmap='hot', vmin=0, vmax=0.1)\n",
|
| 165 |
+
"axes[2].set_title('Pseudo-Diffusion (D*)')\n",
|
| 166 |
+
"axes[3].imshow(np.rot90(K_map[..., z_slice]), cmap='magma', vmin=0, vmax=2.0)\n",
|
| 167 |
+
"axes[3].set_title('Kurtosis (K)')\n",
|
| 168 |
+
"plt.show()\n",
|
| 169 |
+
"\n",
|
| 170 |
+
"# Save NIfTI files\n",
|
| 171 |
+
"print(\"Saving NIfTI files...\")\n",
|
| 172 |
+
"nib.save(nib.Nifti1Image(D_map, affine), 'map_D.nii.gz')\n",
|
| 173 |
+
"nib.save(nib.Nifti1Image(f_map, affine), 'map_f.nii.gz')\n",
|
| 174 |
+
"nib.save(nib.Nifti1Image(Dstar_map, affine), 'map_Dstar.nii.gz')\n",
|
| 175 |
+
"nib.save(nib.Nifti1Image(K_map, affine), 'map_K.nii.gz')\n",
|
| 176 |
+
"\n",
|
| 177 |
+
"print(\"Downloading maps...\")\n",
|
| 178 |
+
"files.download('map_D.nii.gz')\n",
|
| 179 |
+
"files.download('map_f.nii.gz')\n",
|
| 180 |
+
"files.download('map_Dstar.nii.gz')\n",
|
| 181 |
+
"files.download('map_K.nii.gz')"
|
| 182 |
+
]
|
| 183 |
+
}
|
| 184 |
+
],
|
| 185 |
+
"metadata": {
|
| 186 |
+
"kernelspec": {
|
| 187 |
+
"display_name": "Python 3",
|
| 188 |
+
"language": "python",
|
| 189 |
+
"name": "python3"
|
| 190 |
+
},
|
| 191 |
+
"language_info": {
|
| 192 |
+
"codemirror_mode": {
|
| 193 |
+
"name": "ipython",
|
| 194 |
+
"version": 3
|
| 195 |
+
},
|
| 196 |
+
"file_extension": ".py",
|
| 197 |
+
"mimetype": "text/x-python",
|
| 198 |
+
"name": "python",
|
| 199 |
+
"nbconvert_exporter": "python",
|
| 200 |
+
"pygments_lexer": "ipython3",
|
| 201 |
+
"version": "3.8.5"
|
| 202 |
+
}
|
| 203 |
+
},
|
| 204 |
+
"nbformat": 4,
|
| 205 |
+
"nbformat_minor": 4
|
| 206 |
+
}
|
ivim_dki_extratrees.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84f9234b46f168f5d5ed05f11467d9966823ea743c1145adf57653b9fa4d8791
|
| 3 |
+
size 35322471
|
models/b_values_config.npy
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:78c891241f48d197a0b978e4223ffb645c19330ea70f1251dfec7586e7025ab1
|
| 3 |
+
size 248
|
models/ivim_dki_extratrees.joblib
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:84f9234b46f168f5d5ed05f11467d9966823ea743c1145adf57653b9fa4d8791
|
| 3 |
+
size 35322471
|
requirements.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy
|
| 2 |
+
scipy
|
| 3 |
+
matplotlib
|
| 4 |
+
pydicom
|
| 5 |
+
nibabel
|
| 6 |
+
tqdm
|
| 7 |
+
ipywidgets
|
| 8 |
+
pandas
|
| 9 |
+
seaborn
|
| 10 |
+
scikit-learn
|
| 11 |
+
xgboost
|
| 12 |
+
joblib
|
| 13 |
+
joblib
|
setup.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from setuptools import setup, find_packages
|
| 2 |
+
|
| 3 |
+
setup(
|
| 4 |
+
name="ivim-dwi",
|
| 5 |
+
version="1.0.0",
|
| 6 |
+
author="Leonar Steven Prieto-Gonzalez",
|
| 7 |
+
description="Robust IVIM and Kurtosis parameter estimation for DWI",
|
| 8 |
+
long_description=open("README.md").read(),
|
| 9 |
+
long_description_content_type="text/markdown",
|
| 10 |
+
url="https://github.com/yourusername/ivim-robust",
|
| 11 |
+
packages=find_packages(where="src"),
|
| 12 |
+
package_dir={"": "src"},
|
| 13 |
+
classifiers=[
|
| 14 |
+
"Programming Language :: Python :: 3",
|
| 15 |
+
"License :: OSI Approved :: MIT License",
|
| 16 |
+
"Operating System :: OS Independent",
|
| 17 |
+
"Topic :: Scientific/Engineering :: Medical Science Apps.",
|
| 18 |
+
],
|
| 19 |
+
python_requires='>=3.8',
|
| 20 |
+
install_requires=[
|
| 21 |
+
"numpy",
|
| 22 |
+
"scipy",
|
| 23 |
+
"matplotlib",
|
| 24 |
+
"pydicom",
|
| 25 |
+
"nibabel",
|
| 26 |
+
"tqdm",
|
| 27 |
+
"joblib",
|
| 28 |
+
"pandas",
|
| 29 |
+
"scikit-learn",
|
| 30 |
+
"xgboost"
|
| 31 |
+
],
|
| 32 |
+
)
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/ivim_model.cpython-313.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
src/__pycache__/ml_models.cpython-313.pyc
ADDED
|
Binary file (6.41 kB). View file
|
|
|
src/__pycache__/utils.cpython-313.pyc
ADDED
|
Binary file (4.14 kB). View file
|
|
|
src/ivim_model.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.optimize import curve_fit
|
| 3 |
+
from joblib import Parallel, delayed
|
| 4 |
+
import multiprocessing
|
| 5 |
+
|
| 6 |
+
def func_lin(x, a0, a1):
|
| 7 |
+
"""Linear function: a1*x + a0"""
|
| 8 |
+
return a1 * x + a0
|
| 9 |
+
|
| 10 |
+
def func_qua(x, a0, a1, a2):
|
| 11 |
+
"""Quadratic function: a2*x^2 + a1*x + a0"""
|
| 12 |
+
return a2 * (x ** 2) + a1 * x + a0
|
| 13 |
+
|
| 14 |
+
def _fit_step(x_data, y_data, model_type):
|
| 15 |
+
"""
|
| 16 |
+
Performs a single fitting step (linear or quadratic).
|
| 17 |
+
"""
|
| 18 |
+
if model_type == 'linear':
|
| 19 |
+
# Constrain slope to be negative (decay)
|
| 20 |
+
bounds = (-np.inf, [np.inf, 0])
|
| 21 |
+
try:
|
| 22 |
+
popt, _ = curve_fit(func_lin, x_data, y_data, bounds=bounds)
|
| 23 |
+
residuals = np.sum((y_data - func_lin(x_data, *popt)) ** 2)
|
| 24 |
+
a2 = 0
|
| 25 |
+
a0, a1 = popt
|
| 26 |
+
except RuntimeError:
|
| 27 |
+
return 0, 0, 0, 0
|
| 28 |
+
else:
|
| 29 |
+
# Constrain quadratic term
|
| 30 |
+
bounds = ([-np.inf, -np.inf, 0], [np.inf, 0, np.inf])
|
| 31 |
+
try:
|
| 32 |
+
popt, _ = curve_fit(func_qua, x_data, y_data, bounds=bounds)
|
| 33 |
+
residuals = np.sum((y_data - func_qua(x_data, *popt)) ** 2)
|
| 34 |
+
a0, a1, a2 = popt
|
| 35 |
+
except RuntimeError:
|
| 36 |
+
return 0, 0, 0, 0
|
| 37 |
+
|
| 38 |
+
y_mean = np.mean(y_data)
|
| 39 |
+
ss_tot = np.sum((y_data - y_mean) ** 2)
|
| 40 |
+
|
| 41 |
+
r2 = 1 - residuals / ss_tot if ss_tot != 0 else 0
|
| 42 |
+
|
| 43 |
+
return a0, a1, a2, r2
|
| 44 |
+
|
| 45 |
+
def _fit_pixel(x, y, model_type='linear', gof_threshold=0.9):
|
| 46 |
+
"""
|
| 47 |
+
Iterative fitting process. Removes outliers if R2 is below threshold.
|
| 48 |
+
"""
|
| 49 |
+
a0, a1, a2, r2 = _fit_step(x, y, model_type)
|
| 50 |
+
|
| 51 |
+
if r2 < gof_threshold:
|
| 52 |
+
best_r2 = 0
|
| 53 |
+
best_params = (a0, a1, a2)
|
| 54 |
+
best_data = (x, y)
|
| 55 |
+
|
| 56 |
+
# Optimization: Only check points with high residuals if N is large
|
| 57 |
+
if len(x) > 20:
|
| 58 |
+
if model_type == 'linear':
|
| 59 |
+
y_pred = func_lin(x, a0, a1)
|
| 60 |
+
else:
|
| 61 |
+
y_pred = func_qua(x, a0, a1, a2)
|
| 62 |
+
|
| 63 |
+
residuals = (y - y_pred)**2
|
| 64 |
+
# Check only the top 5 worst outliers for speed
|
| 65 |
+
n_check = min(len(x), 5)
|
| 66 |
+
indices_to_check = np.argsort(residuals)[-n_check:]
|
| 67 |
+
else:
|
| 68 |
+
indices_to_check = range(len(x))
|
| 69 |
+
|
| 70 |
+
# Leave-one-out strategy on candidate points
|
| 71 |
+
for i in indices_to_check:
|
| 72 |
+
x_subset = np.delete(x, i)
|
| 73 |
+
y_subset = np.delete(y, i)
|
| 74 |
+
|
| 75 |
+
curr_a0, curr_a1, curr_a2, curr_r2 = _fit_step(x_subset, y_subset, model_type)
|
| 76 |
+
|
| 77 |
+
if curr_r2 > best_r2:
|
| 78 |
+
best_r2 = curr_r2
|
| 79 |
+
best_params = (curr_a0, curr_a1, curr_a2)
|
| 80 |
+
best_data = (x_subset, y_subset)
|
| 81 |
+
|
| 82 |
+
return best_params, best_r2, best_data
|
| 83 |
+
|
| 84 |
+
return (a0, a1, a2), r2, (x, y)
|
| 85 |
+
|
| 86 |
+
def media(r1, r2, z=1):
|
| 87 |
+
"""Calculates mean (geometric, harmonic, or arithmetic)."""
|
| 88 |
+
if r1 <= 0.01: r1 = 1E-3
|
| 89 |
+
if r2 <= 0.01: r2 = 1E-3
|
| 90 |
+
|
| 91 |
+
if z == 1: # Geometric
|
| 92 |
+
m = np.sqrt(r1 * r2)
|
| 93 |
+
elif z == 2: # Harmonic
|
| 94 |
+
m = 2 / (1/r1 + 1/r2)
|
| 95 |
+
else: # Arithmetic
|
| 96 |
+
m = (r1 + r2) / 2
|
| 97 |
+
return m
|
| 98 |
+
|
| 99 |
+
def calculate_ivim_params(b_values, signal_values, gof=0.9, limit_dif=180, model_type='linear'):
|
| 100 |
+
"""
|
| 101 |
+
Calculates IVIM and Kurtosis parameters (D, f, D*, K) for a single voxel using a segmented fitting approach.
|
| 102 |
+
|
| 103 |
+
Args:
|
| 104 |
+
b_values: Array of b-values.
|
| 105 |
+
signal_values: Array of signal intensities corresponding to b-values.
|
| 106 |
+
gof: Goodness of fit threshold (R2) to trigger outlier removal.
|
| 107 |
+
limit_dif: b-value threshold separating perfusion (low b) and diffusion (high b) regimes.
|
| 108 |
+
model_type: 'linear' for Mono-exponential (IVIM), 'quadratic' for Kurtosis (IVIM-DKI).
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
r2_new: Combined R2 score.
|
| 112 |
+
D: Diffusion coefficient.
|
| 113 |
+
f: Perfusion fraction.
|
| 114 |
+
D_pse: Pseudo-diffusion coefficient (D*).
|
| 115 |
+
K: Kurtosis (0 if model_type is 'linear').
|
| 116 |
+
"""
|
| 117 |
+
vec_b = np.array(b_values)
|
| 118 |
+
vec_S = np.array(signal_values)
|
| 119 |
+
|
| 120 |
+
# Normalize signal
|
| 121 |
+
vec_S[np.isnan(vec_S)] = 0
|
| 122 |
+
S0 = np.max(vec_S)
|
| 123 |
+
if S0 == 0:
|
| 124 |
+
return 0, 0, 0, 0, 0
|
| 125 |
+
|
| 126 |
+
vec_S = vec_S / S0
|
| 127 |
+
|
| 128 |
+
# Avoid log(0) issues
|
| 129 |
+
vec_S[vec_S <= 0] = 1e-10
|
| 130 |
+
vec_S_log = np.log(vec_S)
|
| 131 |
+
|
| 132 |
+
# Split data into vascular (low b) and cellular (high b) regimes
|
| 133 |
+
try:
|
| 134 |
+
limite_dif_idx = np.where(vec_b >= limit_dif)[0][0]
|
| 135 |
+
except IndexError:
|
| 136 |
+
limite_dif_idx = len(vec_b)
|
| 137 |
+
|
| 138 |
+
vec_b_vas = vec_b[:limite_dif_idx + 1]
|
| 139 |
+
vec_b_cel = vec_b[limite_dif_idx:]
|
| 140 |
+
|
| 141 |
+
vec_S_vas = vec_S[:limite_dif_idx + 1]
|
| 142 |
+
vec_S_cel = vec_S[limite_dif_idx:]
|
| 143 |
+
vec_S_cel_log = vec_S_log[limite_dif_idx:]
|
| 144 |
+
|
| 145 |
+
# 1. Fit Diffusion/Kurtosis (High b-values)
|
| 146 |
+
# Perform initial fit on the full dataset to check overall quality
|
| 147 |
+
a, r, data = _fit_pixel(vec_b, vec_S_log, model_type, gof_threshold=gof)
|
| 148 |
+
|
| 149 |
+
if r < gof:
|
| 150 |
+
# If global fit is poor, proceed with Segmented Fitting
|
| 151 |
+
|
| 152 |
+
# Step 1: Fit High-b values (Diffusion regime)
|
| 153 |
+
a, r, data = _fit_pixel(vec_b_cel, vec_S_cel_log, model_type, gof_threshold=gof)
|
| 154 |
+
|
| 155 |
+
# Calculate Diffusion (D) and Kurtosis (K) from high-b fit
|
| 156 |
+
D = -a[1]
|
| 157 |
+
if model_type == 'quadratic' and D != 0:
|
| 158 |
+
K = (a[2] / D**2) * 6
|
| 159 |
+
else:
|
| 160 |
+
K = 0
|
| 161 |
+
|
| 162 |
+
# Calculate Perfusion Fraction (f) from intercept
|
| 163 |
+
# Intercept a[0] corresponds to ln(1-f)
|
| 164 |
+
uno_menos_f = np.exp(a[0])
|
| 165 |
+
f = 1 - uno_menos_f
|
| 166 |
+
|
| 167 |
+
# Step 2: Extrapolate diffusion contribution to low-b values
|
| 168 |
+
# S_diff_extrapolated = exp(a2*b^2 + a1*b + a0)
|
| 169 |
+
if model_type == 'linear':
|
| 170 |
+
diffusion_contribution = np.exp(a[1]*vec_b_vas + a[0])
|
| 171 |
+
else:
|
| 172 |
+
diffusion_contribution = np.exp(a[2]*(vec_b_vas**2) + a[1]*vec_b_vas + a[0])
|
| 173 |
+
|
| 174 |
+
# Step 3: Subtract Diffusion from Total Signal to isolate Perfusion
|
| 175 |
+
# S_perfusion = S_total - S_diffusion
|
| 176 |
+
y4 = vec_S_vas - diffusion_contribution
|
| 177 |
+
|
| 178 |
+
# Ensure residuals are positive for logarithmic fitting
|
| 179 |
+
if np.min(y4) < 0:
|
| 180 |
+
y4 = y4 + abs(np.min(y4))
|
| 181 |
+
|
| 182 |
+
# Prepare for Perfusion fit (log of residuals)
|
| 183 |
+
y5 = np.zeros(len(y4))
|
| 184 |
+
valid_indices = []
|
| 185 |
+
for i in range(len(y4)):
|
| 186 |
+
if y4[i] > 0:
|
| 187 |
+
y5[i] = np.log(y4[i])
|
| 188 |
+
valid_indices.append(i)
|
| 189 |
+
|
| 190 |
+
if len(valid_indices) > 2:
|
| 191 |
+
y5_clean = y5[valid_indices]
|
| 192 |
+
vec_b_vas_clean = vec_b_vas[valid_indices]
|
| 193 |
+
|
| 194 |
+
# Step 4: Fit Perfusion (Pseudo-diffusion) using a linear model
|
| 195 |
+
A_param, R_param, _ = _fit_pixel(vec_b_vas_clean, y5_clean, 'linear', gof_threshold=gof)
|
| 196 |
+
|
| 197 |
+
D_pse = -A_param[1]
|
| 198 |
+
r2_new = media(R_param, r, 1)
|
| 199 |
+
else:
|
| 200 |
+
# Fallback: Failed to isolate perfusion component (D*)
|
| 201 |
+
# But we keep D, f, K from the high-b fit
|
| 202 |
+
D_pse = 0
|
| 203 |
+
r2_new = r
|
| 204 |
+
|
| 205 |
+
else:
|
| 206 |
+
# Good global fit: Assume Mono-exponential / Kurtosis model fits entire range
|
| 207 |
+
D = -a[1]
|
| 208 |
+
f = 0
|
| 209 |
+
if model_type == 'quadratic' and D != 0:
|
| 210 |
+
K = (a[2] / D**2) * 6
|
| 211 |
+
else:
|
| 212 |
+
K = 0
|
| 213 |
+
D_pse = 0
|
| 214 |
+
r2_new = r
|
| 215 |
+
|
| 216 |
+
return r2_new, D, f, D_pse, K
|
| 217 |
+
|
| 218 |
+
def process_slice_parallel(b_values, slice_data, mask=None, gof=0.9, limit_dif=180, model_type='linear', n_jobs=-1):
|
| 219 |
+
"""
|
| 220 |
+
Processes an entire 2D slice (Rows x Cols x b-values) in parallel to extract IVIM/DKI parameters.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
b_values: Array of b-values.
|
| 224 |
+
slice_data: 3D array (Rows, Cols, b-values) containing signal intensities.
|
| 225 |
+
mask: Binary mask (Rows, Cols) indicating pixels to process. If None, processes all.
|
| 226 |
+
gof: Goodness of fit threshold.
|
| 227 |
+
limit_dif: b-value threshold for segmentation.
|
| 228 |
+
model_type: 'linear' or 'quadratic'.
|
| 229 |
+
n_jobs: Number of parallel jobs. -1 uses all available CPUs.
|
| 230 |
+
|
| 231 |
+
Returns:
|
| 232 |
+
Tuple of 2D maps: (R2, D, f, D*, K)
|
| 233 |
+
"""
|
| 234 |
+
rows, cols, n_b = slice_data.shape
|
| 235 |
+
|
| 236 |
+
if mask is None:
|
| 237 |
+
mask = np.ones((rows, cols), dtype=bool)
|
| 238 |
+
|
| 239 |
+
# Prepare list of tasks for parallel execution
|
| 240 |
+
tasks = []
|
| 241 |
+
coords = []
|
| 242 |
+
|
| 243 |
+
for i in range(rows):
|
| 244 |
+
for j in range(cols):
|
| 245 |
+
if mask[i, j]:
|
| 246 |
+
signal = slice_data[i, j, :]
|
| 247 |
+
tasks.append((b_values, signal, gof, limit_dif, model_type))
|
| 248 |
+
coords.append((i, j))
|
| 249 |
+
|
| 250 |
+
if not tasks:
|
| 251 |
+
# Return empty maps if no pixels to process
|
| 252 |
+
return (np.zeros((rows, cols)) for _ in range(5))
|
| 253 |
+
|
| 254 |
+
# Determine number of jobs
|
| 255 |
+
if n_jobs == -1:
|
| 256 |
+
n_jobs = multiprocessing.cpu_count()
|
| 257 |
+
|
| 258 |
+
print(f"Processing {len(tasks)} pixels using {n_jobs} threads...")
|
| 259 |
+
|
| 260 |
+
# Execute parallel processing
|
| 261 |
+
# Note: 'threading' backend is preferred on Windows to avoid pickling overhead and issues with local functions
|
| 262 |
+
results = Parallel(n_jobs=n_jobs, backend="threading", verbose=1)(
|
| 263 |
+
delayed(calculate_ivim_params)(*t) for t in tasks
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
# Reconstruct 2D parameter maps from flat results
|
| 267 |
+
map_R2 = np.zeros((rows, cols))
|
| 268 |
+
map_D = np.zeros((rows, cols))
|
| 269 |
+
map_f = np.zeros((rows, cols))
|
| 270 |
+
map_D_star = np.zeros((rows, cols))
|
| 271 |
+
map_K = np.zeros((rows, cols))
|
| 272 |
+
|
| 273 |
+
for (i, j), (r2, D, f, D_star, K) in zip(coords, results):
|
| 274 |
+
map_R2[i, j] = r2
|
| 275 |
+
map_D[i, j] = D
|
| 276 |
+
map_f[i, j] = f
|
| 277 |
+
map_D_star[i, j] = D_star
|
| 278 |
+
map_K[i, j] = K
|
| 279 |
+
|
| 280 |
+
return map_R2, map_D, map_f, map_D_star, map_K
|
src/ml_models.py
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from sklearn.ensemble import RandomForestRegressor, ExtraTreesRegressor
|
| 4 |
+
from sklearn.neural_network import MLPRegressor
|
| 5 |
+
from sklearn.linear_model import LinearRegression, Ridge, Lasso, ElasticNet
|
| 6 |
+
from sklearn.svm import SVR
|
| 7 |
+
from sklearn.model_selection import train_test_split, cross_val_score
|
| 8 |
+
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
|
| 9 |
+
import joblib
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
class IVIMRegressor:
|
| 13 |
+
"""
|
| 14 |
+
Machine Learning wrapper for estimating IVIM/DKI parameters from diffusion MRI signals.
|
| 15 |
+
|
| 16 |
+
This class provides a unified interface for training and applying various regression models
|
| 17 |
+
(Random Forest, Extra Trees, MLP, etc.) to map signal attenuation curves directly to
|
| 18 |
+
tissue parameters (D, f, D*, K), bypassing iterative non-linear least squares fitting.
|
| 19 |
+
|
| 20 |
+
Supported architectures:
|
| 21 |
+
- 'random_forest': Robust baseline, handles noise well.
|
| 22 |
+
- 'extra_trees': Often faster and slightly more accurate than RF. In our experiments, this model showed superior robustness to noise.
|
| 23 |
+
- 'mlp': Multi-layer Perceptron for capturing complex non-linear mappings.
|
| 24 |
+
- 'xgboost': Gradient boosting (requires xgboost package).
|
| 25 |
+
- 'svr': Support Vector Regression.
|
| 26 |
+
"""
|
| 27 |
+
|
| 28 |
+
def __init__(self, model_type='extra_trees', params=None):
|
| 29 |
+
self.model_type = model_type
|
| 30 |
+
self.params = params if params else {}
|
| 31 |
+
self.model = self._build_model()
|
| 32 |
+
|
| 33 |
+
def _build_model(self):
|
| 34 |
+
if self.model_type == 'random_forest':
|
| 35 |
+
# Default params from paper/notebook
|
| 36 |
+
n_estimators = self.params.get('n_estimators', 100)
|
| 37 |
+
return RandomForestRegressor(n_estimators=n_estimators, random_state=42, n_jobs=-1)
|
| 38 |
+
|
| 39 |
+
elif self.model_type == 'extra_trees':
|
| 40 |
+
n_estimators = self.params.get('n_estimators', 100)
|
| 41 |
+
return ExtraTreesRegressor(n_estimators=n_estimators, random_state=42, n_jobs=-1)
|
| 42 |
+
|
| 43 |
+
elif self.model_type == 'mlp':
|
| 44 |
+
hidden_layer_sizes = self.params.get('hidden_layer_sizes', (100, 50))
|
| 45 |
+
return MLPRegressor(hidden_layer_sizes=hidden_layer_sizes, max_iter=500, random_state=42)
|
| 46 |
+
|
| 47 |
+
elif self.model_type == 'xgboost':
|
| 48 |
+
try:
|
| 49 |
+
from xgboost import XGBRegressor
|
| 50 |
+
return XGBRegressor(n_estimators=1000, learning_rate=0.01, n_jobs=-1, random_state=42)
|
| 51 |
+
except ImportError:
|
| 52 |
+
print("XGBoost not installed. Falling back to Random Forest.")
|
| 53 |
+
return RandomForestRegressor(n_estimators=100, random_state=42)
|
| 54 |
+
|
| 55 |
+
elif self.model_type == 'svr':
|
| 56 |
+
C = self.params.get('C', 100)
|
| 57 |
+
return SVR(C=C)
|
| 58 |
+
|
| 59 |
+
else:
|
| 60 |
+
raise ValueError(f"Unknown model type: {self.model_type}")
|
| 61 |
+
|
| 62 |
+
def train(self, X, y, test_size=0.2, verbose=True):
|
| 63 |
+
"""
|
| 64 |
+
Trains the regression model using the provided signal-parameter pairs.
|
| 65 |
+
|
| 66 |
+
Args:
|
| 67 |
+
X: Input feature matrix (Normalized Signal vs b-values). Shape: [n_samples, n_b_values]
|
| 68 |
+
y: Target parameter vector (e.g., Diffusion Coefficient D). Shape: [n_samples]
|
| 69 |
+
test_size: Fraction of data to reserve for validation (default: 0.2).
|
| 70 |
+
verbose: If True, prints training progress and validation metrics.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
Dictionary containing validation metrics (MAE, MSE, RMSE, R2).
|
| 74 |
+
"""
|
| 75 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
|
| 76 |
+
|
| 77 |
+
if verbose:
|
| 78 |
+
print(f"Training {self.model_type} on {len(X_train)} samples...")
|
| 79 |
+
|
| 80 |
+
self.model.fit(X_train, y_train)
|
| 81 |
+
|
| 82 |
+
# Evaluate
|
| 83 |
+
predictions = self.model.predict(X_test)
|
| 84 |
+
metrics = self._evaluate(y_test, predictions)
|
| 85 |
+
|
| 86 |
+
if verbose:
|
| 87 |
+
print("--- Validation Metrics ---")
|
| 88 |
+
for k, v in metrics.items():
|
| 89 |
+
print(f"{k}: {v:.6f}")
|
| 90 |
+
|
| 91 |
+
return metrics
|
| 92 |
+
|
| 93 |
+
def predict(self, X):
|
| 94 |
+
"""Predicts parameters for new data."""
|
| 95 |
+
return self.model.predict(X)
|
| 96 |
+
|
| 97 |
+
def _evaluate(self, y_true, y_pred):
|
| 98 |
+
return {
|
| 99 |
+
'MAE': mean_absolute_error(y_true, y_pred),
|
| 100 |
+
'MSE': mean_squared_error(y_true, y_pred),
|
| 101 |
+
'RMSE': np.sqrt(mean_squared_error(y_true, y_pred)),
|
| 102 |
+
'R2': r2_score(y_true, y_pred)
|
| 103 |
+
}
|
| 104 |
+
|
| 105 |
+
def save(self, filepath):
|
| 106 |
+
"""Saves the trained model to disk."""
|
| 107 |
+
joblib.dump(self.model, filepath)
|
| 108 |
+
print(f"Model saved to {filepath}")
|
| 109 |
+
|
| 110 |
+
def load(self, filepath):
|
| 111 |
+
"""Loads a trained model from disk."""
|
| 112 |
+
if os.path.exists(filepath):
|
| 113 |
+
self.model = joblib.load(filepath)
|
| 114 |
+
print(f"Model loaded from {filepath}")
|
| 115 |
+
else:
|
| 116 |
+
raise FileNotFoundError(f"Model file not found: {filepath}")
|
| 117 |
+
|
| 118 |
+
def load_training_data(data_dir, dataset_name='MR701'):
|
| 119 |
+
"""
|
| 120 |
+
Helper to load X and Y CSV files from the data directory.
|
| 121 |
+
Expected format: Data_X2_{dataset}.csv and Data_Y_{dataset}.csv
|
| 122 |
+
"""
|
| 123 |
+
x_path = os.path.join(data_dir, f'Data_X2_{dataset_name}.csv')
|
| 124 |
+
y_path = os.path.join(data_dir, f'Data_Y_{dataset_name}.csv')
|
| 125 |
+
|
| 126 |
+
if not os.path.exists(x_path) or not os.path.exists(y_path):
|
| 127 |
+
raise FileNotFoundError(f"Data files not found for {dataset_name} in {data_dir}")
|
| 128 |
+
|
| 129 |
+
X = np.loadtxt(x_path)
|
| 130 |
+
Y = np.loadtxt(y_path) # Assuming Y contains [D, f, D*, K] columns or similar
|
| 131 |
+
|
| 132 |
+
return X, Y
|
src/utils.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
|
| 3 |
+
def filtro_2D(veces, col, row, Matriz_R):
|
| 4 |
+
"""
|
| 5 |
+
Applies a 2D filter to the R2 matrix to smooth it.
|
| 6 |
+
"""
|
| 7 |
+
R_old = np.copy(Matriz_R)
|
| 8 |
+
R_new = np.copy(Matriz_R)
|
| 9 |
+
|
| 10 |
+
for k in range(veces):
|
| 11 |
+
for i in range(1, col-1):
|
| 12 |
+
for j in range(1, row-1):
|
| 13 |
+
w_A = (R_old[i-1, j] + R_old[i, j-1] + R_old[i+1, j] + R_old[i, j+1]) / 4
|
| 14 |
+
w_B = (R_old[i-1, j-1] + R_old[i+1, j-1] + R_old[i-1, j+1] + R_old[i+1, j+1]) / (4 * np.sqrt(2))
|
| 15 |
+
|
| 16 |
+
if (w_A + w_B) > 1:
|
| 17 |
+
R_new[i, j] = 1
|
| 18 |
+
else:
|
| 19 |
+
R_new[i, j] = 0
|
| 20 |
+
R_old = np.copy(R_new)
|
| 21 |
+
|
| 22 |
+
return R_new
|
| 23 |
+
|
| 24 |
+
def filtro_3D(veces, slices, col, row, Tensor_R):
|
| 25 |
+
"""
|
| 26 |
+
Applies a 3D filter to the R2 tensor.
|
| 27 |
+
"""
|
| 28 |
+
R_old = np.copy(Tensor_R)
|
| 29 |
+
R_new = np.copy(Tensor_R)
|
| 30 |
+
|
| 31 |
+
for u in range(veces):
|
| 32 |
+
for k in range(1, slices-1):
|
| 33 |
+
for i in range(1, col-1):
|
| 34 |
+
for j in range(1, row-1):
|
| 35 |
+
w_A = (R_old[i-1, j, k] + R_old[i, j-1, k] + R_old[i+1, j, k] + R_old[i, j+1, k] + R_old[i, j, k-1] + R_old[i, j, k+1]) / 6
|
| 36 |
+
w_B1 = (R_old[i-1, j, k-1] + R_old[i+1, j, k-1] + R_old[i-1, j, k+1] + R_old[i+1, j, k+1]) / (12 * np.sqrt(2))
|
| 37 |
+
w_B2 = (R_old[i-1, j-1, k] + R_old[i-1, j+1, k] + R_old[i+1, j-1, k] + R_old[i+1, j+1, k]) / (12 * np.sqrt(2))
|
| 38 |
+
w_B3 = (R_old[i, j-1, k-1] + R_old[i, j+1, k-1] + R_old[i, j-1, k+1] + R_old[i, j+1, k+1]) / (12 * np.sqrt(2))
|
| 39 |
+
w_C = (R_old[i-1, j-1, k+1] + R_old[i-1, j+1, k+1] + R_old[i+1, j-1, k+1] + R_old[i+1, j+1, k+1] + R_old[i-1, j-1, k-1] + R_old[i-1, j+1, k-1] + R_old[i+1, j-1 ,k-1] + R_old[i+1, j+1 ,k-1]) / (8 * np.sqrt(3))
|
| 40 |
+
w_B = w_B1 + w_B2 + w_B3
|
| 41 |
+
|
| 42 |
+
if (w_A + w_B + w_C) >= 1:
|
| 43 |
+
R_new[i, j, k] = 1
|
| 44 |
+
else:
|
| 45 |
+
R_new[i, j, k] = 0
|
| 46 |
+
R_old = np.copy(R_new)
|
| 47 |
+
|
| 48 |
+
return R_new
|
train_ml_demo.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import sys
|
| 4 |
+
import matplotlib.pyplot as plt
|
| 5 |
+
|
| 6 |
+
# Add src to path
|
| 7 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 8 |
+
|
| 9 |
+
from ml_models import IVIMRegressor
|
| 10 |
+
from ivim_model import calculate_ivim_params
|
| 11 |
+
|
| 12 |
+
def generate_synthetic_training_data(n_samples=1000):
|
| 13 |
+
"""
|
| 14 |
+
Generates synthetic data for training ML models.
|
| 15 |
+
"""
|
| 16 |
+
print(f"Generating {n_samples} synthetic samples...")
|
| 17 |
+
|
| 18 |
+
# b-values
|
| 19 |
+
b_values = np.array([0, 10, 20, 30, 50, 80, 100, 200, 400, 800, 1000])
|
| 20 |
+
|
| 21 |
+
# Random parameters
|
| 22 |
+
# D: 0.0005 to 0.003
|
| 23 |
+
D = np.random.uniform(0.0005, 0.003, n_samples)
|
| 24 |
+
# f: 0.05 to 0.3
|
| 25 |
+
f = np.random.uniform(0.05, 0.3, n_samples)
|
| 26 |
+
# D*: 0.005 to 0.05
|
| 27 |
+
D_star = np.random.uniform(0.005, 0.05, n_samples)
|
| 28 |
+
# K: 0 to 1.5
|
| 29 |
+
K = np.random.uniform(0, 1.5, n_samples)
|
| 30 |
+
|
| 31 |
+
X = []
|
| 32 |
+
for i in range(n_samples):
|
| 33 |
+
# Generate signal
|
| 34 |
+
term_diff = np.exp(-b_values * D[i] + (1/6) * (b_values**2) * (D[i]**2) * K[i])
|
| 35 |
+
term_perf = np.exp(-b_values * D_star[i])
|
| 36 |
+
S0 = 1000
|
| 37 |
+
signal = S0 * (f[i] * term_perf + (1 - f[i]) * term_diff)
|
| 38 |
+
|
| 39 |
+
# Add noise
|
| 40 |
+
noise = np.random.normal(0, 0.02 * S0, size=len(b_values))
|
| 41 |
+
signal_noisy = signal + noise
|
| 42 |
+
signal_noisy[signal_noisy < 0] = 0
|
| 43 |
+
|
| 44 |
+
# Normalize
|
| 45 |
+
signal_norm = signal_noisy / np.max(signal_noisy)
|
| 46 |
+
X.append(signal_norm)
|
| 47 |
+
|
| 48 |
+
X = np.array(X)
|
| 49 |
+
# Targets: Let's predict D for this example
|
| 50 |
+
y = D
|
| 51 |
+
|
| 52 |
+
return X, y, b_values
|
| 53 |
+
|
| 54 |
+
def main():
|
| 55 |
+
print("=== ML Model Training Demo ===")
|
| 56 |
+
|
| 57 |
+
# 1. Get Data (Synthetic for demo, but structure allows loading real data)
|
| 58 |
+
# In a real scenario, you would use:
|
| 59 |
+
# X, Y = load_training_data('data/MR701')
|
| 60 |
+
# y = Y[:, 0] # D column
|
| 61 |
+
|
| 62 |
+
X, y, b_values = generate_synthetic_training_data()
|
| 63 |
+
|
| 64 |
+
# 2. Initialize Model
|
| 65 |
+
# We use Random Forest as it was one of the best performers in the paper
|
| 66 |
+
regressor = IVIMRegressor(model_type='random_forest', params={'n_estimators': 50})
|
| 67 |
+
|
| 68 |
+
# 3. Train
|
| 69 |
+
print("\nTraining Random Forest to predict Diffusion Coefficient (D)...")
|
| 70 |
+
metrics = regressor.train(X, y)
|
| 71 |
+
|
| 72 |
+
# 4. Save Model
|
| 73 |
+
if not os.path.exists('models'):
|
| 74 |
+
os.makedirs('models')
|
| 75 |
+
regressor.save('models/rf_diffusion_model.joblib')
|
| 76 |
+
|
| 77 |
+
# 5. Compare with Analytical Fit on a few samples
|
| 78 |
+
print("\n--- Comparison: ML vs Analytical ---")
|
| 79 |
+
indices = np.random.choice(len(X), 5)
|
| 80 |
+
|
| 81 |
+
for idx in indices:
|
| 82 |
+
signal = X[idx]
|
| 83 |
+
true_val = y[idx]
|
| 84 |
+
|
| 85 |
+
# ML Prediction
|
| 86 |
+
ml_pred = regressor.predict([signal])[0]
|
| 87 |
+
|
| 88 |
+
# Analytical Fit
|
| 89 |
+
# Note: calculate_ivim_params expects raw signal, but we normalized.
|
| 90 |
+
# It handles normalization internally, so passing normalized is fine (S0=1).
|
| 91 |
+
_, d_ana, _, _, _ = calculate_ivim_params(b_values, signal, gof=0.9)
|
| 92 |
+
|
| 93 |
+
print(f"True D: {true_val:.6f} | ML: {ml_pred:.6f} | Analytical: {d_ana:.6f}")
|
| 94 |
+
|
| 95 |
+
if __name__ == "__main__":
|
| 96 |
+
main()
|
train_pretrained.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import joblib
|
| 3 |
+
import os
|
| 4 |
+
import sys
|
| 5 |
+
from sklearn.ensemble import ExtraTreesRegressor
|
| 6 |
+
from sklearn.model_selection import train_test_split
|
| 7 |
+
from sklearn.metrics import mean_squared_error
|
| 8 |
+
|
| 9 |
+
# Ensure output directory exists
|
| 10 |
+
if not os.path.exists('models'):
|
| 11 |
+
os.makedirs('models')
|
| 12 |
+
|
| 13 |
+
def generate_comprehensive_training_data(n_samples=100000):
|
| 14 |
+
"""
|
| 15 |
+
Generates a large, comprehensive dataset covering a wide range of b-values
|
| 16 |
+
and physiological parameters to create a robust pre-trained model.
|
| 17 |
+
"""
|
| 18 |
+
print(f"Generating {n_samples} synthetic samples for pre-training...")
|
| 19 |
+
|
| 20 |
+
# Extended set of b-values to cover various clinical protocols (Prostate, Brain, etc.)
|
| 21 |
+
# We include a superset of common b-values.
|
| 22 |
+
# NOTE: For the pre-trained model to be universal, it ideally needs to be trained
|
| 23 |
+
# on the SPECIFIC b-values of the target protocol.
|
| 24 |
+
# However, for a general purpose 'demo' model, we will use a standard clinical set.
|
| 25 |
+
# Users should ideally retrain for their specific protocol, but this serves as a strong baseline.
|
| 26 |
+
b_values = np.array([0, 10, 20, 30, 50, 80, 100, 150, 200, 400, 600, 800, 1000, 1500, 2000])
|
| 27 |
+
|
| 28 |
+
# Random parameters within broad physiological ranges
|
| 29 |
+
D = np.random.uniform(0.0001, 0.004, n_samples) # Diffusion coefficient (mm2/s)
|
| 30 |
+
f = np.random.uniform(0.01, 0.4, n_samples) # Perfusion fraction
|
| 31 |
+
D_star = np.random.uniform(0.005, 0.1, n_samples) # Pseudo-diffusion (mm2/s)
|
| 32 |
+
K = np.random.uniform(0, 2.5, n_samples) # Kurtosis (dimensionless)
|
| 33 |
+
|
| 34 |
+
X = []
|
| 35 |
+
Y = [] # Targets: [D, f, D*, K]
|
| 36 |
+
|
| 37 |
+
for i in range(n_samples):
|
| 38 |
+
# Signal model: IVIM-DKI
|
| 39 |
+
# S/S0 = f * exp(-b*D*) + (1-f) * exp(-b*D + 1/6 * b^2 * D^2 * K)
|
| 40 |
+
|
| 41 |
+
# Diffusion term with Kurtosis
|
| 42 |
+
# Note: We clip the exponent to avoid overflow/numerical issues at high b-values/K
|
| 43 |
+
exponent_diff = -b_values * D[i] + (1/6) * (b_values**2) * (D[i]**2) * K[i]
|
| 44 |
+
term_diff = np.exp(exponent_diff)
|
| 45 |
+
|
| 46 |
+
# Perfusion term
|
| 47 |
+
term_perf = np.exp(-b_values * D_star[i])
|
| 48 |
+
|
| 49 |
+
signal = f[i] * term_perf + (1 - f[i]) * term_diff
|
| 50 |
+
|
| 51 |
+
# Add Rician noise (approximated as Gaussian for simplicity in this large batch)
|
| 52 |
+
# SNR varying between 20 and 100
|
| 53 |
+
snr = np.random.uniform(20, 100)
|
| 54 |
+
sigma = 1.0 / snr
|
| 55 |
+
noise = np.random.normal(0, sigma, size=len(b_values))
|
| 56 |
+
|
| 57 |
+
signal_noisy = signal + noise
|
| 58 |
+
|
| 59 |
+
# Rician correction (magnitude)
|
| 60 |
+
signal_noisy = np.sqrt(signal_noisy**2 + noise**2) # Simple approximation
|
| 61 |
+
|
| 62 |
+
# Normalize (though it's already relative to S0=1)
|
| 63 |
+
signal_norm = signal_noisy / np.max(signal_noisy)
|
| 64 |
+
|
| 65 |
+
X.append(signal_norm)
|
| 66 |
+
Y.append([D[i], f[i], D_star[i], K[i]])
|
| 67 |
+
|
| 68 |
+
return np.array(X), np.array(Y), b_values
|
| 69 |
+
|
| 70 |
+
def train_and_save():
|
| 71 |
+
X, Y, b_values = generate_comprehensive_training_data(n_samples=50000)
|
| 72 |
+
|
| 73 |
+
print("Training ExtraTreesRegressor (Multi-output)...")
|
| 74 |
+
# Optimized for model size < 100MB for GitHub hosting
|
| 75 |
+
model = ExtraTreesRegressor(
|
| 76 |
+
n_estimators=50, # Reduced from 100
|
| 77 |
+
max_depth=20, # Limit depth to prevent massive trees
|
| 78 |
+
min_samples_leaf=5, # Prune leaves to reduce size
|
| 79 |
+
min_samples_split=10,
|
| 80 |
+
n_jobs=-1,
|
| 81 |
+
random_state=42,
|
| 82 |
+
verbose=1
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
model.fit(X, Y)
|
| 86 |
+
|
| 87 |
+
# Evaluate on a small subset
|
| 88 |
+
preds = model.predict(X[:1000])
|
| 89 |
+
mse = mean_squared_error(Y[:1000], preds)
|
| 90 |
+
print(f"Training MSE: {mse:.6f}")
|
| 91 |
+
|
| 92 |
+
# Save model
|
| 93 |
+
model_path = 'models/ivim_dki_extratrees.joblib'
|
| 94 |
+
joblib.dump(model, model_path, compress=3) # Compress to keep file size small
|
| 95 |
+
print(f"Model saved to {model_path}")
|
| 96 |
+
|
| 97 |
+
# Save b-values metadata so we know what protocol this model expects
|
| 98 |
+
np.save('models/b_values_config.npy', b_values)
|
| 99 |
+
print("Configuration saved.")
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
train_and_save()
|