estenhl commited on
Commit
b42b662
·
1 Parent(s): 2d0f971

Finished manual tutorial

Browse files
README.md CHANGED
@@ -21,7 +21,7 @@ sudo cp /usr/include/cudnn*.h /usr/local/cuda-11.2/include/
21
  sudo cp -P /usr/lib/x86_64-linux-gnu/libcudnn*.so* /usr/local/cuda-11.2/lib64/
22
  sudo ldconfig
23
  ```
24
- Finally, we must configure the system paths (or add these lines to ~/.bashrc:
25
  ```
26
  echo 'export CUDA_HOME=/usr/local/cuda-11.2' >> ~/.bashrc
27
  echo 'export PATH=$CUDA_HOME/bin:$PATH' >> ~/.bashrc
@@ -135,8 +135,9 @@ After preprocessing, we can generate predictions for the IXI dataset using the s
135
  ```
136
  eval $(poetry env activate)
137
  ```
138
- Next, run the prediction-script:
139
  ```
140
- python scripts/predict_from_fastsurfer_folder.py
 
141
  ```
142
  </details>
 
21
  sudo cp -P /usr/lib/x86_64-linux-gnu/libcudnn*.so* /usr/local/cuda-11.2/lib64/
22
  sudo ldconfig
23
  ```
24
+ Finally, we must configure the system paths in .bashrc:
25
  ```
26
  echo 'export CUDA_HOME=/usr/local/cuda-11.2' >> ~/.bashrc
27
  echo 'export PATH=$CUDA_HOME/bin:$PATH' >> ~/.bashrc
 
135
  ```
136
  eval $(poetry env activate)
137
  ```
138
+ Next, make an output-folder for the predictions and run the prediction-script:
139
  ```
140
+ mkdir ~/data/ixi/outputs
141
+ python scripts/predict_from_fastsurfer_folder.py ~/data/ixi/preprocessed -d ~/data/ixi/outputs/predictions.csv
142
  ```
143
  </details>
poetry.lock CHANGED
@@ -2714,6 +2714,74 @@ files = [
2714
  [package.extras]
2715
  diagrams = ["jinja2", "railroad-diagrams"]
2716
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2717
  [[package]]
2718
  name = "pytest"
2719
  version = "8.3.3"
@@ -4088,4 +4156,4 @@ test = ["pytest", "pytest-cov"]
4088
  [metadata]
4089
  lock-version = "2.1"
4090
  python-versions = "3.10.4"
4091
- content-hash = "5bcab6231b21344133a42a97b439b6c847b1ef803c281d2518a40bbfa7d9cef2"
 
2714
  [package.extras]
2715
  diagrams = ["jinja2", "railroad-diagrams"]
2716
 
2717
+ [[package]]
2718
+ name = "pyqt5"
2719
+ version = "5.15.11"
2720
+ description = "Python bindings for the Qt cross platform application toolkit"
2721
+ optional = false
2722
+ python-versions = ">=3.8"
2723
+ groups = ["main"]
2724
+ files = [
2725
+ {file = "PyQt5-5.15.11-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:c8b03dd9380bb13c804f0bdb0f4956067f281785b5e12303d529f0462f9afdc2"},
2726
+ {file = "PyQt5-5.15.11-cp38-abi3-macosx_11_0_x86_64.whl", hash = "sha256:6cd75628f6e732b1ffcfe709ab833a0716c0445d7aec8046a48d5843352becb6"},
2727
+ {file = "PyQt5-5.15.11-cp38-abi3-manylinux_2_17_x86_64.whl", hash = "sha256:cd672a6738d1ae33ef7d9efa8e6cb0a1525ecf53ec86da80a9e1b6ec38c8d0f1"},
2728
+ {file = "PyQt5-5.15.11-cp38-abi3-win32.whl", hash = "sha256:76be0322ceda5deecd1708a8d628e698089a1cea80d1a49d242a6d579a40babd"},
2729
+ {file = "PyQt5-5.15.11-cp38-abi3-win_amd64.whl", hash = "sha256:bdde598a3bb95022131a5c9ea62e0a96bd6fb28932cc1619fd7ba211531b7517"},
2730
+ {file = "PyQt5-5.15.11.tar.gz", hash = "sha256:fda45743ebb4a27b4b1a51c6d8ef455c4c1b5d610c90d2934c7802b5c1557c52"},
2731
+ ]
2732
+
2733
+ [package.dependencies]
2734
+ PyQt5-Qt5 = ">=5.15.2,<5.16.0"
2735
+ PyQt5-sip = ">=12.15,<13"
2736
+
2737
+ [[package]]
2738
+ name = "pyqt5-qt5"
2739
+ version = "5.15.17"
2740
+ description = "The subset of a Qt installation needed by PyQt5."
2741
+ optional = false
2742
+ python-versions = "*"
2743
+ groups = ["main"]
2744
+ files = [
2745
+ {file = "PyQt5_Qt5-5.15.17-py3-none-macosx_10_13_x86_64.whl", hash = "sha256:d8b8094108e748b4bbd315737cfed81291d2d228de43278f0b8bd7d2b808d2b9"},
2746
+ {file = "PyQt5_Qt5-5.15.17-py3-none-macosx_11_0_arm64.whl", hash = "sha256:b68628f9b8261156f91d2f72ebc8dfb28697c4b83549245d9a68195bd2d74f0c"},
2747
+ {file = "PyQt5_Qt5-5.15.17-py3-none-manylinux2014_x86_64.whl", hash = "sha256:b018f75d1cc61146396fa5af14da1db77c5d6318030e5e366f09ffdf7bd358d8"},
2748
+ ]
2749
+
2750
+ [[package]]
2751
+ name = "pyqt5-sip"
2752
+ version = "12.17.1"
2753
+ description = "The sip module support for PyQt5"
2754
+ optional = false
2755
+ python-versions = ">=3.9"
2756
+ groups = ["main"]
2757
+ files = [
2758
+ {file = "pyqt5_sip-12.17.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:bd4f73b1ebd5e0bd8d4539a8e55132318efc70a92f648ef0f9d93329ad50adeb"},
2759
+ {file = "pyqt5_sip-12.17.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b52e85520dbfe5c3d0c0c47aa2c10fc1853d892ae60ebebfe8154b052394da50"},
2760
+ {file = "pyqt5_sip-12.17.1-cp310-cp310-win32.whl", hash = "sha256:71a67e2c9b77a74e943e220db0a341c702fd9bcf83c4a2e07342dfce691742ae"},
2761
+ {file = "pyqt5_sip-12.17.1-cp310-cp310-win_amd64.whl", hash = "sha256:2710effb921bf6955b902779c763d890bb593da6325f0e128a0e3991cc855e9f"},
2762
+ {file = "pyqt5_sip-12.17.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:5134d637efadd108a70306bab55b3d7feaa951bf6b8162161a67ae847bea9130"},
2763
+ {file = "pyqt5_sip-12.17.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:155cf755266c8bf64428916e2ff720d5efa1aec003d4ccc40c003b147dbdac03"},
2764
+ {file = "pyqt5_sip-12.17.1-cp311-cp311-win32.whl", hash = "sha256:9dfa7fe4ac93b60004430699c4bf56fef842a356d64dfea7cbc6d580d0427d6d"},
2765
+ {file = "pyqt5_sip-12.17.1-cp311-cp311-win_amd64.whl", hash = "sha256:2ddd214cf40119b86942a5da2da5a7345334955ab00026d8dcc56326b30e6d3c"},
2766
+ {file = "pyqt5_sip-12.17.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:c362606de782d2d46374a38523632786f145c517ee62de246a6069e5f2c5f336"},
2767
+ {file = "pyqt5_sip-12.17.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:140cc582151456103ebb149fefc678f3cae803e7720733db51212af5219cd45c"},
2768
+ {file = "pyqt5_sip-12.17.1-cp312-cp312-win32.whl", hash = "sha256:9dc1f1525d4d42c080f6cfdfc70d78239f8f67b0a48ea0745497251d8d848b1d"},
2769
+ {file = "pyqt5_sip-12.17.1-cp312-cp312-win_amd64.whl", hash = "sha256:d5e2e9e175559017cd161d661e0ee0b551684f824bb90800c5a8c8a3bea9355e"},
2770
+ {file = "pyqt5_sip-12.17.1-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:9ebbd7769ccdaaa6295e9c872553b6cde17f38e171056f17300d8af9a14d1fc8"},
2771
+ {file = "pyqt5_sip-12.17.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:b023da906a70af2cf5e6fc1932f441ede07530f3e164dd52c6c2bb5ab7c6f424"},
2772
+ {file = "pyqt5_sip-12.17.1-cp313-cp313-win32.whl", hash = "sha256:36dbef482bd638786b909f3bda65b7b3d5cbd6cbf16797496de38bae542da307"},
2773
+ {file = "pyqt5_sip-12.17.1-cp313-cp313-win_amd64.whl", hash = "sha256:d04e5551bbc3bcec98acc63b3b0618ddcbf31ff107349225b516fe7e7c0a7c8b"},
2774
+ {file = "pyqt5_sip-12.17.1-cp314-cp314-macosx_10_15_universal2.whl", hash = "sha256:c49918287e1ad77956d1589f1d3d432a0be7630c646ea02cf652413a48e14458"},
2775
+ {file = "pyqt5_sip-12.17.1-cp314-cp314-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:944a4bf1e1ee18ad03a54964c1c6433fb6de582313a1f0b17673e7203e22fc83"},
2776
+ {file = "pyqt5_sip-12.17.1-cp314-cp314-win32.whl", hash = "sha256:99a2935fd662a67748625b1e6ffa0a2d1f2da068b9df6db04fa59a4a5d4ee613"},
2777
+ {file = "pyqt5_sip-12.17.1-cp314-cp314-win_amd64.whl", hash = "sha256:aaa33232cc80793d14fdb3b149b27eec0855612ed66aad480add5ac49b9cee63"},
2778
+ {file = "pyqt5_sip-12.17.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:6fdc457bd528e5909a5893db0a7dee0066d5f22e08234c9152db0ae6df9a367f"},
2779
+ {file = "pyqt5_sip-12.17.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.whl", hash = "sha256:06ea59741c1bffb198d99b00d26594791f45fb11b10f774c8105aea5962e3835"},
2780
+ {file = "pyqt5_sip-12.17.1-cp39-cp39-win32.whl", hash = "sha256:b9ef23869d35c6740a95fcb1f387f4aea8d8fac80e19096fbaf1a64e18409c4b"},
2781
+ {file = "pyqt5_sip-12.17.1-cp39-cp39-win_amd64.whl", hash = "sha256:90eed15f19557dfab22e68c7763e3690053cc8dd30d93ade2523d1b5a04a87be"},
2782
+ {file = "pyqt5_sip-12.17.1.tar.gz", hash = "sha256:0eab72bcb628f1926bf5b9ac51259d4fa18e8b2a81d199071135458f7d087ea8"},
2783
+ ]
2784
+
2785
  [[package]]
2786
  name = "pytest"
2787
  version = "8.3.3"
 
4156
  [metadata]
4157
  lock-version = "2.1"
4158
  python-versions = "3.10.4"
4159
+ content-hash = "f4ed40969ecb1dec94a4d0f7f994e17cea56a8e86b72ea95a021f5247942d8a1"
pyment/models/utils/ensure_weights.py CHANGED
@@ -28,7 +28,8 @@ def _lookup_identifier(identifier: str, local_cache: str) -> str:
28
  local_cache,
29
  f'{identifier}.data-00000-of-00001'
30
  ),
31
- description=f'Downloading {identifier} data'
 
32
  )
33
  download_file(
34
  url=BASE_URL + '/' + IDENTIFIERS[identifier]['index'],
@@ -36,7 +37,8 @@ def _lookup_identifier(identifier: str, local_cache: str) -> str:
36
  local_cache,
37
  f'{identifier}.index'
38
  ),
39
- description=f'Downloading {identifier} index'
 
40
  )
41
 
42
  return os.path.join(local_cache, identifier)
 
28
  local_cache,
29
  f'{identifier}.data-00000-of-00001'
30
  ),
31
+ description=f'Downloading {identifier} data',
32
+ decode_github=True
33
  )
34
  download_file(
35
  url=BASE_URL + '/' + IDENTIFIERS[identifier]['index'],
 
37
  local_cache,
38
  f'{identifier}.index'
39
  ),
40
+ description=f'Downloading {identifier} index',
41
+ decode_github=True
42
  )
43
 
44
  return os.path.join(local_cache, identifier)
pyment/preprocessing/conform.py CHANGED
@@ -3,14 +3,10 @@ import numpy as np
3
  from typing import Tuple
4
 
5
 
6
- logging.basicConfig(
7
- format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
8
- level=logging.DEBUG
9
- )
10
  logger = logging.getLogger(__name__)
11
 
12
  def _pad_if_necessary(
13
- image: np.ndarray,
14
  target_shape: Tuple[int, int, int]
15
  ) -> np.ndarray:
16
  pad = [(0, 0)] * 3
@@ -25,7 +21,7 @@ def _pad_if_necessary(
25
  return np.pad(image, tuple(pad), mode='constant', constant_values=0)
26
 
27
  def _crop_if_necessary(
28
- image: np.ndarray,
29
  target_shape: Tuple[int, int, int]
30
  ) -> np.ndarray:
31
  nonzero = np.where(image != 0)
@@ -36,7 +32,7 @@ def _crop_if_necessary(
36
  if image.shape[dim] > target_shape[dim]:
37
  extrafluous = target_shape[dim] / 2
38
  center = np.round(np.mean([
39
- np.amin(nonzero[dim]),
40
  np.amax(nonzero[dim])
41
  ]))
42
  min_idx = int(center - extrafluous)
@@ -45,7 +41,7 @@ def _crop_if_necessary(
45
  if min_idx < 0:
46
  max_idx -= min_idx
47
  min_idx = 0
48
-
49
  if max_idx > image.shape[dim]:
50
  diff = max_idx - image.shape[dim]
51
  min_idx -= diff
@@ -58,7 +54,7 @@ def _crop_if_necessary(
58
  return image
59
 
60
  def _center_crop_or_pad(
61
- image: np.ndarray,
62
  target_shape: Tuple[int, int, int]
63
  ) -> np.ndarray:
64
  image = _pad_if_necessary(image, target_shape)
@@ -67,22 +63,22 @@ def _center_crop_or_pad(
67
  return image
68
 
69
  def conform(
70
- image: np.ndarray,
71
  relative_normalization: bool = False
72
  ) -> np.ndarray:
73
- """Conforms an image to the expected format if necessary. The
74
  expected format means an image of shape 224x192x224 with voxel
75
- values spanning the range [0, 1]. If the image has a redundant
76
- channel-dimension, this is removed. If the image is currently too
77
  large along any dimension, a "central" crop is made by determining
78
- the bound of the brain (e.g. non-zero voxels) and retaining
79
  equivalent padding on each side. If the image is currently too small
80
- along either axis, the image is zero-padded equally on each side.
81
  If the voxel-values does not fall within the expected range, they
82
- are normalized. If the relative_normalization-flag is set, the
83
  values are normalized by dividing by the image max, otherwise they
84
- are divided by 255. However, if the largest value is >255, this
85
- indicates that the image has not been processed with FastSurfer,
86
  and an error is raised.
87
 
88
  Parameters
@@ -106,7 +102,7 @@ def conform(
106
  if len(image.shape) == 4:
107
  if image.shape[-1] != 1:
108
  raise ValueError(f'Unable to handle multi-channel images')
109
-
110
  image = image[...,0]
111
 
112
  if image.shape != (224, 192, 224):
@@ -114,7 +110,7 @@ def conform(
114
 
115
  logger.debug('Conformed image shape: %s', str(image.shape))
116
  logger.debug(
117
- 'Original image voxel value range: %f-%f',
118
  np.amin(image), np.amax(image)
119
  )
120
 
@@ -124,8 +120,8 @@ def conform(
124
  image *= 255.0
125
 
126
  logger.debug(
127
- 'Conformed image voxel value range: %f-%f',
128
  np.amin(image), np.amax(image)
129
  )
130
 
131
- return image
 
3
  from typing import Tuple
4
 
5
 
 
 
 
 
6
  logger = logging.getLogger(__name__)
7
 
8
  def _pad_if_necessary(
9
+ image: np.ndarray,
10
  target_shape: Tuple[int, int, int]
11
  ) -> np.ndarray:
12
  pad = [(0, 0)] * 3
 
21
  return np.pad(image, tuple(pad), mode='constant', constant_values=0)
22
 
23
  def _crop_if_necessary(
24
+ image: np.ndarray,
25
  target_shape: Tuple[int, int, int]
26
  ) -> np.ndarray:
27
  nonzero = np.where(image != 0)
 
32
  if image.shape[dim] > target_shape[dim]:
33
  extrafluous = target_shape[dim] / 2
34
  center = np.round(np.mean([
35
+ np.amin(nonzero[dim]),
36
  np.amax(nonzero[dim])
37
  ]))
38
  min_idx = int(center - extrafluous)
 
41
  if min_idx < 0:
42
  max_idx -= min_idx
43
  min_idx = 0
44
+
45
  if max_idx > image.shape[dim]:
46
  diff = max_idx - image.shape[dim]
47
  min_idx -= diff
 
54
  return image
55
 
56
  def _center_crop_or_pad(
57
+ image: np.ndarray,
58
  target_shape: Tuple[int, int, int]
59
  ) -> np.ndarray:
60
  image = _pad_if_necessary(image, target_shape)
 
63
  return image
64
 
65
  def conform(
66
+ image: np.ndarray,
67
  relative_normalization: bool = False
68
  ) -> np.ndarray:
69
+ """Conforms an image to the expected format if necessary. The
70
  expected format means an image of shape 224x192x224 with voxel
71
+ values spanning the range [0, 1]. If the image has a redundant
72
+ channel-dimension, this is removed. If the image is currently too
73
  large along any dimension, a "central" crop is made by determining
74
+ the bound of the brain (e.g. non-zero voxels) and retaining
75
  equivalent padding on each side. If the image is currently too small
76
+ along either axis, the image is zero-padded equally on each side.
77
  If the voxel-values does not fall within the expected range, they
78
+ are normalized. If the relative_normalization-flag is set, the
79
  values are normalized by dividing by the image max, otherwise they
80
+ are divided by 255. However, if the largest value is >255, this
81
+ indicates that the image has not been processed with FastSurfer,
82
  and an error is raised.
83
 
84
  Parameters
 
102
  if len(image.shape) == 4:
103
  if image.shape[-1] != 1:
104
  raise ValueError(f'Unable to handle multi-channel images')
105
+
106
  image = image[...,0]
107
 
108
  if image.shape != (224, 192, 224):
 
110
 
111
  logger.debug('Conformed image shape: %s', str(image.shape))
112
  logger.debug(
113
+ 'Original image voxel value range: %f-%f',
114
  np.amin(image), np.amax(image)
115
  )
116
 
 
120
  image *= 255.0
121
 
122
  logger.debug(
123
+ 'Conformed image voxel value range: %f-%f',
124
  np.amin(image), np.amax(image)
125
  )
126
 
127
+ return image
pyment/utils/download_file.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import math
2
  import requests
3
  from tqdm import tqdm
@@ -6,7 +8,8 @@ from tqdm import tqdm
6
  def download_file(
7
  url: str,
8
  destination: str,
9
- description: str = None
 
10
  ) -> str:
11
  with requests.get(url, stream=True) as response:
12
  response.raise_for_status()
@@ -28,3 +31,13 @@ def download_file(
28
  with open(destination, 'wb') as f:
29
  for chunk in progress_bar:
30
  f.write(chunk)
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import json
3
  import math
4
  import requests
5
  from tqdm import tqdm
 
8
  def download_file(
9
  url: str,
10
  destination: str,
11
+ description: str = None,
12
+ decode_github: bool = False
13
  ) -> str:
14
  with requests.get(url, stream=True) as response:
15
  response.raise_for_status()
 
31
  with open(destination, 'wb') as f:
32
  for chunk in progress_bar:
33
  f.write(chunk)
34
+
35
+ if decode_github:
36
+ # Assumes a JSON file downloaded from GitHub
37
+ with open(destination, 'rb') as f:
38
+ data = json.load(f)
39
+
40
+ data = base64.b64decode(data['content'])
41
+
42
+ with open(destination, 'wb') as f:
43
+ f.write(data)
pyproject.toml CHANGED
@@ -37,6 +37,7 @@ pytest = "8.3.3"
37
  scikit-learn = "1.5.1"
38
  xlrd = "2.0.1"
39
  pydantic = "2.10"
 
40
 
41
  [build-system]
42
  requires = ["poetry-core>=2.0.0,<3.0.0"]
 
37
  scikit-learn = "1.5.1"
38
  xlrd = "2.0.1"
39
  pydantic = "2.10"
40
+ pyqt5 = "^5.15.11"
41
 
42
  [build-system]
43
  requires = ["poetry-core>=2.0.0,<3.0.0"]
scripts/predict_from_fastsurfer_folder.py CHANGED
@@ -13,10 +13,6 @@ from pyment.models.sfcn import sfcn_factory
13
  from pyment.preprocessing.conform import conform
14
 
15
 
16
- logging.basicConfig(
17
- format='%(asctime)s - %(levelname)s - %(name)s: %(message)s',
18
- level=logging.DEBUG
19
- )
20
  logger = logging.getLogger(__name__)
21
 
22
  def _parse_folder_name(name: str) -> Tuple[str, str, str]:
@@ -71,7 +67,10 @@ def predict_from_fastsurfer_folder(
71
  logger.debug('Conforming image from %s', os.path.join(source, folder))
72
  image = conform(image)
73
 
74
- predictions = model.predict(np.expand_dims(image, axis=0))[0]
 
 
 
75
  logger.debug('Predictions for %s: %s', folder, str(predictions))
76
 
77
  results.append({
 
13
  from pyment.preprocessing.conform import conform
14
 
15
 
 
 
 
 
16
  logger = logging.getLogger(__name__)
17
 
18
  def _parse_folder_name(name: str) -> Tuple[str, str, str]:
 
67
  logger.debug('Conforming image from %s', os.path.join(source, folder))
68
  image = conform(image)
69
 
70
+ predictions = model.predict(
71
+ np.expand_dims(image, axis=0),
72
+ verbose=0
73
+ )[0]
74
  logger.debug('Predictions for %s: %s', folder, str(predictions))
75
 
76
  results.append({
tutorials/evaluate_ixi_predictions.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import matplotlib
3
+ import matplotlib.pyplot as plt
4
+ import numpy as np
5
+ import os
6
+ import pandas as pd
7
+
8
+ from download_ixi import DEFAULT_DESTINATION
9
+
10
+
11
+ def evaluate_ixi_predictions(
12
+ labels: str,
13
+ predictions: str
14
+ ) -> None:
15
+ labels = pd.read_excel(labels)
16
+ print(labels.head())
17
+ predictions = pd.read_csv(predictions)
18
+ print(predictions.head())
19
+ predictions['IXI_ID'] = predictions['source'].apply(
20
+ lambda path: int(path.split('/')[-1][3:6])
21
+ )
22
+ predictions['age_prediction'] = predictions['age']
23
+ predictions = pd.merge(
24
+ predictions[['IXI_ID', 'age_prediction']],
25
+ labels[['IXI_ID', 'AGE']],
26
+ on='IXI_ID',
27
+ how='left'
28
+ )
29
+
30
+ mae = np.mean(np.abs(predictions['AGE'] - predictions['age_prediction']))
31
+ print(f'MAE: {mae}')
32
+
33
+ plt.scatter(predictions['AGE'], predictions['age_prediction'])
34
+ plt.xlabel('True age')
35
+ plt.ylabel('Predicted age')
36
+ plt.title('Age prediction')
37
+ plt.show()
38
+
39
+
40
+ if __name__ == '__main__':
41
+ parser = argparse.ArgumentParser(
42
+ 'Evaluates predictions for the IXI dataset'
43
+ )
44
+ parser.add_argument(
45
+ '-l', '--labels',
46
+ required=False,
47
+ default=os.path.join(DEFAULT_DESTINATION, 'IXI.xls'),
48
+ help='Path to XLSX containing labels'
49
+ )
50
+ parser.add_argument(
51
+ '-p', '--predictions',
52
+ required=False,
53
+ default=os.path.join(
54
+ DEFAULT_DESTINATION,
55
+ 'outputs',
56
+ 'predictions.csv'
57
+ ),
58
+ help='Path to CSV containing predictions'
59
+ )
60
+
61
+ args = parser.parse_args()
62
+
63
+ evaluate_ixi_predictions(
64
+ labels=args.labels,
65
+ predictions=args.predictions
66
+ )