Finished manual tutorial
Browse files- README.md +4 -3
- poetry.lock +69 -1
- pyment/models/utils/ensure_weights.py +4 -2
- pyment/preprocessing/conform.py +18 -22
- pyment/utils/download_file.py +14 -1
- pyproject.toml +1 -0
- scripts/predict_from_fastsurfer_folder.py +4 -5
- tutorials/evaluate_ixi_predictions.py +66 -0
README.md
CHANGED
|
@@ -21,7 +21,7 @@ sudo cp /usr/include/cudnn*.h /usr/local/cuda-11.2/include/
|
|
| 21 |
sudo cp -P /usr/lib/x86_64-linux-gnu/libcudnn*.so* /usr/local/cuda-11.2/lib64/
|
| 22 |
sudo ldconfig
|
| 23 |
```
|
| 24 |
-
Finally, we must configure the system paths
|
| 25 |
```
|
| 26 |
echo 'export CUDA_HOME=/usr/local/cuda-11.2' >> ~/.bashrc
|
| 27 |
echo 'export PATH=$CUDA_HOME/bin:$PATH' >> ~/.bashrc
|
|
@@ -135,8 +135,9 @@ After preprocessing, we can generate predictions for the IXI dataset using the s
|
|
| 135 |
```
|
| 136 |
eval $(poetry env activate)
|
| 137 |
```
|
| 138 |
-
Next, run the prediction-script:
|
| 139 |
```
|
| 140 |
-
|
|
|
|
| 141 |
```
|
| 142 |
</details>
|
|
|
|
| 21 |
sudo cp -P /usr/lib/x86_64-linux-gnu/libcudnn*.so* /usr/local/cuda-11.2/lib64/
|
| 22 |
sudo ldconfig
|
| 23 |
```
|
| 24 |
+
Finally, we must configure the system paths in .bashrc:
|
| 25 |
```
|
| 26 |
echo 'export CUDA_HOME=/usr/local/cuda-11.2' >> ~/.bashrc
|
| 27 |
echo 'export PATH=$CUDA_HOME/bin:$PATH' >> ~/.bashrc
|
|
|
|
| 135 |
```
|
| 136 |
eval $(poetry env activate)
|
| 137 |
```
|
| 138 |
+
Next, make an output-folder for the predictions and run the prediction-script:
|
| 139 |
```
|
| 140 |
+
mkdir ~/data/ixi/outputs
|
| 141 |
+
python scripts/predict_from_fastsurfer_folder.py ~/data/ixi/preprocessed -d ~/data/ixi/outputs/predictions.csv
|
| 142 |
```
|
| 143 |
</details>
|
poetry.lock
CHANGED
|
@@ -2714,6 +2714,74 @@ files = [
|
|
| 2714 |
[package.extras]
|
| 2715 |
diagrams = ["jinja2", "railroad-diagrams"]
|
| 2716 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2717 |
[[package]]
|
| 2718 |
name = "pytest"
|
| 2719 |
version = "8.3.3"
|
|
@@ -4088,4 +4156,4 @@ test = ["pytest", "pytest-cov"]
|
|
| 4088 |
[metadata]
|
| 4089 |
lock-version = "2.1"
|
| 4090 |
python-versions = "3.10.4"
|
| 4091 |
-
content-hash = "
|
|
|
|
| 2714 |
[package.extras]
|
| 2715 |
diagrams = ["jinja2", "railroad-diagrams"]
|
| 2716 |
|
| 2717 |
+
[[package]]
|
| 2718 |
+
name = "pyqt5"
|
| 2719 |
+
version = "5.15.11"
|
| 2720 |
+
description = "Python bindings for the Qt cross platform application toolkit"
|
| 2721 |
+
optional = false
|
| 2722 |
+
python-versions = ">=3.8"
|
| 2723 |
+
groups = ["main"]
|
| 2724 |
+
files = [
|
| 2725 |
+
{file = "PyQt5-5.15.11-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c8b03dd9380bb13c804f0bdb0f4956067f281785b5e12303d529f0462f9afdc2"},
|
| 2726 |
+
{file = "PyQt5-5.15.11-cp38-abi3-macosx_11_0_x86_64.whl", hash = "sha256:6cd75628f6e732b1ffcfe709ab833a0716c0445d7aec8046a48d5843352becb6"},
|
| 2727 |
+
{file = "PyQt5-5.15.11-cp38-abi3-manylinux_2_17_x86_64.whl", hash = "sha256:cd672a6738d1ae33ef7d9efa8e6cb0a1525ecf53ec86da80a9e1b6ec38c8d0f1"},
|
| 2728 |
+
{file = "PyQt5-5.15.11-cp38-abi3-win32.whl", hash = "sha256:76be0322ceda5deecd1708a8d628e698089a1cea80d1a49d242a6d579a40babd"},
|
| 2729 |
+
{file = "PyQt5-5.15.11-cp38-abi3-win_amd64.whl", hash = "sha256:bdde598a3bb95022131a5c9ea62e0a96bd6fb28932cc1619fd7ba211531b7517"},
|
| 2730 |
+
{file = "PyQt5-5.15.11.tar.gz", hash = "sha256:fda45743ebb4a27b4b1a51c6d8ef455c4c1b5d610c90d2934c7802b5c1557c52"},
|
| 2731 |
+
]
|
| 2732 |
+
|
| 2733 |
+
[package.dependencies]
|
| 2734 |
+
PyQt5-Qt5 = ">=5.15.2,<5.16.0"
|
| 2735 |
+
PyQt5-sip = ">=12.15,<13"
|
| 2736 |
+
|
| 2737 |
+
[[package]]
|
| 2738 |
+
name = "pyqt5-qt5"
|
| 2739 |
+
version = "5.15.17"
|
| 2740 |
+
description = "The subset of a Qt installation needed by PyQt5."
|
| 2741 |
+
optional = false
|
| 2742 |
+
python-versions = "*"
|
| 2743 |
+
groups = ["main"]
|
| 2744 |
+
files = [
|
| 2745 |
+
{file = "PyQt5_Qt5-5.15.17-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:d8b8094108e748b4bbd315737cfed81291d2d228de43278f0b8bd7d2b808d2b9"},
|
| 2746 |
+
{file = "PyQt5_Qt5-5.15.17-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b68628f9b8261156f91d2f72ebc8dfb28697c4b83549245d9a68195bd2d74f0c"},
|
| 2747 |
+
{file = "PyQt5_Qt5-5.15.17-py3-none-manylinux2014_x86_64.whl", hash = "sha256:b018f75d1cc61146396fa5af14da1db77c5d6318030e5e366f09ffdf7bd358d8"},
|
| 2748 |
+
]
|
| 2749 |
+
|
| 2750 |
+
[[package]]
|
| 2751 |
+
name = "pyqt5-sip"
|
| 2752 |
+
version = "12.17.1"
|
| 2753 |
+
description = "The sip module support for PyQt5"
|
| 2754 |
+
optional = false
|
| 2755 |
+
python-versions = ">=3.9"
|
| 2756 |
+
groups = ["main"]
|
| 2757 |
+
files = [
|
| 2758 |
+
{file = "pyqt5_sip-12.17.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bd4f73b1ebd5e0bd8d4539a8e55132318efc70a92f648ef0f9d93329ad50adeb"},
|
| 2759 |
+
{file = "pyqt5_sip-12.17.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b52e85520dbfe5c3d0c0c47aa2c10fc1853d892ae60ebebfe8154b052394da50"},
|
| 2760 |
+
{file = "pyqt5_sip-12.17.1-cp310-cp310-win32.whl", hash = "sha256:71a67e2c9b77a74e943e220db0a341c702fd9bcf83c4a2e07342dfce691742ae"},
|
| 2761 |
+
{file = "pyqt5_sip-12.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:2710effb921bf6955b902779c763d890bb593da6325f0e128a0e3991cc855e9f"},
|
| 2762 |
+
{file = "pyqt5_sip-12.17.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5134d637efadd108a70306bab55b3d7feaa951bf6b8162161a67ae847bea9130"},
|
| 2763 |
+
{file = "pyqt5_sip-12.17.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:155cf755266c8bf64428916e2ff720d5efa1aec003d4ccc40c003b147dbdac03"},
|
| 2764 |
+
{file = "pyqt5_sip-12.17.1-cp311-cp311-win32.whl", hash = "sha256:9dfa7fe4ac93b60004430699c4bf56fef842a356d64dfea7cbc6d580d0427d6d"},
|
| 2765 |
+
{file = "pyqt5_sip-12.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:2ddd214cf40119b86942a5da2da5a7345334955ab00026d8dcc56326b30e6d3c"},
|
| 2766 |
+
{file = "pyqt5_sip-12.17.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c362606de782d2d46374a38523632786f145c517ee62de246a6069e5f2c5f336"},
|
| 2767 |
+
{file = "pyqt5_sip-12.17.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:140cc582151456103ebb149fefc678f3cae803e7720733db51212af5219cd45c"},
|
| 2768 |
+
{file = "pyqt5_sip-12.17.1-cp312-cp312-win32.whl", hash = "sha256:9dc1f1525d4d42c080f6cfdfc70d78239f8f67b0a48ea0745497251d8d848b1d"},
|
| 2769 |
+
{file = "pyqt5_sip-12.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:d5e2e9e175559017cd161d661e0ee0b551684f824bb90800c5a8c8a3bea9355e"},
|
| 2770 |
+
{file = "pyqt5_sip-12.17.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:9ebbd7769ccdaaa6295e9c872553b6cde17f38e171056f17300d8af9a14d1fc8"},
|
| 2771 |
+
{file = "pyqt5_sip-12.17.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b023da906a70af2cf5e6fc1932f441ede07530f3e164dd52c6c2bb5ab7c6f424"},
|
| 2772 |
+
{file = "pyqt5_sip-12.17.1-cp313-cp313-win32.whl", hash = "sha256:36dbef482bd638786b909f3bda65b7b3d5cbd6cbf16797496de38bae542da307"},
|
| 2773 |
+
{file = "pyqt5_sip-12.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:d04e5551bbc3bcec98acc63b3b0618ddcbf31ff107349225b516fe7e7c0a7c8b"},
|
| 2774 |
+
{file = "pyqt5_sip-12.17.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:c49918287e1ad77956d1589f1d3d432a0be7630c646ea02cf652413a48e14458"},
|
| 2775 |
+
{file = "pyqt5_sip-12.17.1-cp314-cp314-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:944a4bf1e1ee18ad03a54964c1c6433fb6de582313a1f0b17673e7203e22fc83"},
|
| 2776 |
+
{file = "pyqt5_sip-12.17.1-cp314-cp314-win32.whl", hash = "sha256:99a2935fd662a67748625b1e6ffa0a2d1f2da068b9df6db04fa59a4a5d4ee613"},
|
| 2777 |
+
{file = "pyqt5_sip-12.17.1-cp314-cp314-win_amd64.whl", hash = "sha256:aaa33232cc80793d14fdb3b149b27eec0855612ed66aad480add5ac49b9cee63"},
|
| 2778 |
+
{file = "pyqt5_sip-12.17.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6fdc457bd528e5909a5893db0a7dee0066d5f22e08234c9152db0ae6df9a367f"},
|
| 2779 |
+
{file = "pyqt5_sip-12.17.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:06ea59741c1bffb198d99b00d26594791f45fb11b10f774c8105aea5962e3835"},
|
| 2780 |
+
{file = "pyqt5_sip-12.17.1-cp39-cp39-win32.whl", hash = "sha256:b9ef23869d35c6740a95fcb1f387f4aea8d8fac80e19096fbaf1a64e18409c4b"},
|
| 2781 |
+
{file = "pyqt5_sip-12.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:90eed15f19557dfab22e68c7763e3690053cc8dd30d93ade2523d1b5a04a87be"},
|
| 2782 |
+
{file = "pyqt5_sip-12.17.1.tar.gz", hash = "sha256:0eab72bcb628f1926bf5b9ac51259d4fa18e8b2a81d199071135458f7d087ea8"},
|
| 2783 |
+
]
|
| 2784 |
+
|
| 2785 |
[[package]]
|
| 2786 |
name = "pytest"
|
| 2787 |
version = "8.3.3"
|
|
|
|
| 4156 |
[metadata]
|
| 4157 |
lock-version = "2.1"
|
| 4158 |
python-versions = "3.10.4"
|
| 4159 |
+
content-hash = "f4ed40969ecb1dec94a4d0f7f994e17cea56a8e86b72ea95a021f5247942d8a1"
|
pyment/models/utils/ensure_weights.py
CHANGED
|
@@ -28,7 +28,8 @@ def _lookup_identifier(identifier: str, local_cache: str) -> str:
|
|
| 28 |
local_cache,
|
| 29 |
f'{identifier}.data-00000-of-00001'
|
| 30 |
),
|
| 31 |
-
description=f'Downloading {identifier} data'
|
|
|
|
| 32 |
)
|
| 33 |
download_file(
|
| 34 |
url=BASE_URL + '/' + IDENTIFIERS[identifier]['index'],
|
|
@@ -36,7 +37,8 @@ def _lookup_identifier(identifier: str, local_cache: str) -> str:
|
|
| 36 |
local_cache,
|
| 37 |
f'{identifier}.index'
|
| 38 |
),
|
| 39 |
-
description=f'Downloading {identifier} index'
|
|
|
|
| 40 |
)
|
| 41 |
|
| 42 |
return os.path.join(local_cache, identifier)
|
|
|
|
| 28 |
local_cache,
|
| 29 |
f'{identifier}.data-00000-of-00001'
|
| 30 |
),
|
| 31 |
+
description=f'Downloading {identifier} data',
|
| 32 |
+
decode_github=True
|
| 33 |
)
|
| 34 |
download_file(
|
| 35 |
url=BASE_URL + '/' + IDENTIFIERS[identifier]['index'],
|
|
|
|
| 37 |
local_cache,
|
| 38 |
f'{identifier}.index'
|
| 39 |
),
|
| 40 |
+
description=f'Downloading {identifier} index',
|
| 41 |
+
decode_github=True
|
| 42 |
)
|
| 43 |
|
| 44 |
return os.path.join(local_cache, identifier)
|
pyment/preprocessing/conform.py
CHANGED
|
@@ -3,14 +3,10 @@ import numpy as np
|
|
| 3 |
from typing import Tuple
|
| 4 |
|
| 5 |
|
| 6 |
-
logging.basicConfig(
|
| 7 |
-
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
|
| 8 |
-
level=logging.DEBUG
|
| 9 |
-
)
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
def _pad_if_necessary(
|
| 13 |
-
image: np.ndarray,
|
| 14 |
target_shape: Tuple[int, int, int]
|
| 15 |
) -> np.ndarray:
|
| 16 |
pad = [(0, 0)] * 3
|
|
@@ -25,7 +21,7 @@ def _pad_if_necessary(
|
|
| 25 |
return np.pad(image, tuple(pad), mode='constant', constant_values=0)
|
| 26 |
|
| 27 |
def _crop_if_necessary(
|
| 28 |
-
image: np.ndarray,
|
| 29 |
target_shape: Tuple[int, int, int]
|
| 30 |
) -> np.ndarray:
|
| 31 |
nonzero = np.where(image != 0)
|
|
@@ -36,7 +32,7 @@ def _crop_if_necessary(
|
|
| 36 |
if image.shape[dim] > target_shape[dim]:
|
| 37 |
extrafluous = target_shape[dim] / 2
|
| 38 |
center = np.round(np.mean([
|
| 39 |
-
np.amin(nonzero[dim]),
|
| 40 |
np.amax(nonzero[dim])
|
| 41 |
]))
|
| 42 |
min_idx = int(center - extrafluous)
|
|
@@ -45,7 +41,7 @@ def _crop_if_necessary(
|
|
| 45 |
if min_idx < 0:
|
| 46 |
max_idx -= min_idx
|
| 47 |
min_idx = 0
|
| 48 |
-
|
| 49 |
if max_idx > image.shape[dim]:
|
| 50 |
diff = max_idx - image.shape[dim]
|
| 51 |
min_idx -= diff
|
|
@@ -58,7 +54,7 @@ def _crop_if_necessary(
|
|
| 58 |
return image
|
| 59 |
|
| 60 |
def _center_crop_or_pad(
|
| 61 |
-
image: np.ndarray,
|
| 62 |
target_shape: Tuple[int, int, int]
|
| 63 |
) -> np.ndarray:
|
| 64 |
image = _pad_if_necessary(image, target_shape)
|
|
@@ -67,22 +63,22 @@ def _center_crop_or_pad(
|
|
| 67 |
return image
|
| 68 |
|
| 69 |
def conform(
|
| 70 |
-
image: np.ndarray,
|
| 71 |
relative_normalization: bool = False
|
| 72 |
) -> np.ndarray:
|
| 73 |
-
"""Conforms an image to the expected format if necessary. The
|
| 74 |
expected format means an image of shape 224x192x224 with voxel
|
| 75 |
-
values spanning the range [0, 1]. If the image has a redundant
|
| 76 |
-
channel-dimension, this is removed. If the image is currently too
|
| 77 |
large along any dimension, a "central" crop is made by determining
|
| 78 |
-
the bound of the brain (e.g. non-zero voxels) and retaining
|
| 79 |
equivalent padding on each side. If the image is currently too small
|
| 80 |
-
along either axis, the image is zero-padded equally on each side.
|
| 81 |
If the voxel-values does not fall within the expected range, they
|
| 82 |
-
are normalized. If the relative_normalization-flag is set, the
|
| 83 |
values are normalized by dividing by the image max, otherwise they
|
| 84 |
-
are divided by 255. However, if the largest value is >255, this
|
| 85 |
-
indicates that the image has not been processed with FastSurfer,
|
| 86 |
and an error is raised.
|
| 87 |
|
| 88 |
Parameters
|
|
@@ -106,7 +102,7 @@ def conform(
|
|
| 106 |
if len(image.shape) == 4:
|
| 107 |
if image.shape[-1] != 1:
|
| 108 |
raise ValueError(f'Unable to handle multi-channel images')
|
| 109 |
-
|
| 110 |
image = image[...,0]
|
| 111 |
|
| 112 |
if image.shape != (224, 192, 224):
|
|
@@ -114,7 +110,7 @@ def conform(
|
|
| 114 |
|
| 115 |
logger.debug('Conformed image shape: %s', str(image.shape))
|
| 116 |
logger.debug(
|
| 117 |
-
'Original image voxel value range: %f-%f',
|
| 118 |
np.amin(image), np.amax(image)
|
| 119 |
)
|
| 120 |
|
|
@@ -124,8 +120,8 @@ def conform(
|
|
| 124 |
image *= 255.0
|
| 125 |
|
| 126 |
logger.debug(
|
| 127 |
-
'Conformed image voxel value range: %f-%f',
|
| 128 |
np.amin(image), np.amax(image)
|
| 129 |
)
|
| 130 |
|
| 131 |
-
return image
|
|
|
|
| 3 |
from typing import Tuple
|
| 4 |
|
| 5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 6 |
logger = logging.getLogger(__name__)
|
| 7 |
|
| 8 |
def _pad_if_necessary(
|
| 9 |
+
image: np.ndarray,
|
| 10 |
target_shape: Tuple[int, int, int]
|
| 11 |
) -> np.ndarray:
|
| 12 |
pad = [(0, 0)] * 3
|
|
|
|
| 21 |
return np.pad(image, tuple(pad), mode='constant', constant_values=0)
|
| 22 |
|
| 23 |
def _crop_if_necessary(
|
| 24 |
+
image: np.ndarray,
|
| 25 |
target_shape: Tuple[int, int, int]
|
| 26 |
) -> np.ndarray:
|
| 27 |
nonzero = np.where(image != 0)
|
|
|
|
| 32 |
if image.shape[dim] > target_shape[dim]:
|
| 33 |
extrafluous = target_shape[dim] / 2
|
| 34 |
center = np.round(np.mean([
|
| 35 |
+
np.amin(nonzero[dim]),
|
| 36 |
np.amax(nonzero[dim])
|
| 37 |
]))
|
| 38 |
min_idx = int(center - extrafluous)
|
|
|
|
| 41 |
if min_idx < 0:
|
| 42 |
max_idx -= min_idx
|
| 43 |
min_idx = 0
|
| 44 |
+
|
| 45 |
if max_idx > image.shape[dim]:
|
| 46 |
diff = max_idx - image.shape[dim]
|
| 47 |
min_idx -= diff
|
|
|
|
| 54 |
return image
|
| 55 |
|
| 56 |
def _center_crop_or_pad(
|
| 57 |
+
image: np.ndarray,
|
| 58 |
target_shape: Tuple[int, int, int]
|
| 59 |
) -> np.ndarray:
|
| 60 |
image = _pad_if_necessary(image, target_shape)
|
|
|
|
| 63 |
return image
|
| 64 |
|
| 65 |
def conform(
|
| 66 |
+
image: np.ndarray,
|
| 67 |
relative_normalization: bool = False
|
| 68 |
) -> np.ndarray:
|
| 69 |
+
"""Conforms an image to the expected format if necessary. The
|
| 70 |
expected format means an image of shape 224x192x224 with voxel
|
| 71 |
+
values spanning the range [0, 1]. If the image has a redundant
|
| 72 |
+
channel-dimension, this is removed. If the image is currently too
|
| 73 |
large along any dimension, a "central" crop is made by determining
|
| 74 |
+
the bound of the brain (e.g. non-zero voxels) and retaining
|
| 75 |
equivalent padding on each side. If the image is currently too small
|
| 76 |
+
along either axis, the image is zero-padded equally on each side.
|
| 77 |
If the voxel-values does not fall within the expected range, they
|
| 78 |
+
are normalized. If the relative_normalization-flag is set, the
|
| 79 |
values are normalized by dividing by the image max, otherwise they
|
| 80 |
+
are divided by 255. However, if the largest value is >255, this
|
| 81 |
+
indicates that the image has not been processed with FastSurfer,
|
| 82 |
and an error is raised.
|
| 83 |
|
| 84 |
Parameters
|
|
|
|
| 102 |
if len(image.shape) == 4:
|
| 103 |
if image.shape[-1] != 1:
|
| 104 |
raise ValueError(f'Unable to handle multi-channel images')
|
| 105 |
+
|
| 106 |
image = image[...,0]
|
| 107 |
|
| 108 |
if image.shape != (224, 192, 224):
|
|
|
|
| 110 |
|
| 111 |
logger.debug('Conformed image shape: %s', str(image.shape))
|
| 112 |
logger.debug(
|
| 113 |
+
'Original image voxel value range: %f-%f',
|
| 114 |
np.amin(image), np.amax(image)
|
| 115 |
)
|
| 116 |
|
|
|
|
| 120 |
image *= 255.0
|
| 121 |
|
| 122 |
logger.debug(
|
| 123 |
+
'Conformed image voxel value range: %f-%f',
|
| 124 |
np.amin(image), np.amax(image)
|
| 125 |
)
|
| 126 |
|
| 127 |
+
return image
|
pyment/utils/download_file.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import math
|
| 2 |
import requests
|
| 3 |
from tqdm import tqdm
|
|
@@ -6,7 +8,8 @@ from tqdm import tqdm
|
|
| 6 |
def download_file(
|
| 7 |
url: str,
|
| 8 |
destination: str,
|
| 9 |
-
description: str = None
|
|
|
|
| 10 |
) -> str:
|
| 11 |
with requests.get(url, stream=True) as response:
|
| 12 |
response.raise_for_status()
|
|
@@ -28,3 +31,13 @@ def download_file(
|
|
| 28 |
with open(destination, 'wb') as f:
|
| 29 |
for chunk in progress_bar:
|
| 30 |
f.write(chunk)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import json
|
| 3 |
import math
|
| 4 |
import requests
|
| 5 |
from tqdm import tqdm
|
|
|
|
| 8 |
def download_file(
|
| 9 |
url: str,
|
| 10 |
destination: str,
|
| 11 |
+
description: str = None,
|
| 12 |
+
decode_github: bool = False
|
| 13 |
) -> str:
|
| 14 |
with requests.get(url, stream=True) as response:
|
| 15 |
response.raise_for_status()
|
|
|
|
| 31 |
with open(destination, 'wb') as f:
|
| 32 |
for chunk in progress_bar:
|
| 33 |
f.write(chunk)
|
| 34 |
+
|
| 35 |
+
if decode_github:
|
| 36 |
+
# Assumes a JSON file downloaded from GitHub
|
| 37 |
+
with open(destination, 'rb') as f:
|
| 38 |
+
data = json.load(f)
|
| 39 |
+
|
| 40 |
+
data = base64.b64decode(data['content'])
|
| 41 |
+
|
| 42 |
+
with open(destination, 'wb') as f:
|
| 43 |
+
f.write(data)
|
pyproject.toml
CHANGED
|
@@ -37,6 +37,7 @@ pytest = "8.3.3"
|
|
| 37 |
scikit-learn = "1.5.1"
|
| 38 |
xlrd = "2.0.1"
|
| 39 |
pydantic = "2.10"
|
|
|
|
| 40 |
|
| 41 |
[build-system]
|
| 42 |
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
|
|
|
| 37 |
scikit-learn = "1.5.1"
|
| 38 |
xlrd = "2.0.1"
|
| 39 |
pydantic = "2.10"
|
| 40 |
+
pyqt5 = "^5.15.11"
|
| 41 |
|
| 42 |
[build-system]
|
| 43 |
requires = ["poetry-core>=2.0.0,<3.0.0"]
|
scripts/predict_from_fastsurfer_folder.py
CHANGED
|
@@ -13,10 +13,6 @@ from pyment.models.sfcn import sfcn_factory
|
|
| 13 |
from pyment.preprocessing.conform import conform
|
| 14 |
|
| 15 |
|
| 16 |
-
logging.basicConfig(
|
| 17 |
-
format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
|
| 18 |
-
level=logging.DEBUG
|
| 19 |
-
)
|
| 20 |
logger = logging.getLogger(__name__)
|
| 21 |
|
| 22 |
def _parse_folder_name(name: str) -> Tuple[str, str, str]:
|
|
@@ -71,7 +67,10 @@ def predict_from_fastsurfer_folder(
|
|
| 71 |
logger.debug('Conforming image from %s', os.path.join(source, folder))
|
| 72 |
image = conform(image)
|
| 73 |
|
| 74 |
-
predictions = model.predict(
|
|
|
|
|
|
|
|
|
|
| 75 |
logger.debug('Predictions for %s: %s', folder, str(predictions))
|
| 76 |
|
| 77 |
results.append({
|
|
|
|
| 13 |
from pyment.preprocessing.conform import conform
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
logger = logging.getLogger(__name__)
|
| 17 |
|
| 18 |
def _parse_folder_name(name: str) -> Tuple[str, str, str]:
|
|
|
|
| 67 |
logger.debug('Conforming image from %s', os.path.join(source, folder))
|
| 68 |
image = conform(image)
|
| 69 |
|
| 70 |
+
predictions = model.predict(
|
| 71 |
+
np.expand_dims(image, axis=0),
|
| 72 |
+
verbose=0
|
| 73 |
+
)[0]
|
| 74 |
logger.debug('Predictions for %s: %s', folder, str(predictions))
|
| 75 |
|
| 76 |
results.append({
|
tutorials/evaluate_ixi_predictions.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import matplotlib
|
| 3 |
+
import matplotlib.pyplot as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
import pandas as pd
|
| 7 |
+
|
| 8 |
+
from download_ixi import DEFAULT_DESTINATION
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def evaluate_ixi_predictions(
|
| 12 |
+
labels: str,
|
| 13 |
+
predictions: str
|
| 14 |
+
) -> None:
|
| 15 |
+
labels = pd.read_excel(labels)
|
| 16 |
+
print(labels.head())
|
| 17 |
+
predictions = pd.read_csv(predictions)
|
| 18 |
+
print(predictions.head())
|
| 19 |
+
predictions['IXI_ID'] = predictions['source'].apply(
|
| 20 |
+
lambda path: int(path.split('/')[-1][3:6])
|
| 21 |
+
)
|
| 22 |
+
predictions['age_prediction'] = predictions['age']
|
| 23 |
+
predictions = pd.merge(
|
| 24 |
+
predictions[['IXI_ID', 'age_prediction']],
|
| 25 |
+
labels[['IXI_ID', 'AGE']],
|
| 26 |
+
on='IXI_ID',
|
| 27 |
+
how='left'
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
mae = np.mean(np.abs(predictions['AGE'] - predictions['age_prediction']))
|
| 31 |
+
print(f'MAE: {mae}')
|
| 32 |
+
|
| 33 |
+
plt.scatter(predictions['AGE'], predictions['age_prediction'])
|
| 34 |
+
plt.xlabel('True age')
|
| 35 |
+
plt.ylabel('Predicted age')
|
| 36 |
+
plt.title('Age prediction')
|
| 37 |
+
plt.show()
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
if __name__ == '__main__':
|
| 41 |
+
parser = argparse.ArgumentParser(
|
| 42 |
+
'Evaluates predictions for the IXI dataset'
|
| 43 |
+
)
|
| 44 |
+
parser.add_argument(
|
| 45 |
+
'-l', '--labels',
|
| 46 |
+
required=False,
|
| 47 |
+
default=os.path.join(DEFAULT_DESTINATION, 'IXI.xls'),
|
| 48 |
+
help='Path to XLSX containing labels'
|
| 49 |
+
)
|
| 50 |
+
parser.add_argument(
|
| 51 |
+
'-p', '--predictions',
|
| 52 |
+
required=False,
|
| 53 |
+
default=os.path.join(
|
| 54 |
+
DEFAULT_DESTINATION,
|
| 55 |
+
'outputs',
|
| 56 |
+
'predictions.csv'
|
| 57 |
+
),
|
| 58 |
+
help='Path to CSV containing predictions'
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
args = parser.parse_args()
|
| 62 |
+
|
| 63 |
+
evaluate_ixi_predictions(
|
| 64 |
+
labels=args.labels,
|
| 65 |
+
predictions=args.predictions
|
| 66 |
+
)
|