estenhl commited on
Commit
66269ec
·
1 Parent(s): 2a12e39

Working on setting up automatic weight downloading

Browse files
README.md CHANGED
@@ -1,5 +1,34 @@
1
  # Installation
2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
  ## Install pyenv and Python
5
 
@@ -100,4 +129,14 @@ Finally, we can run the preprocessing script, pointing towards the python from t
100
  ```
101
  sh scripts/preprocess.sh --license <path-to-license> --python ~/venvs/fastsurfer/bin/python ~/data/ixi/images ~/data/ixi/preprocessed
102
  ```
 
 
 
 
 
 
 
 
 
 
103
  </details>
 
1
  # Installation
2
 
3
+ ## Configure system
4
+ <details>
5
+ <summary>Ubuntu</summary>
6
+
7
+ First we need to download and install CUDA 11.2:
8
+ ```
9
+ wget https://developer.download.nvidia.com/compute/cuda/11.2.2/local_installers/cuda_11.2.2_460.32.03_linux.run
10
+ sudo sh cuda_11.2.2_460.32.03_linux.run --silent --toolkit --installpath=/usr/local/cuda-11.2
11
+ ```
12
+
13
+ Next, cudnn must be installed. Download a suitable deb-file from
14
+ https://developer.nvidia.com/rdp/cudnn-archive. Then install the file:
15
+ ```
16
+ sudo dpkg -i ~/Downloads/cudnn-local-repo-ubuntu2204-8.9.7.29_1.0-1_amd64.deb
17
+ sudo cp /var/cudnn-local-repo-ubuntu2204-8.9.7.29/cudnn-local-*-keyring.gpg /usr/share/keyrings/
18
+ sudo apt update
19
+ sudo apt install libcudnn8 libcudnn8-dev
20
+ sudo cp /usr/include/cudnn*.h /usr/local/cuda-11.2/include/
21
+ sudo cp -P /usr/lib/x86_64-linux-gnu/libcudnn*.so* /usr/local/cuda-11.2/lib64/
22
+ sudo ldconfig
23
+ ```
24
+ Finally, we must configure the system paths (or add these lines to ~/.bashrc:
25
+ ```
26
+ echo 'export CUDA_HOME=/usr/local/cuda-11.2' >> ~/.bashrc
27
+ echo 'export PATH=$CUDA_HOME/bin:$PATH' >> ~/.bashrc
28
+ echo 'export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$CUDA_HOME/extras/CUPTI/lib64' >> ~/.bashrc
29
+ ```
30
+
31
+ </details>
32
 
33
  ## Install pyenv and Python
34
 
 
129
  ```
130
  sh scripts/preprocess.sh --license <path-to-license> --python ~/venvs/fastsurfer/bin/python ~/data/ixi/images ~/data/ixi/preprocessed
131
  ```
132
+
133
+ ### Generate predictions
134
+ After preprocessing, we can generate predictions for the IXI dataset using the scripts in the repository. First, ensure the virtual environment is loaded:
135
+ ```
136
+ eval $(poetry env activate)
137
+ ```
138
+ Next, run the prediction-script:
139
+ ```
140
+ python scripts/predict_from_fastsurfer_folder.py
141
+ ```
142
  </details>
poetry.lock CHANGED
@@ -1,4 +1,4 @@
1
- # This file is automatically @generated by Poetry 2.2.0 and should not be changed by hand.
2
 
3
  [[package]]
4
  name = "absl-py"
@@ -2807,6 +2807,13 @@ optional = false
2807
  python-versions = ">=3.8"
2808
  groups = ["main"]
2809
  files = [
 
 
 
 
 
 
 
2810
  {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
2811
  {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
2812
  {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
 
1
+ # This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
2
 
3
  [[package]]
4
  name = "absl-py"
 
2807
  python-versions = ">=3.8"
2808
  groups = ["main"]
2809
  files = [
2810
+ {file = "PyYAML-6.0.3-cp38-cp38-macosx_10_13_x86_64.whl", hash = "sha256:c2514fceb77bc5e7a2f7adfaa1feb2fb311607c9cb518dbc378688ec73d8292f"},
2811
+ {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9c57bb8c96f6d1808c030b1687b9b5fb476abaa47f0db9c0101f5e9f394e97f4"},
2812
+ {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_s390x.manylinux_2_17_s390x.manylinux_2_28_s390x.whl", hash = "sha256:efd7b85f94a6f21e4932043973a7ba2613b059c4a000551892ac9f1d11f5baf3"},
2813
+ {file = "PyYAML-6.0.3-cp38-cp38-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:22ba7cfcad58ef3ecddc7ed1db3409af68d023b7f940da23c6c2a1890976eda6"},
2814
+ {file = "PyYAML-6.0.3-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:6344df0d5755a2c9a276d4473ae6b90647e216ab4757f8426893b5dd2ac3f369"},
2815
+ {file = "PyYAML-6.0.3-cp38-cp38-win32.whl", hash = "sha256:3ff07ec89bae51176c0549bc4c63aa6202991da2d9a6129d7aef7f1407d3f295"},
2816
+ {file = "PyYAML-6.0.3-cp38-cp38-win_amd64.whl", hash = "sha256:5cf4e27da7e3fbed4d6c3d8e797387aaad68102272f8f9752883bc32d61cb87b"},
2817
  {file = "pyyaml-6.0.3-cp310-cp310-macosx_10_13_x86_64.whl", hash = "sha256:214ed4befebe12df36bcc8bc2b64b396ca31be9304b8f59e25c11cf94a4c033b"},
2818
  {file = "pyyaml-6.0.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:02ea2dfa234451bbb8772601d7b8e426c2bfa197136796224e50e35a78777956"},
2819
  {file = "pyyaml-6.0.3-cp310-cp310-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:b30236e45cf30d2b8e7b3e85881719e98507abed1011bf463a8fa23e9c3e98a8"},
pyment/models/sfcn/sfcn.py CHANGED
@@ -79,7 +79,10 @@ class SFCN(Model):
79
  super().__init__(self.inputs, x)
80
 
81
  if weights:
 
82
  weights = ensure_weights(weights)
 
 
83
  status = self.load_weights(weights)
84
 
85
  print(weights)
 
79
  super().__init__(self.inputs, x)
80
 
81
  if weights:
82
+ print(weights)
83
  weights = ensure_weights(weights)
84
+ print(weights)
85
+ weights = 'checkpoints/pyment/sfcn-multi'
86
  status = self.load_weights(weights)
87
 
88
  print(weights)
pyment/models/utils/ensure_weights.py CHANGED
@@ -1,15 +1,60 @@
1
  import os
2
 
 
3
 
4
- def ensure_weights(identifier: str) -> str:
5
- """Takes either a path or an identifier for a valid weight
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  configuration as an argument, and returns a path-prefix to files
7
  containing the weights. If necessary, the weights are downloaded.
8
 
9
  Parameters
10
  ----------
11
  identifier : str
12
- Points to either a filename or a valid keyword identifiying a
13
  weight file.
14
 
15
  Returns
@@ -21,20 +66,24 @@ def ensure_weights(identifier: str) -> str:
21
  ------
22
  KeyError
23
  If the identifier is not a valid identifier and there does not
24
- exist either a single file <identifier> or files
25
  <identifier>.index and <identifier>.data-00000-of-00001 on the
26
  local file system.
27
  """
28
- if not (
29
  (
30
- os.path.isfile(f'{identifier}.index') and
31
  os.path.isfile(f'{identifier}.data-00000-of-00001')
32
  ) or (
33
  os.path.isfile(identifier)
34
  )
35
  ):
 
 
 
 
36
  raise NotImplementedError(
37
  f'Identifier-based lookups are not supported'
38
  )
39
 
40
- return identifier
 
1
  import os
2
 
3
+ from pyment.utils.download_file import download_file
4
 
5
+
6
+ IDENTIFIERS = {
7
+ 'multi-2025': {
8
+ 'data': 'f4054d701fa59971fb7000d38cf9e63a202bd66a',
9
+ 'index': '9c208ca0bcc3969ceb281ba63a8cee4944a63157'
10
+ }
11
+ }
12
+ BASE_URL = 'https://api.github.com/repos/estenhl/pyment-public/git/blobs'
13
+
14
+ def _lookup_identifier(identifier: str, local_cache: str) -> str:
15
+ if not (
16
+ os.path.isfile(
17
+ os.path.join(local_cache, f'{identifier}.index')
18
+ ) and os.path.isfile(
19
+ os.path.join(local_cache, f'{identifier}.data-00000-of-00001')
20
+ )
21
+ ):
22
+ if not os.path.isdir(local_cache):
23
+ os.makedirs(local_cache, exist_ok=True)
24
+
25
+ download_file(
26
+ url=BASE_URL + '/' + IDENTIFIERS[identifier]['data'],
27
+ destination=os.path.join(
28
+ local_cache,
29
+ f'{identifier}.data-00000-of-00001'
30
+ ),
31
+ description=f'Downloading {identifier} data'
32
+ )
33
+ download_file(
34
+ url=BASE_URL + '/' + IDENTIFIERS[identifier]['index'],
35
+ destination=os.path.join(
36
+ local_cache,
37
+ f'{identifier}.index'
38
+ ),
39
+ description=f'Downloading {identifier} index'
40
+ )
41
+
42
+ return os.path.join(local_cache, identifier)
43
+
44
+ def ensure_weights(
45
+ identifier: str,
46
+ local_cache: str = os.path.join(
47
+ os.path.expanduser('~'), '.pyment', 'weights'
48
+ )
49
+ ) -> str:
50
+ """Takes either a path or an identifier for a valid weight
51
  configuration as an argument, and returns a path-prefix to files
52
  containing the weights. If necessary, the weights are downloaded.
53
 
54
  Parameters
55
  ----------
56
  identifier : str
57
+ Points to either a filename or a valid keyword identifiying a
58
  weight file.
59
 
60
  Returns
 
66
  ------
67
  KeyError
68
  If the identifier is not a valid identifier and there does not
69
+ exist either a single file <identifier> or files
70
  <identifier>.index and <identifier>.data-00000-of-00001 on the
71
  local file system.
72
  """
73
+ if (
74
  (
75
+ os.path.isfile(f'{identifier}.index') and
76
  os.path.isfile(f'{identifier}.data-00000-of-00001')
77
  ) or (
78
  os.path.isfile(identifier)
79
  )
80
  ):
81
+ return identifier
82
+ elif identifier in IDENTIFIERS:
83
+ return _lookup_identifier(identifier, local_cache)
84
+ else:
85
  raise NotImplementedError(
86
  f'Identifier-based lookups are not supported'
87
  )
88
 
89
+ return identifier
pyment/utils/download_file.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import requests
3
+ from tqdm import tqdm
4
+
5
+
6
+ def download_file(
7
+ url: str,
8
+ destination: str,
9
+ description: str = None
10
+ ) -> str:
11
+ with requests.get(url, stream=True) as response:
12
+ response.raise_for_status()
13
+ total_size = int(response.headers.get('content-length', 0))
14
+
15
+ # 1 MB chunks
16
+ chunk_size = 1<<20
17
+
18
+ progress_bar = tqdm(
19
+ response.iter_content(chunk_size=chunk_size),
20
+ total=int(math.ceil(total_size / chunk_size)),
21
+ unit='mb',
22
+ unit_scale=True,
23
+ unit_divisor=1024,
24
+ desc=description
25
+ )
26
+ progress_bar.format_dict['rate'] = f'mb/s'
27
+
28
+ with open(destination, 'wb') as f:
29
+ for chunk in progress_bar:
30
+ f.write(chunk)
scripts/predict_from_fastsurfer_folder.py CHANGED
@@ -28,7 +28,7 @@ def _parse_folder_name(name: str) -> Tuple[str, str, str]:
28
  return match.groups()
29
 
30
  def predict_from_fastsurfer_folder(
31
- source: str,
32
  weights: str,
33
  model_name: str = 'sfcn-multi',
34
  targets: List[str] = [
@@ -38,7 +38,7 @@ def predict_from_fastsurfer_folder(
38
  ) -> pd.DataFrame:
39
  if destination is not None and os.path.isfile(destination):
40
  raise ValueError(f'Destination {destination} already exists')
41
-
42
  logger.info('Loading multi-task model with weights %s', weights)
43
 
44
  model_class = sfcn_factory(model_name)
@@ -62,7 +62,7 @@ def predict_from_fastsurfer_folder(
62
  if not os.path.isfile(brainmask):
63
  logger.warning('No mask.mgz file for folder %s', folder)
64
  continue
65
-
66
  brainmask = nib.load(brainmask)
67
  brainmask = brainmask.get_fdata()
68
 
@@ -72,10 +72,8 @@ def predict_from_fastsurfer_folder(
72
  image = conform(image)
73
 
74
  predictions = model.predict(np.expand_dims(image, axis=0))[0]
75
- print(predictions.shape)
76
- print(predictions)
77
  logger.debug('Predictions for %s: %s', folder, str(predictions))
78
-
79
  results.append({
80
  **{
81
  'source': os.path.join(source, folder),
@@ -100,7 +98,7 @@ if __name__ == '__main__':
100
  )
101
 
102
  parser.add_argument(
103
- 'root',
104
  help=(
105
  'Path to FastSurfer folder. Should contain subfolders that have '
106
  'an \'mri\' subfolder that contains files orig.mgz and mask.mgz'
@@ -108,15 +106,16 @@ if __name__ == '__main__':
108
  )
109
  parser.add_argument(
110
  '-w', '--weights',
111
- required=True,
 
112
  help=(
113
  'Weights to use. Should either point to a local file path, or a '
114
- 'known keyword. If a local file path <path> is used, there should '
115
  'exist files named <path>.index and <path>.data-00000-of-00001'
116
  )
117
  )
118
  parser.add_argument(
119
- '-m', '--model',
120
  required=False,
121
  default='sfcn-multi',
122
  help=(
@@ -128,7 +127,7 @@ if __name__ == '__main__':
128
  required=False,
129
  nargs='+',
130
  default=[
131
- 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
132
  'neuroticism'
133
  ],
134
  help='Name to use for each of the prediction heads in the output CSV'
 
28
  return match.groups()
29
 
30
  def predict_from_fastsurfer_folder(
31
+ source: str,
32
  weights: str,
33
  model_name: str = 'sfcn-multi',
34
  targets: List[str] = [
 
38
  ) -> pd.DataFrame:
39
  if destination is not None and os.path.isfile(destination):
40
  raise ValueError(f'Destination {destination} already exists')
41
+
42
  logger.info('Loading multi-task model with weights %s', weights)
43
 
44
  model_class = sfcn_factory(model_name)
 
62
  if not os.path.isfile(brainmask):
63
  logger.warning('No mask.mgz file for folder %s', folder)
64
  continue
65
+
66
  brainmask = nib.load(brainmask)
67
  brainmask = brainmask.get_fdata()
68
 
 
72
  image = conform(image)
73
 
74
  predictions = model.predict(np.expand_dims(image, axis=0))[0]
 
 
75
  logger.debug('Predictions for %s: %s', folder, str(predictions))
76
+
77
  results.append({
78
  **{
79
  'source': os.path.join(source, folder),
 
98
  )
99
 
100
  parser.add_argument(
101
+ 'root',
102
  help=(
103
  'Path to FastSurfer folder. Should contain subfolders that have '
104
  'an \'mri\' subfolder that contains files orig.mgz and mask.mgz'
 
106
  )
107
  parser.add_argument(
108
  '-w', '--weights',
109
+ required=False,
110
+ default='multi-2025',
111
  help=(
112
  'Weights to use. Should either point to a local file path, or a '
113
+ 'known identifier. If a local file path <path> is used, there should '
114
  'exist files named <path>.index and <path>.data-00000-of-00001'
115
  )
116
  )
117
  parser.add_argument(
118
+ '-m', '--model',
119
  required=False,
120
  default='sfcn-multi',
121
  help=(
 
127
  required=False,
128
  nargs='+',
129
  default=[
130
+ 'age', 'sex', 'handedness', 'bmi', 'fluid_intelligence',
131
  'neuroticism'
132
  ],
133
  help='Name to use for each of the prediction heads in the output CSV'
scripts/utils/upload_weights_to_github.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import base64
3
+ import os
4
+ import requests
5
+
6
+
7
+ def upload_weights_to_github(filename: str, token: str, user: str, repo: str):
8
+ with open(filename, 'rb') as f:
9
+ bytes = f.read()
10
+
11
+ bytes = base64.b64encode(bytes).decode()
12
+
13
+ if os.path.isfile(token):
14
+ with open(token, 'r') as f:
15
+ token = f.read().strip()
16
+
17
+ headers = {
18
+ 'Accept': 'application/vnd.github+json',
19
+ 'Authorization': f'Bearer {token}',
20
+ 'X-GitHub-Api-Version': '2022-11-28'
21
+ }
22
+ content = {
23
+ 'content': bytes,
24
+ 'encoding': 'base64'
25
+ }
26
+
27
+ url = f'https://api.github.com/repos/{user}/{repo}/git/blobs'
28
+
29
+ response = requests.post(url, json=content, headers=headers)
30
+
31
+ return response
32
+
33
+ if __name__ == '__main__':
34
+ parser = argparse.ArgumentParser('Uploads a weights-file to github')
35
+
36
+ parser.add_argument(
37
+ '-f', '--filename',
38
+ required=True,
39
+ help='Path to file containing weights'
40
+ )
41
+ parser.add_argument(
42
+ '-t', '--token',
43
+ required=True,
44
+ help='Token for the GitHub API'
45
+ )
46
+ parser.add_argument(
47
+ '-u', '--user',
48
+ required=False,
49
+ default='estenhl',
50
+ help='Owner of the github repo'
51
+ )
52
+ parser.add_argument(
53
+ '-r', '--repo',
54
+ required=False,
55
+ default='pyment-public',
56
+ help='Name of the github repo'
57
+ )
58
+
59
+ args = parser.parse_args()
60
+
61
+ response = upload_weights_to_github(
62
+ filename=args.filename,
63
+ token=args.token,
64
+ user=args.user,
65
+ repo=args.repo
66
+ )
67
+
68
+ print(response.text)
tutorials/download_ixi.py CHANGED
@@ -5,34 +5,10 @@ import requests
5
  import tarfile
6
  from tqdm import tqdm
7
 
 
8
 
9
- DEFAULT_DESTINATION = os.path.join(os.path.expanduser('~'), 'data', 'ixi')
10
-
11
- def download_file(
12
- url: str,
13
- destination: str,
14
- description: str = None
15
- ) -> str:
16
- with requests.get(url, stream=True) as response:
17
- response.raise_for_status()
18
- total_size = int(response.headers.get('content-length', 0))
19
-
20
- # 1 MB chunks
21
- chunk_size = 1<<20
22
 
23
- progress_bar = tqdm(
24
- response.iter_content(chunk_size=chunk_size),
25
- total=int(math.ceil(total_size / chunk_size)),
26
- unit='mb',
27
- unit_scale=True,
28
- unit_divisor=1024,
29
- desc=description
30
- )
31
- progress_bar.format_dict['rate'] = f'mb/s'
32
-
33
- with open(destination, 'wb') as f:
34
- for chunk in progress_bar:
35
- f.write(chunk)
36
 
37
  def download_tar(tar_path: str) -> str:
38
  url = (
@@ -82,4 +58,4 @@ if __name__ == '__main__':
82
 
83
  args = parser.parse_args()
84
 
85
- download_ixi(args.destination)
 
5
  import tarfile
6
  from tqdm import tqdm
7
 
8
+ from pyment.utils.download_file import download_file
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
+ DEFAULT_DESTINATION = os.path.join(os.path.expanduser('~'), 'data', 'ixi')
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def download_tar(tar_path: str) -> str:
14
  url = (
 
58
 
59
  args = parser.parse_args()
60
 
61
+ download_ixi(args.destination)