File size: 29,608 Bytes
ef677f1
 
1c47445
b32645b
ef677f1
 
1c47445
dfe19ad
1c47445
dfe19ad
1c47445
dbe81c1
1c47445
 
 
 
dbe81c1
1c47445
 
 
 
dbe81c1
1c47445
dbe81c1
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a4c8f0c
b32645b
 
a4c8f0c
1c47445
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
import numpy as np
import pandas as pd
from datasets import load_dataset
from nilearn import input_data, connectome
from nilearn.image import load_img
import nibabel as nib
import os

def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
    """
    Process fMRI data to generate functional connectivity matrices
    
    Parameters:
    - dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
    - demo_data: Optional demographic data, required if providing NIfTI files
    - demo_types: Optional demographic data types, required if providing NIfTI files
    
    Returns:
    - X: Array of FC matrices
    - demo_data: Demographic data
    - demo_types: Demographic data types
    """
    print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
    
    # For SreekarB/OSFData dataset, the data will be loaded from dataset features
    if isinstance(dataset_or_niifiles, str):
        dataset_name = dataset_or_niifiles
        print(f"Loading data from dataset: {dataset_name}")
        try:
            # Try multiple approaches to load the dataset
            approaches = [
                lambda: load_dataset(dataset_name, split="train"),
                lambda: load_dataset(dataset_name),  # Try without split
                lambda: load_dataset(dataset_name, split="train", trust_remote_code=True),  # Try with trust_remote_code
                lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
            ]
            
            dataset = None
            last_error = None
            
            for i, approach in enumerate(approaches):
                if approach is None:
                    continue
                    
                try:
                    print(f"Attempt {i+1} to load dataset...")
                    dataset = approach()
                    print(f"Successfully loaded dataset with approach {i+1}!")
                    break
                except Exception as e:
                    print(f"Attempt {i+1} failed: {e}")
                    last_error = e
                    
            if dataset is None:
                print(f"All attempts to load dataset failed. Last error: {last_error}")
                raise ValueError(f"Could not load dataset {dataset_name}")
        except Exception as e:
            print(f"Error during dataset loading: {e}")
            raise
        
        # Prepare demographics data from the dataset
        if demo_data is None:
            # Create demo_data from the dataset
            demo_df = pd.DataFrame({
                'age': dataset['age'],
                'gender': dataset['gender'],
                'mpo': dataset['mpo'],
                'wab_aq': dataset['wab_aq']
            })
            
            demo_data = [
                demo_df['age'].values,
                demo_df['gender'].values,
                demo_df['mpo'].values,
                demo_df['wab_aq'].values
            ]
            
            demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
        
        # Look for NIfTI files in P01_rs.nii format
        print("Searching for NIfTI files in dataset columns...")
        nii_files = []
        
        # Create a temp directory for downloads
        import tempfile
        from huggingface_hub import hf_hub_download
        import shutil
        
        temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
        print(f"Created temporary directory for NIfTI files: {temp_dir}")
        
        try:
            # First approach: Check if there are any columns containing file paths
            nii_columns = []
            for col in dataset.column_names:
                # Check if column name suggests NIfTI files
                if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
                    nii_columns.append(col)
                # Or check if column contains file paths
                elif len(dataset) > 0:
                    first_val = dataset[0][col]
                    if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
                        nii_columns.append(col)
            
            if nii_columns:
                print(f"Found columns that may contain NIfTI files: {nii_columns}")
                
                for col in nii_columns:
                    print(f"Processing column '{col}'...")
                    
                    for i, item in enumerate(dataset[col]):
                        if not isinstance(item, str):
                            print(f"Item {i} in column {col} is not a string but {type(item)}")
                            continue
                            
                        if not (item.endswith('.nii') or item.endswith('.nii.gz')):
                            print(f"Item {i} in column {col} is not a NIfTI file: {item}")
                            continue
                            
                        print(f"Downloading {item} from dataset {dataset_name}...")
                        
                        try:
                            # Attempt to download with explicit filename
                            file_path = hf_hub_download(
                                repo_id=dataset_name,
                                filename=item,
                                repo_type="dataset",
                                cache_dir=temp_dir
                            )
                            nii_files.append(file_path)
                            print(f"βœ“ Successfully downloaded {item}")
                        except Exception as e1:
                            print(f"Error downloading with explicit filename: {e1}")
                            
                            # Second attempt: try with the item's basename
                            try:
                                basename = os.path.basename(item)
                                print(f"Trying with basename: {basename}")
                                file_path = hf_hub_download(
                                    repo_id=dataset_name,
                                    filename=basename,
                                    repo_type="dataset",
                                    cache_dir=temp_dir
                                )
                                nii_files.append(file_path)
                                print(f"βœ“ Successfully downloaded {basename}")
                            except Exception as e2:
                                print(f"Error downloading with basename: {e2}")
                                
                                # Third attempt: check if it's a binary blob in the dataset
                                try:
                                    if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
                                        print("Found binary data in dataset, saving to temporary file...")
                                        binary_data = dataset[i]['bytes']
                                        temp_file = os.path.join(temp_dir, basename)
                                        with open(temp_file, 'wb') as f:
                                            f.write(binary_data)
                                        nii_files.append(temp_file)
                                        print(f"βœ“ Saved binary data to {temp_file}")
                                except Exception as e3:
                                    print(f"Error handling binary data: {e3}")
                                    
                                # Last resort: look for the file locally
                                local_path = os.path.join(os.getcwd(), item)
                                if os.path.exists(local_path):
                                    nii_files.append(local_path)
                                    print(f"βœ“ Found {item} locally")
                                else:
                                    print(f"❌ Warning: Could not find {item} anywhere")
            
            # Second approach: Try to find NIfTI files in dataset repository directly
            if not nii_files:
                print("No NIfTI files found in dataset columns. Trying direct repository search...")
                
                try:
                    from huggingface_hub import list_repo_files, hf_hub_download
                    
                    # Try to list all files in the repository
                    try:
                        print("Listing all repository files...")
                        all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
                        print(f"Found {len(all_repo_files)} files in repository")
                        
                        # First prioritize P*_rs.nii files
                        p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
                        
                        # Then include all other NIfTI files
                        other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
                        
                        # Combine, with P*_rs.nii files first
                        nii_repo_files = p_rs_files + other_nii_files
                        
                        if nii_repo_files:
                            print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
                            
                            # Download each file
                            for nii_file in nii_repo_files:
                                try:
                                    file_path = hf_hub_download(
                                        repo_id=dataset_name,
                                        filename=nii_file,
                                        repo_type="dataset",
                                        cache_dir=temp_dir
                                    )
                                    nii_files.append(file_path)
                                    print(f"βœ“ Downloaded {nii_file}")
                                except Exception as e:
                                    print(f"Error downloading {nii_file}: {e}")
                    except Exception as e:
                        print(f"Error listing repository files: {e}")
                        print("Will try alternative approaches...")
                    
                    # If repo listing fails, try with common NIfTI file patterns directly
                    if not nii_files:
                        print("Trying common NIfTI file patterns...")
                        
                        # Focus specifically on P*_rs.nii pattern
                        patterns = []
                        
                        # Generate P01_rs.nii through P30_rs.nii
                        for i in range(1, 31):  # Try subjects 1-30
                            patterns.append(f"P{i:02d}_rs.nii")
                            
                        # Also try with .nii.gz extension
                        for i in range(1, 31):
                            patterns.append(f"P{i:02d}_rs.nii.gz")
                            
                        # Include a few other common patterns as fallbacks
                        patterns.extend([
                            "sub-01_task-rest_bold.nii.gz",  # BIDS format
                            "fmri.nii.gz", "bold.nii.gz", 
                            "rest.nii.gz"
                        ])
                        
                        for pattern in patterns:
                            try:
                                print(f"Trying to download {pattern}...")
                                file_path = hf_hub_download(
                                    repo_id=dataset_name,
                                    filename=pattern,
                                    repo_type="dataset",
                                    cache_dir=temp_dir
                                )
                                nii_files.append(file_path)
                                print(f"βœ“ Successfully downloaded {pattern}")
                            except Exception as e:
                                print(f"Γ— Failed to download {pattern}")
                                
                    # If we still couldn't find any files, check if data files are nested
                    if not nii_files:
                        print("Checking for nested data files...")
                        nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
                        
                        for path in nested_paths:
                            for pattern in patterns:
                                nested_file = f"{path}{pattern}"
                                try:
                                    print(f"Trying to download {nested_file}...")
                                    file_path = hf_hub_download(
                                        repo_id=dataset_name,
                                        filename=nested_file,
                                        repo_type="dataset",
                                        cache_dir=temp_dir
                                    )
                                    nii_files.append(file_path)
                                    print(f"βœ“ Successfully downloaded {nested_file}")
                                    # If we found one file in this directory, try to find all files in it
                                    try:
                                        all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
                                        nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
                                        print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
                                        
                                        for nii_file in nii_files_in_dir:
                                            if nii_file != nested_file:  # Skip the one we already downloaded
                                                try:
                                                    file_path = hf_hub_download(
                                                        repo_id=dataset_name,
                                                        filename=nii_file,
                                                        repo_type="dataset",
                                                        cache_dir=temp_dir
                                                    )
                                                    nii_files.append(file_path)
                                                    print(f"βœ“ Downloaded {nii_file}")
                                                except Exception as e:
                                                    print(f"Error downloading {nii_file}: {e}")
                                    except Exception as e:
                                        print(f"Error finding additional files in {path}: {e}")
                                except Exception as e:
                                    pass
                                
                except Exception as e:
                    print(f"Error during repository exploration: {e}")
            
            # If we still don't have any files, try to search for P*_rs.nii pattern specifically
            if not nii_files:
                print("Trying to find files matching P*_rs.nii pattern specifically...")
                
                try:
                    # List all files in the repository (if we haven't already)
                    if not 'all_repo_files' in locals():
                        from huggingface_hub import list_repo_files
                        try:
                            all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
                        except Exception as e:
                            print(f"Error listing repo files: {e}")
                            all_repo_files = []
                    
                    # Look for files matching the pattern exactly (P*_rs.nii)
                    pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
                    
                    # If we don't find any exact matches, try a more relaxed pattern
                    if not pattern_files:
                        pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
                    
                    if pattern_files:
                        print(f"Found {len(pattern_files)} files matching rs.nii pattern")
                        
                        # Download each file
                        for pattern_file in pattern_files:
                            try:
                                file_path = hf_hub_download(
                                    repo_id=dataset_name,
                                    filename=pattern_file,
                                    repo_type="dataset",
                                    cache_dir=temp_dir
                                )
                                nii_files.append(file_path)
                                print(f"βœ“ Downloaded {pattern_file}")
                            except Exception as e:
                                print(f"Error downloading {pattern_file}: {e}")
                except Exception as e:
                    print(f"Error searching for pattern files: {e}")
                
            print(f"Found total of {len(nii_files)} NIfTI files")
        except Exception as e:
            print(f"Unexpected error during NIfTI file search: {e}")
            import traceback
            traceback.print_exc()
        
        # If we found NIfTI files, process them to FC matrices
        if nii_files:
            print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
            
            # Load Power 264 atlas
            from nilearn import datasets
            power = datasets.fetch_coords_power_2011()
            coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
            
            masker = input_data.NiftiSpheresMasker(
                coords, radius=5,
                standardize=True,
                memory='nilearn_cache', memory_level=1,
                verbose=0,
                detrend=True,
                low_pass=0.1,
                high_pass=0.01,
                t_r=2.0  # Adjust TR according to your data
            )
            
            # Process fMRI data and compute FC matrices
            fc_matrices = []
            valid_files = 0
            total_files = len(nii_files)
            
            for nii_file in nii_files:
                try:
                    print(f"Processing {nii_file}...")
                    fmri_img = load_img(nii_file)
                    
                    # Check image dimensions
                    if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
                        print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
                        continue
                        
                    try:
                        # Explicitly handle warnings about empty spheres
                        import warnings
                        with warnings.catch_warnings():
                            warnings.filterwarnings('ignore', message='.*empty.*')
                            time_series = masker.fit_transform(fmri_img)
                    except Exception as e:
                        if "empty" in str(e):
                            print(f"Warning: Some spheres are empty in {nii_file}. Using a different sphere radius.")
                            
                            # Extract the list of empty spheres for logging
                            import re
                            empty_spheres = re.findall(r"\[(.*?)\]", str(e))
                            if empty_spheres:
                                print(f"Empty spheres: {empty_spheres[0]}")
                            
                            # Try with a different radius
                            alternate_masker = input_data.NiftiSpheresMasker(
                                coords, radius=8,  # Larger radius
                                standardize=True,
                                memory='nilearn_cache', memory_level=1,
                                verbose=0,
                                detrend=True,
                                low_pass=0.1,
                                high_pass=0.01,
                                t_r=2.0
                            )
                            try:
                                time_series = alternate_masker.fit_transform(fmri_img)
                                print(f"Successfully extracted time series with larger radius")
                            except Exception as e2:
                                print(f"Error with alternate masker: {e2}")
                                print(f"Skipping this file due to empty spheres")
                                continue  # Skip this file entirely
                        else:
                            print(f"Unknown error in masker: {e}")
                            continue  # Skip this file if there's any other error
                        
                    # Validate time series data
                    if np.isnan(time_series).any() or np.isinf(time_series).any():
                        print(f"Warning: {nii_file} contains NaN or Inf values after masking")
                        # Replace NaNs with zeros for this file
                        time_series = np.nan_to_num(time_series)
                    
                    correlation_measure = connectome.ConnectivityMeasure(
                        kind='correlation',
                        vectorize=False,
                        discard_diagonal=False
                    )
                    
                    fc_matrix = correlation_measure.fit_transform([time_series])[0]
                    
                    # Check for invalid correlation values
                    if np.isnan(fc_matrix).any():
                        print(f"Warning: {nii_file} produced NaN correlation values")
                        continue
                        
                    triu_indices = np.triu_indices_from(fc_matrix, k=1)
                    fc_triu = fc_matrix[triu_indices]
                    
                    # Fisher z-transform with proper bounds check
                    # Clip correlation values to valid range for arctanh
                    fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
                    fc_triu = np.arctanh(fc_triu_clipped)
                    
                    fc_matrices.append(fc_triu)
                    valid_files += 1
                    print(f"Successfully processed {nii_file} to FC matrix")
                    
                except Exception as e:
                    print(f"Error processing {nii_file}: {e}")
            
            if fc_matrices:
                print(f"Successfully processed {valid_files} out of {total_files} files")
                
                # Ensure all matrices have the same dimensions
                dims = [m.shape[0] for m in fc_matrices]
                if len(set(dims)) > 1:
                    print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
                    # Use the most common dimension
                    from collections import Counter
                    most_common_dim = Counter(dims).most_common(1)[0][0]
                    print(f"Using most common dimension: {most_common_dim}")
                    fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
                
                X = np.array(fc_matrices)
                
                # Normalize the FC data
                mean_x = np.mean(X, axis=0)
                std_x = np.std(X, axis=0)
                
                # Handle zero standard deviation
                std_x[std_x == 0] = 1.0
                
                X = (X - mean_x) / std_x
                print(f"Created FC matrices with shape {X.shape}")
                
                # Make sure demo_data matches the number of FC matrices
                if len(demo_data[0]) != X.shape[0]:
                    print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
                          f"doesn't match number of FC matrices ({X.shape[0]})")
                    # Adjust demo_data to match FC matrices
                    indices = list(range(min(len(demo_data[0]), X.shape[0])))
                    X = X[indices]
                    demo_data = [d[indices] for d in demo_data]
                    
                return X, demo_data, demo_types

        print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
        # Return a placeholder with the right demographics but empty FC
        n_subjects = len(dataset)
        n_rois = 264
        fc_dim = (n_rois * (n_rois - 1)) // 2
        X = np.zeros((n_subjects, fc_dim))
        print(f"Created placeholder FC matrices with shape {X.shape}")
        return X, demo_data, demo_types
        
    elif isinstance(dataset_or_niifiles, str):
        # Handle real dataset with actual fMRI data
        dataset = load_dataset(dataset_or_niifiles, split="train")
        
        # Load Power 264 atlas
        from nilearn import datasets
        power = datasets.fetch_coords_power_2011()
        coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
        
        masker = input_data.NiftiSpheresMasker(
            coords, radius=5,
            standardize=True,
            memory='nilearn_cache', memory_level=1,
            verbose=0,
            detrend=True,
            low_pass=0.1,
            high_pass=0.01,
            t_r=2.0  # Adjust TR according to your data
        )

        # Load demographic data if needed
        if demo_data is None:
            if 'demographics' in dataset.features:
                demo_df = pd.DataFrame(dataset['demographics'])
                
                demo_data = [
                    demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
                    demo_df['sex'].values if 'sex' in demo_df.columns else [],
                    demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
                    demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
                ]
                
                demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
        
        # Process fMRI data and compute FC matrices
        fc_matrices = []
        for nii_file in dataset['nii_files']:
            fmri_img = load_img(nii_file)
            time_series = masker.fit_transform(fmri_img)
            
            correlation_measure = connectome.ConnectivityMeasure(
                kind='correlation', vectorize=False, discard_diagonal=False
            )
            
            fc_matrix = correlation_measure.fit_transform([time_series])[0]
            
            triu_indices = np.triu_indices_from(fc_matrix, k=1)
            fc_triu = fc_matrix[triu_indices]
            
            fc_triu = np.arctanh(fc_triu)  # Fisher z-transform
            
            fc_matrices.append(fc_triu)
        
        X = np.array(fc_matrices)
        
    elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
        # Handle a list of NIfTI files
        # Similar processing as above but with local files
        print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
        
        # Load Power 264 atlas
        from nilearn import datasets
        power = datasets.fetch_coords_power_2011()
        coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
        
        masker = input_data.NiftiSpheresMasker(
            coords, radius=5,
            standardize=True,
            memory='nilearn_cache', memory_level=1,
            verbose=0,
            detrend=True,
            low_pass=0.1,
            high_pass=0.01,
            t_r=2.0
        )
        
        fc_matrices = []
        for nii_file in dataset_or_niifiles:
            fmri_img = load_img(nii_file)
            time_series = masker.fit_transform(fmri_img)
            
            correlation_measure = connectome.ConnectivityMeasure(
                kind='correlation', vectorize=False, discard_diagonal=False
            )
            
            fc_matrix = correlation_measure.fit_transform([time_series])[0]
            
            triu_indices = np.triu_indices_from(fc_matrix, k=1)
            fc_triu = fc_matrix[triu_indices]
            
            fc_triu = np.arctanh(fc_triu)  # Fisher z-transform
            
            fc_matrices.append(fc_triu)
            
        X = np.array(fc_matrices)
    else:
        raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
    
    # Normalize the FC data
    X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
    
    return X, demo_data, demo_types