Esten Leonardsen commited on
Commit
feeef85
·
1 Parent(s): 55880f9

Small changes to BIDS scripts

Browse files
pyment/models/sfcn/__init__.py CHANGED
@@ -6,6 +6,8 @@ from .sfcn_reg import RegressionSFCN
6
  def sfcn_factory(model_type: str):
7
  if model_type in ['sfcn-reg', 'regression']:
8
  return RegressionSFCN
 
 
9
 
10
  raise ValueError(f'Unknown SFCN type {model_type}')
11
 
 
6
  def sfcn_factory(model_type: str):
7
  if model_type in ['sfcn-reg', 'regression']:
8
  return RegressionSFCN
9
+ elif model_type in ['sfcn-multi', 'multi']:
10
+ return MultiTaskSFCN
11
 
12
  raise ValueError(f'Unknown SFCN type {model_type}')
13
 
scripts/finetune_from_bids_folder.py CHANGED
@@ -162,10 +162,6 @@ def finetune_from_configuration(configuration: str):
162
  target=configuration.training.target
163
  )
164
 
165
- # strategy = tf.distribute.MirroredStrategy()
166
-
167
- # with strategy.scope():
168
-
169
  loss_cls = loss_factory(configuration.training.loss)
170
  loss = loss_cls()
171
 
 
162
  target=configuration.training.target
163
  )
164
 
 
 
 
 
165
  loss_cls = loss_factory(configuration.training.loss)
166
  loss = loss_cls()
167
 
scripts/predict_from_bids_folder.py CHANGED
@@ -20,7 +20,10 @@ logging.basicConfig(
20
  logger = logging.getLogger(__name__)
21
 
22
  def _extract_run(filename: str) -> str:
23
- match = re.fullmatch(r'.*_run-(.*)(?:_.*)?(?:\.nii(?:\.gz)?|\.mgz)', filename)
 
 
 
24
 
25
  if not match:
26
  logger.warning('Unable to extract run for filename %s', filename)
@@ -28,6 +31,17 @@ def _extract_run(filename: str) -> str:
28
 
29
  return match.groups()[0]
30
 
 
 
 
 
 
 
 
 
 
 
 
31
  def predict_from_bids_folder(
32
  source: str,
33
  weights: str,
@@ -61,6 +75,7 @@ def predict_from_bids_folder(
61
  for filename in os.listdir(anat_folder):
62
  path = os.path.join(anat_folder, filename)
63
  run = _extract_run(filename)
 
64
 
65
  logger.debug(f'Loading image {path}')
66
  image = nib.load(os.path.join(anat_folder, filename))
@@ -81,9 +96,10 @@ def predict_from_bids_folder(
81
  results.append({
82
  **{
83
  'source': path,
84
- 'subject': subject,
85
- 'session': session,
86
- 'run': run
 
87
  },
88
  **{targets[i]: predictions[i] for i in range(len(targets))}
89
  })
 
20
  logger = logging.getLogger(__name__)
21
 
22
  def _extract_run(filename: str) -> str:
23
+ match = re.fullmatch(
24
+ r'.*_run-([^_]+)(?:_.*)?(?:\.nii(?:\.gz)?|\.mgz)',
25
+ filename
26
+ )
27
 
28
  if not match:
29
  logger.warning('Unable to extract run for filename %s', filename)
 
31
 
32
  return match.groups()[0]
33
 
34
+ def _extract_modality(filename: str) -> str:
35
+ match = re.fullmatch(
36
+ r'.*_run-(?:[^_]+)(?:_(.*))?(?:\.nii(?:\.gz)?|\.mgz)',
37
+ filename
38
+ )
39
+
40
+ if not match:
41
+ logger.warning('Unable to extract modality for filename %s', filename)
42
+ return None
43
+
44
+ return match.groups()[0]
45
  def predict_from_bids_folder(
46
  source: str,
47
  weights: str,
 
75
  for filename in os.listdir(anat_folder):
76
  path = os.path.join(anat_folder, filename)
77
  run = _extract_run(filename)
78
+ modality = _extract_modality(filename)
79
 
80
  logger.debug(f'Loading image {path}')
81
  image = nib.load(os.path.join(anat_folder, filename))
 
96
  results.append({
97
  **{
98
  'source': path,
99
+ 'subject': subject.replace('sub-', ''),
100
+ 'session': session.replace('ses-', ''),
101
+ 'run': run,
102
+ 'modality': modality
103
  },
104
  **{targets[i]: predictions[i] for i in range(len(targets))}
105
  })