SreekarB commited on
Commit
9135a28
·
verified ·
1 Parent(s): cb31997

Upload 9 files

Browse files
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🧠
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
analysis.py CHANGED
@@ -1,16 +1,39 @@
1
  from scipy.stats import pearsonr
 
 
2
 
3
  def analyze_fc_patterns(latents, demographics):
4
  results = {}
5
- for demo_name, demo_values in demographics.items():
6
- if demo_name != 'sex': # For continuous variables
7
- correlations = []
8
- p_values = []
9
- for latent_dim in range(latents.shape[1]):
10
- r, p = pearsonr(latents[:, latent_dim], demo_values)
11
- correlations.append(r)
12
- p_values.append(p)
13
- results[demo_name] = {'correlations': correlations, 'p_values': p_values}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
  return results
16
 
 
1
  from scipy.stats import pearsonr
2
+ import numpy as np
3
+ import warnings
4
 
5
  def analyze_fc_patterns(latents, demographics):
6
  results = {}
7
+
8
+ # Suppress the ConstantInputWarning
9
+ with warnings.catch_warnings():
10
+ warnings.filterwarnings('ignore', category=RuntimeWarning)
11
+
12
+ for demo_name, demo_values in demographics.items():
13
+ # Check if the demographic is categorical or continuous
14
+ if demo_name not in ['sex', 'gender']: # For continuous variables
15
+ correlations = []
16
+ p_values = []
17
+
18
+ for latent_dim in range(latents.shape[1]):
19
+ # Check for constant values that would cause correlation issues
20
+ if np.all(latents[:, latent_dim] == latents[0, latent_dim]) or np.all(demo_values == demo_values[0]):
21
+ # If either array is constant, correlation is undefined
22
+ r, p = np.nan, np.nan
23
+ else:
24
+ try:
25
+ # Convert to numpy array if not already
26
+ demo_array = np.array(demo_values, dtype=float)
27
+ # Calculate correlation
28
+ r, p = pearsonr(latents[:, latent_dim], demo_array)
29
+ except (ValueError, TypeError) as e:
30
+ print(f"Error calculating correlation for {demo_name}, dimension {latent_dim}: {e}")
31
+ r, p = np.nan, np.nan
32
+
33
+ correlations.append(r)
34
+ p_values.append(p)
35
+
36
+ results[demo_name] = {'correlations': correlations, 'p_values': p_values}
37
 
38
  return results
39
 
data_preprocessing.py CHANGED
@@ -23,9 +23,40 @@ def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
23
  print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
24
 
25
  # For SreekarB/OSFData dataset, the data will be loaded from dataset features
26
- if isinstance(dataset_or_niifiles, str) and dataset_or_niifiles == "SreekarB/OSFData":
27
- print("Loading data from SreekarB/OSFData dataset")
28
- dataset = load_dataset(dataset_or_niifiles, split="train")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  # Prepare demographics data from the dataset
31
  if demo_data is None:
@@ -50,93 +81,369 @@ def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
50
  print("Searching for NIfTI files in dataset columns...")
51
  nii_files = []
52
 
53
- # First check if there are any columns with .nii files
54
- for col in dataset.column_names:
55
- # Check if column contains file paths
56
- first_val = dataset[col][0] if len(dataset) > 0 else None
57
- if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
58
- print(f"Found column '{col}' with NIfTI file paths")
59
-
60
- # Try to download files from HuggingFace Hub
61
- from huggingface_hub import hf_hub_download
62
- import tempfile
63
-
64
- for item in dataset[col]:
65
- try:
66
- # Download the NIfTI file
67
- file_path = hf_hub_download(
68
- repo_id="SreekarB/OSFData",
69
- filename=item,
70
- repo_type="dataset"
71
- )
72
- nii_files.append(file_path)
73
- print(f"Downloaded {item} from HuggingFace Hub")
74
- except Exception as e:
75
- print(f"Error downloading {item}: {e}")
76
- # Try looking for the file locally
77
- local_path = os.path.join(os.getcwd(), item)
78
- if os.path.exists(local_path):
79
- nii_files.append(local_path)
80
- print(f"Found {item} locally")
81
- else:
82
- print(f"Warning: Could not find {item} locally or on HuggingFace")
83
 
84
- # If we found NIfTI files, process them to FC matrices
85
- if nii_files:
86
- print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
87
-
88
- # Load Power 264 atlas
89
- from nilearn import datasets
90
- power = datasets.fetch_coords_power_2011()
91
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
92
-
93
- masker = input_data.NiftiSpheresMasker(
94
- coords, radius=5,
95
- standardize=True,
96
- memory='nilearn_cache', memory_level=1,
97
- verbose=0,
98
- detrend=True,
99
- low_pass=0.1,
100
- high_pass=0.01,
101
- t_r=2.0 # Adjust TR according to your data
102
- )
103
 
104
- # Process fMRI data and compute FC matrices
105
- fc_matrices = []
106
- for nii_file in nii_files:
107
- try:
108
- fmri_img = load_img(nii_file)
109
- time_series = masker.fit_transform(fmri_img)
 
 
110
 
111
- correlation_measure = connectome.ConnectivityMeasure(
112
- kind='correlation',
113
- vectorize=False,
114
- discard_diagonal=False
 
 
 
 
 
115
  )
 
 
 
 
116
 
117
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
118
 
119
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
120
- fc_triu = fc_matrix[triu_indices]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- fc_triu = np.arctanh(fc_triu) # Fisher z-transform
 
 
123
 
124
- fc_matrices.append(fc_triu)
125
- print(f"Processed {nii_file} to FC matrix")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  except Exception as e:
127
- print(f"Error processing {nii_file}: {e}")
128
-
129
- if fc_matrices:
130
- X = np.array(fc_matrices)
131
- # Normalize the FC data
132
- X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
133
- print(f"Created FC matrices with shape {X.shape}")
134
- return X, demo_data, demo_types
135
- break # Stop after finding one column with NIfTI files
136
-
137
- # If we're here, we couldn't find or process the NIfTI files
138
- print("Could not find or process NIfTI files from the dataset.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
  print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
141
  # Return a placeholder with the right demographics but empty FC
142
  n_subjects = len(dataset)
@@ -247,5 +554,4 @@ def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
247
  # Normalize the FC data
248
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
249
 
250
- return X, demo_data, demo_types
251
-
 
23
  print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
24
 
25
  # For SreekarB/OSFData dataset, the data will be loaded from dataset features
26
+ if isinstance(dataset_or_niifiles, str):
27
+ dataset_name = dataset_or_niifiles
28
+ print(f"Loading data from dataset: {dataset_name}")
29
+ try:
30
+ # Try multiple approaches to load the dataset
31
+ approaches = [
32
+ lambda: load_dataset(dataset_name, split="train"),
33
+ lambda: load_dataset(dataset_name), # Try without split
34
+ lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code
35
+ lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
36
+ ]
37
+
38
+ dataset = None
39
+ last_error = None
40
+
41
+ for i, approach in enumerate(approaches):
42
+ if approach is None:
43
+ continue
44
+
45
+ try:
46
+ print(f"Attempt {i+1} to load dataset...")
47
+ dataset = approach()
48
+ print(f"Successfully loaded dataset with approach {i+1}!")
49
+ break
50
+ except Exception as e:
51
+ print(f"Attempt {i+1} failed: {e}")
52
+ last_error = e
53
+
54
+ if dataset is None:
55
+ print(f"All attempts to load dataset failed. Last error: {last_error}")
56
+ raise ValueError(f"Could not load dataset {dataset_name}")
57
+ except Exception as e:
58
+ print(f"Error during dataset loading: {e}")
59
+ raise
60
 
61
  # Prepare demographics data from the dataset
62
  if demo_data is None:
 
81
  print("Searching for NIfTI files in dataset columns...")
82
  nii_files = []
83
 
84
+ # Create a temp directory for downloads
85
+ import tempfile
86
+ from huggingface_hub import hf_hub_download
87
+ import shutil
88
+
89
+ temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
90
+ print(f"Created temporary directory for NIfTI files: {temp_dir}")
91
+
92
+ try:
93
+ # First approach: Check if there are any columns containing file paths
94
+ nii_columns = []
95
+ for col in dataset.column_names:
96
+ # Check if column name suggests NIfTI files
97
+ if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
98
+ nii_columns.append(col)
99
+ # Or check if column contains file paths
100
+ elif len(dataset) > 0:
101
+ first_val = dataset[0][col]
102
+ if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
103
+ nii_columns.append(col)
104
+
105
+ if nii_columns:
106
+ print(f"Found columns that may contain NIfTI files: {nii_columns}")
 
 
 
 
 
 
 
107
 
108
+ for col in nii_columns:
109
+ print(f"Processing column '{col}'...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ for i, item in enumerate(dataset[col]):
112
+ if not isinstance(item, str):
113
+ print(f"Item {i} in column {col} is not a string but {type(item)}")
114
+ continue
115
+
116
+ if not (item.endswith('.nii') or item.endswith('.nii.gz')):
117
+ print(f"Item {i} in column {col} is not a NIfTI file: {item}")
118
+ continue
119
 
120
+ print(f"Downloading {item} from dataset {dataset_name}...")
121
+
122
+ try:
123
+ # Attempt to download with explicit filename
124
+ file_path = hf_hub_download(
125
+ repo_id=dataset_name,
126
+ filename=item,
127
+ repo_type="dataset",
128
+ cache_dir=temp_dir
129
  )
130
+ nii_files.append(file_path)
131
+ print(f"✓ Successfully downloaded {item}")
132
+ except Exception as e1:
133
+ print(f"Error downloading with explicit filename: {e1}")
134
 
135
+ # Second attempt: try with the item's basename
136
+ try:
137
+ basename = os.path.basename(item)
138
+ print(f"Trying with basename: {basename}")
139
+ file_path = hf_hub_download(
140
+ repo_id=dataset_name,
141
+ filename=basename,
142
+ repo_type="dataset",
143
+ cache_dir=temp_dir
144
+ )
145
+ nii_files.append(file_path)
146
+ print(f"✓ Successfully downloaded {basename}")
147
+ except Exception as e2:
148
+ print(f"Error downloading with basename: {e2}")
149
+
150
+ # Third attempt: check if it's a binary blob in the dataset
151
+ try:
152
+ if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
153
+ print("Found binary data in dataset, saving to temporary file...")
154
+ binary_data = dataset[i]['bytes']
155
+ temp_file = os.path.join(temp_dir, basename)
156
+ with open(temp_file, 'wb') as f:
157
+ f.write(binary_data)
158
+ nii_files.append(temp_file)
159
+ print(f"✓ Saved binary data to {temp_file}")
160
+ except Exception as e3:
161
+ print(f"Error handling binary data: {e3}")
162
+
163
+ # Last resort: look for the file locally
164
+ local_path = os.path.join(os.getcwd(), item)
165
+ if os.path.exists(local_path):
166
+ nii_files.append(local_path)
167
+ print(f"✓ Found {item} locally")
168
+ else:
169
+ print(f"❌ Warning: Could not find {item} anywhere")
170
+
171
+ # Second approach: Try to find NIfTI files in dataset repository directly
172
+ if not nii_files:
173
+ print("No NIfTI files found in dataset columns. Trying direct repository search...")
174
+
175
+ try:
176
+ from huggingface_hub import list_repo_files, hf_hub_download
177
+
178
+ # Try to list all files in the repository
179
+ try:
180
+ print("Listing all repository files...")
181
+ all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
182
+ print(f"Found {len(all_repo_files)} files in repository")
183
+
184
+ # First prioritize P*_rs.nii files
185
+ p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
186
+
187
+ # Then include all other NIfTI files
188
+ other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
189
+
190
+ # Combine, with P*_rs.nii files first
191
+ nii_repo_files = p_rs_files + other_nii_files
192
+
193
+ if nii_repo_files:
194
+ print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
195
 
196
+ # Download each file
197
+ for nii_file in nii_repo_files:
198
+ try:
199
+ file_path = hf_hub_download(
200
+ repo_id=dataset_name,
201
+ filename=nii_file,
202
+ repo_type="dataset",
203
+ cache_dir=temp_dir
204
+ )
205
+ nii_files.append(file_path)
206
+ print(f"✓ Downloaded {nii_file}")
207
+ except Exception as e:
208
+ print(f"Error downloading {nii_file}: {e}")
209
+ except Exception as e:
210
+ print(f"Error listing repository files: {e}")
211
+ print("Will try alternative approaches...")
212
+
213
+ # If repo listing fails, try with common NIfTI file patterns directly
214
+ if not nii_files:
215
+ print("Trying common NIfTI file patterns...")
216
+
217
+ # Focus specifically on P*_rs.nii pattern
218
+ patterns = []
219
+
220
+ # Generate P01_rs.nii through P30_rs.nii
221
+ for i in range(1, 31): # Try subjects 1-30
222
+ patterns.append(f"P{i:02d}_rs.nii")
223
 
224
+ # Also try with .nii.gz extension
225
+ for i in range(1, 31):
226
+ patterns.append(f"P{i:02d}_rs.nii.gz")
227
 
228
+ # Include a few other common patterns as fallbacks
229
+ patterns.extend([
230
+ "sub-01_task-rest_bold.nii.gz", # BIDS format
231
+ "fmri.nii.gz", "bold.nii.gz",
232
+ "rest.nii.gz"
233
+ ])
234
+
235
+ for pattern in patterns:
236
+ try:
237
+ print(f"Trying to download {pattern}...")
238
+ file_path = hf_hub_download(
239
+ repo_id=dataset_name,
240
+ filename=pattern,
241
+ repo_type="dataset",
242
+ cache_dir=temp_dir
243
+ )
244
+ nii_files.append(file_path)
245
+ print(f"✓ Successfully downloaded {pattern}")
246
+ except Exception as e:
247
+ print(f"× Failed to download {pattern}")
248
+
249
+ # If we still couldn't find any files, check if data files are nested
250
+ if not nii_files:
251
+ print("Checking for nested data files...")
252
+ nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
253
+
254
+ for path in nested_paths:
255
+ for pattern in patterns:
256
+ nested_file = f"{path}{pattern}"
257
+ try:
258
+ print(f"Trying to download {nested_file}...")
259
+ file_path = hf_hub_download(
260
+ repo_id=dataset_name,
261
+ filename=nested_file,
262
+ repo_type="dataset",
263
+ cache_dir=temp_dir
264
+ )
265
+ nii_files.append(file_path)
266
+ print(f"✓ Successfully downloaded {nested_file}")
267
+ # If we found one file in this directory, try to find all files in it
268
+ try:
269
+ all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
270
+ nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
271
+ print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
272
+
273
+ for nii_file in nii_files_in_dir:
274
+ if nii_file != nested_file: # Skip the one we already downloaded
275
+ try:
276
+ file_path = hf_hub_download(
277
+ repo_id=dataset_name,
278
+ filename=nii_file,
279
+ repo_type="dataset",
280
+ cache_dir=temp_dir
281
+ )
282
+ nii_files.append(file_path)
283
+ print(f"✓ Downloaded {nii_file}")
284
+ except Exception as e:
285
+ print(f"Error downloading {nii_file}: {e}")
286
+ except Exception as e:
287
+ print(f"Error finding additional files in {path}: {e}")
288
+ except Exception as e:
289
+ pass
290
+
291
+ except Exception as e:
292
+ print(f"Error during repository exploration: {e}")
293
+
294
+ # If we still don't have any files, try to search for P*_rs.nii pattern specifically
295
+ if not nii_files:
296
+ print("Trying to find files matching P*_rs.nii pattern specifically...")
297
+
298
+ try:
299
+ # List all files in the repository (if we haven't already)
300
+ if not 'all_repo_files' in locals():
301
+ from huggingface_hub import list_repo_files
302
+ try:
303
+ all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
304
  except Exception as e:
305
+ print(f"Error listing repo files: {e}")
306
+ all_repo_files = []
307
+
308
+ # Look for files matching the pattern exactly (P*_rs.nii)
309
+ pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
310
+
311
+ # If we don't find any exact matches, try a more relaxed pattern
312
+ if not pattern_files:
313
+ pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
314
+
315
+ if pattern_files:
316
+ print(f"Found {len(pattern_files)} files matching rs.nii pattern")
317
+
318
+ # Download each file
319
+ for pattern_file in pattern_files:
320
+ try:
321
+ file_path = hf_hub_download(
322
+ repo_id=dataset_name,
323
+ filename=pattern_file,
324
+ repo_type="dataset",
325
+ cache_dir=temp_dir
326
+ )
327
+ nii_files.append(file_path)
328
+ print(f"✓ Downloaded {pattern_file}")
329
+ except Exception as e:
330
+ print(f"Error downloading {pattern_file}: {e}")
331
+ except Exception as e:
332
+ print(f"Error searching for pattern files: {e}")
333
+
334
+ print(f"Found total of {len(nii_files)} NIfTI files")
335
+ except Exception as e:
336
+ print(f"Unexpected error during NIfTI file search: {e}")
337
+ import traceback
338
+ traceback.print_exc()
339
 
340
+ # If we found NIfTI files, process them to FC matrices
341
+ if nii_files:
342
+ print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
343
+
344
+ # Load Power 264 atlas
345
+ from nilearn import datasets
346
+ power = datasets.fetch_coords_power_2011()
347
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
348
+
349
+ masker = input_data.NiftiSpheresMasker(
350
+ coords, radius=5,
351
+ standardize=True,
352
+ memory='nilearn_cache', memory_level=1,
353
+ verbose=0,
354
+ detrend=True,
355
+ low_pass=0.1,
356
+ high_pass=0.01,
357
+ t_r=2.0 # Adjust TR according to your data
358
+ )
359
+
360
+ # Process fMRI data and compute FC matrices
361
+ fc_matrices = []
362
+ valid_files = 0
363
+ total_files = len(nii_files)
364
+
365
+ for nii_file in nii_files:
366
+ try:
367
+ print(f"Processing {nii_file}...")
368
+ fmri_img = load_img(nii_file)
369
+
370
+ # Check image dimensions
371
+ if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
372
+ print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
373
+ continue
374
+
375
+ time_series = masker.fit_transform(fmri_img)
376
+
377
+ # Validate time series data
378
+ if np.isnan(time_series).any() or np.isinf(time_series).any():
379
+ print(f"Warning: {nii_file} contains NaN or Inf values after masking")
380
+ # Replace NaNs with zeros for this file
381
+ time_series = np.nan_to_num(time_series)
382
+
383
+ correlation_measure = connectome.ConnectivityMeasure(
384
+ kind='correlation',
385
+ vectorize=False,
386
+ discard_diagonal=False
387
+ )
388
+
389
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
390
+
391
+ # Check for invalid correlation values
392
+ if np.isnan(fc_matrix).any():
393
+ print(f"Warning: {nii_file} produced NaN correlation values")
394
+ continue
395
+
396
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
397
+ fc_triu = fc_matrix[triu_indices]
398
+
399
+ # Fisher z-transform with proper bounds check
400
+ # Clip correlation values to valid range for arctanh
401
+ fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
402
+ fc_triu = np.arctanh(fc_triu_clipped)
403
+
404
+ fc_matrices.append(fc_triu)
405
+ valid_files += 1
406
+ print(f"Successfully processed {nii_file} to FC matrix")
407
+
408
+ except Exception as e:
409
+ print(f"Error processing {nii_file}: {e}")
410
+
411
+ if fc_matrices:
412
+ print(f"Successfully processed {valid_files} out of {total_files} files")
413
+
414
+ # Ensure all matrices have the same dimensions
415
+ dims = [m.shape[0] for m in fc_matrices]
416
+ if len(set(dims)) > 1:
417
+ print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
418
+ # Use the most common dimension
419
+ from collections import Counter
420
+ most_common_dim = Counter(dims).most_common(1)[0][0]
421
+ print(f"Using most common dimension: {most_common_dim}")
422
+ fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
423
+
424
+ X = np.array(fc_matrices)
425
+
426
+ # Normalize the FC data
427
+ mean_x = np.mean(X, axis=0)
428
+ std_x = np.std(X, axis=0)
429
+
430
+ # Handle zero standard deviation
431
+ std_x[std_x == 0] = 1.0
432
+
433
+ X = (X - mean_x) / std_x
434
+ print(f"Created FC matrices with shape {X.shape}")
435
+
436
+ # Make sure demo_data matches the number of FC matrices
437
+ if len(demo_data[0]) != X.shape[0]:
438
+ print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
439
+ f"doesn't match number of FC matrices ({X.shape[0]})")
440
+ # Adjust demo_data to match FC matrices
441
+ indices = list(range(min(len(demo_data[0]), X.shape[0])))
442
+ X = X[indices]
443
+ demo_data = [d[indices] for d in demo_data]
444
+
445
+ return X, demo_data, demo_types
446
+
447
  print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
448
  # Return a placeholder with the right demographics but empty FC
449
  n_subjects = len(dataset)
 
554
  # Normalize the FC data
555
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
556
 
557
+ return X, demo_data, demo_types
 
main.py CHANGED
@@ -25,6 +25,23 @@ def train_fc_vae(X, demo_data, demo_types, model_config):
25
 
26
  print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  vae = DemoVAE(
29
  latent_dim=model_config['latent_dim'],
30
  nepochs=model_config['nepochs'],
@@ -194,16 +211,32 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
194
  # Generate new FC matrix
195
  print("Generating new FC matrices...")
196
 
 
 
 
197
  # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
198
  new_demographics = [
199
- np.array([60.0]), # age
200
- np.array(['M']), # gender
201
- np.array([12.0]), # months post onset
202
- np.array([80.0]) # wab score
203
  ]
204
 
 
 
 
 
 
205
  print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
206
- generated_fc = vae.transform(1, new_demographics, demo_types)
 
 
 
 
 
 
 
 
207
  reconstructed_fc = vae.transform(X, demo_data, demo_types)
208
 
209
  # Visualize results
 
25
 
26
  print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
27
 
28
+ # Ensure X is a numpy array with correct data type
29
+ if not isinstance(X, np.ndarray):
30
+ print(f"Converting X from {type(X)} to numpy array")
31
+ X = np.array(X, dtype=np.float32)
32
+
33
+ # Ensure demo_data contains numpy arrays
34
+ for i, d in enumerate(demo_data):
35
+ if not isinstance(d, np.ndarray):
36
+ print(f"Converting demographic {i} from {type(d)} to numpy array")
37
+ demo_data[i] = np.array(d)
38
+
39
+ # Check for NaN or Inf values
40
+ if np.isnan(X).any() or np.isinf(X).any():
41
+ print("Warning: X contains NaN or Inf values. Replacing with zeros.")
42
+ X = np.nan_to_num(X)
43
+
44
+ # Create the VAE model
45
  vae = DemoVAE(
46
  latent_dim=model_config['latent_dim'],
47
  nepochs=model_config['nepochs'],
 
211
  # Generate new FC matrix
212
  print("Generating new FC matrices...")
213
 
214
+ # Get data types from original demographic data for proper conversion
215
+ demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
216
+
217
  # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
218
  new_demographics = [
219
+ np.array([60.0], dtype=np.float64), # age
220
+ np.array(['M'], dtype=np.str_), # gender
221
+ np.array([12.0], dtype=np.float64), # months post onset
222
+ np.array([80.0], dtype=np.float64) # wab score
223
  ]
224
 
225
+ # Verify the demographic data arrays match the expected types
226
+ print("Demographic data types:")
227
+ for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
228
+ print(f" {name}: shape={data.shape}, dtype={data.dtype}")
229
+
230
  print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
231
+ try:
232
+ generated_fc = vae.transform(1, new_demographics, demo_types)
233
+ except Exception as e:
234
+ print(f"Error generating new FC matrix: {e}")
235
+ # Try with a fallback approach
236
+ print("Trying alternative generation approach...")
237
+ # If specific gender is causing issues, try the first gender from training data
238
+ new_demographics[1] = np.array([demo_data[1][0]])
239
+ generated_fc = vae.transform(1, new_demographics, demo_types)
240
  reconstructed_fc = vae.transform(X, demo_data, demo_types)
241
 
242
  # Visualize results
src/.DS_Store CHANGED
Binary files a/src/.DS_Store and b/src/.DS_Store differ
 
src/demovae/model.py CHANGED
@@ -14,7 +14,11 @@ def to_torch(x):
14
 
15
  def to_cuda(x, use_cuda):
16
  if use_cuda:
17
- return x.cuda()
 
 
 
 
18
  else:
19
  return x
20
 
 
14
 
15
  def to_cuda(x, use_cuda):
16
  if use_cuda:
17
+ try:
18
+ return x.cuda()
19
+ except (RuntimeError, AssertionError) as e:
20
+ print(f"Warning: CUDA error: {e}. Falling back to CPU.")
21
+ return x
22
  else:
23
  return x
24
 
src/demovae/sklearn.py CHANGED
@@ -1,5 +1,6 @@
1
 
2
- from demovae.model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
 
3
 
4
  from sklearn.base import BaseEstimator
5
 
 
1
 
2
+ from .model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
3
+ import numpy as np
4
 
5
  from sklearn.base import BaseEstimator
6
 
test_hf_download.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import tempfile
4
+ from huggingface_hub import hf_hub_download, list_repo_files, HfApi
5
+ from datasets import load_dataset
6
+ import numpy as np
7
+ import pandas as pd
8
+
9
+ def test_huggingface_download(dataset_name="SreekarB/OSFData", revision=None, auth_token=None):
10
+ """
11
+ Test script to verify downloading NIfTI files from HuggingFace Datasets
12
+ """
13
+ print(f"Testing download from HuggingFace dataset: {dataset_name}")
14
+
15
+ # Create a temporary directory for downloads
16
+ temp_dir = tempfile.mkdtemp(prefix="hf_test_")
17
+ print(f"Created temporary directory: {temp_dir}")
18
+
19
+ try:
20
+ # Step 1: Load the dataset and check its structure
21
+ print("\n1. Loading dataset...")
22
+ try:
23
+ print(f"Attempting to load dataset: {dataset_name}, revision: {revision}, with auth token: {'Yes' if auth_token else 'No'}")
24
+
25
+ # Try multiple approaches to load the dataset
26
+ approaches = [
27
+ lambda: load_dataset(dataset_name, split="train", revision=revision, token=auth_token),
28
+ lambda: load_dataset(dataset_name, revision=revision, token=auth_token), # Try without split
29
+ lambda: load_dataset(dataset_name, split="train", trust_remote_code=True, revision=revision, token=auth_token), # Try with trust_remote_code
30
+ # If the dataset name has a slash, try just the second part
31
+ lambda: load_dataset(dataset_name.split("/")[-1], split="train", revision=revision, token=auth_token) if "/" in dataset_name else None
32
+ ]
33
+
34
+ dataset = None
35
+ last_error = None
36
+
37
+ for i, approach in enumerate(approaches):
38
+ if approach is None:
39
+ continue
40
+
41
+ try:
42
+ print(f"Attempt {i+1}...")
43
+ dataset = approach()
44
+ print(f"Attempt {i+1} succeeded!")
45
+ break
46
+ except Exception as e:
47
+ print(f"Attempt {i+1} failed: {e}")
48
+ last_error = e
49
+
50
+ if dataset is None:
51
+ print(f"All attempts to load dataset failed. Last error: {last_error}")
52
+
53
+ # Try direct API approach
54
+ print("Trying direct HF API approach...")
55
+ try:
56
+ api = HfApi()
57
+ repo_info = api.repo_info(repo_id=dataset_name, repo_type="dataset")
58
+ print(f"Dataset exists on HF: {repo_info.id}")
59
+ print(f"Dataset info: Private: {repo_info.private}, Size: {repo_info.size}")
60
+ except Exception as e:
61
+ print(f"API check failed: {e}")
62
+ return
63
+
64
+ print(f"Successfully loaded dataset with {len(dataset)} items")
65
+ print(f"Dataset structure: {type(dataset)}")
66
+
67
+ try:
68
+ print(f"Dataset columns: {dataset.column_names}")
69
+ if len(dataset) > 0:
70
+ print(f"First item keys: {list(dataset[0].keys())}")
71
+ except Exception as e:
72
+ print(f"Error accessing dataset structure: {e}")
73
+ except Exception as e:
74
+ print(f"Error loading dataset: {e}")
75
+ return
76
+
77
+ # Step 2: Look for columns that might contain NIfTI files
78
+ print("\n2. Searching for columns with NIfTI files...")
79
+ nii_columns = []
80
+ for col in dataset.column_names:
81
+ if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
82
+ nii_columns.append(col)
83
+ continue
84
+
85
+ if len(dataset) > 0:
86
+ try:
87
+ first_val = dataset[0][col]
88
+ if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
89
+ nii_columns.append(col)
90
+ except:
91
+ pass
92
+
93
+ if nii_columns:
94
+ print(f"Found {len(nii_columns)} columns that may contain NIfTI files: {nii_columns}")
95
+
96
+ # Step 3: Try to download a file from each column
97
+ print("\n3. Attempting to download from columns...")
98
+ downloaded = False
99
+
100
+ for col in nii_columns:
101
+ print(f"\nTrying column '{col}'...")
102
+
103
+ # Get the first non-empty value
104
+ file_path = None
105
+ for i, item in enumerate(dataset[col]):
106
+ if isinstance(item, str) and (item.endswith('.nii') or item.endswith('.nii.gz')):
107
+ file_path = item
108
+ print(f"Found NIfTI file path at index {i}: {file_path}")
109
+ break
110
+
111
+ if not file_path:
112
+ print(f"No valid NIfTI file paths found in column {col}")
113
+ continue
114
+
115
+ # Try to download the file
116
+ try:
117
+ downloaded_path = hf_hub_download(
118
+ repo_id=dataset_name,
119
+ filename=file_path,
120
+ repo_type="dataset",
121
+ cache_dir=temp_dir
122
+ )
123
+ print(f"✓ Successfully downloaded to: {downloaded_path}")
124
+ downloaded = True
125
+ break
126
+ except Exception as e:
127
+ print(f"× Failed to download {file_path}: {e}")
128
+
129
+ # Try with basename
130
+ try:
131
+ basename = os.path.basename(file_path)
132
+ print(f"Trying with basename: {basename}")
133
+ downloaded_path = hf_hub_download(
134
+ repo_id=dataset_name,
135
+ filename=basename,
136
+ repo_type="dataset",
137
+ cache_dir=temp_dir
138
+ )
139
+ print(f"✓ Successfully downloaded to: {downloaded_path}")
140
+ downloaded = True
141
+ break
142
+ except Exception as e:
143
+ print(f"× Failed to download {basename}: {e}")
144
+
145
+ if downloaded:
146
+ print("\nSuccessfully downloaded a file from column data!")
147
+ else:
148
+ print("\nFailed to download any files from columns.")
149
+ else:
150
+ print("No columns found that might contain NIfTI files.")
151
+
152
+ # Step 4: Try to list and download from repository directly
153
+ print("\n4. Searching repository files directly...")
154
+ try:
155
+ all_files = list_repo_files(dataset_name, repo_type="dataset")
156
+ print(f"Found {len(all_files)} files in repository")
157
+
158
+ # Look for NIfTI files
159
+ nii_files = [f for f in all_files if f.endswith('.nii') or f.endswith('.nii.gz')]
160
+ rs_files = [f for f in all_files if 'rs.nii' in f.lower()]
161
+
162
+ print(f"Found {len(nii_files)} .nii/.nii.gz files")
163
+ print(f"Found {len(rs_files)} files matching 'rs.nii' pattern")
164
+
165
+ if nii_files:
166
+ # Try to download the first file
167
+ test_file = nii_files[0]
168
+ print(f"Attempting to download: {test_file}")
169
+
170
+ try:
171
+ downloaded_path = hf_hub_download(
172
+ repo_id=dataset_name,
173
+ filename=test_file,
174
+ repo_type="dataset",
175
+ cache_dir=temp_dir
176
+ )
177
+ print(f"✓ Successfully downloaded to: {downloaded_path}")
178
+ except Exception as e:
179
+ print(f"× Failed to download {test_file}: {e}")
180
+ else:
181
+ print("No NIfTI files found in repository listing.")
182
+
183
+ # Step 5: Try any P01_rs.nii pattern files specifically
184
+ if rs_files:
185
+ print("\n5. Trying P01_rs.nii pattern files...")
186
+ test_file = rs_files[0]
187
+ print(f"Attempting to download: {test_file}")
188
+
189
+ try:
190
+ downloaded_path = hf_hub_download(
191
+ repo_id=dataset_name,
192
+ filename=test_file,
193
+ repo_type="dataset",
194
+ cache_dir=temp_dir
195
+ )
196
+ print(f"✓ Successfully downloaded to: {downloaded_path}")
197
+ except Exception as e:
198
+ print(f"× Failed to download {test_file}: {e}")
199
+
200
+ except Exception as e:
201
+ print(f"Error listing repository files: {e}")
202
+
203
+ # Step 6: Check if dataset is accessible through HF API
204
+ print("\n6. Checking dataset through HF API...")
205
+ try:
206
+ api = HfApi()
207
+ repo_info = api.repo_info(repo_id=dataset_name, repo_type="dataset")
208
+ print(f"Repository info: {repo_info.sha}, {repo_info.lastModified}")
209
+
210
+ # Check if there are binary files
211
+ lfs_files = [f for f in all_files if 'lfs' in api.get_paths_info(dataset_name, paths=[f]).paths[0].lfs]
212
+ print(f"Found {len(lfs_files)} LFS (potentially binary) files")
213
+ if lfs_files and len(lfs_files) > 0:
214
+ print(f"First LFS file: {lfs_files[0]}")
215
+ except Exception as e:
216
+ print(f"Error accessing dataset through API: {e}")
217
+
218
+ except Exception as e:
219
+ print(f"Unexpected error during testing: {e}")
220
+ import traceback
221
+ traceback.print_exc()
222
+ finally:
223
+ print(f"\nTest completed. Temporary directory: {temp_dir}")
224
+ # Uncomment to clean up: shutil.rmtree(temp_dir)
225
+
226
+ if __name__ == "__main__":
227
+ # Process command line arguments
228
+ import argparse
229
+ parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
230
+ parser.add_argument('--dataset', type=str, default="SreekarB/OSFData", help='HuggingFace dataset name')
231
+ parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
232
+ parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
233
+
234
+ args = parser.parse_args()
235
+
236
+ # Use command line arguments
237
+ test_huggingface_download(args.dataset, args.revision, args.token)