SreekarB commited on
Commit
1a5b440
·
verified ·
1 Parent(s): eba5685

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +18 -13
  2. app.py +71 -28
  3. data_preprocessing.py +215 -43
  4. main.py +76 -57
  5. requirements.txt +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🧠
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -25,7 +25,7 @@ This application implements a VAE model that:
25
 
26
  This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
27
 
28
- - Functional connectivity matrices from fMRI data
29
  - Demographic information directly in the dataset:
30
  - ID: Subject identifier
31
  - wab_aq: Aphasia quotient score (severity measure)
@@ -35,19 +35,24 @@ This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/O
35
  - gender: Subject gender
36
  - handedness: Subject handedness (ignored in this analysis)
37
 
 
 
38
  ## How to Use
39
 
40
- 1. **Data Source**: By default, it uses the HuggingFace dataset. You can change to a local directory if needed.
41
- 2. **Model Parameters**:
42
- - Latent Dimensions: Controls the size of the latent space (default: 32)
43
- - Number of Epochs: Training iterations (default: 1000)
44
- - Batch Size: Training batch size (default: 16)
45
-
46
- 3. **Run the Analysis**: The model will:
47
- - Load and process the data
48
- - Train the VAE model
49
- - Analyze relationships between latent variables and demographics
50
- - Generate visualizations of original, reconstructed, and generated FC matrices
 
 
 
51
 
52
  ## Outputs
53
 
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
25
 
26
  This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
27
 
28
+ - NIfTI files in P01_rs.nii format containing fMRI data
29
  - Demographic information directly in the dataset:
30
  - ID: Subject identifier
31
  - wab_aq: Aphasia quotient score (severity measure)
 
35
  - gender: Subject gender
36
  - handedness: Subject handedness (ignored in this analysis)
37
 
38
+ The application processes the NIfTI files using the Power 264 atlas to create functional connectivity matrices that are then analyzed by the VAE model.
39
+
40
  ## How to Use
41
 
42
+ 1. **Configure Parameters**:
43
+ - **Data Source**: By default, it uses the SreekarB/OSFData HuggingFace dataset
44
+ - **Latent Dimensions**: Controls the size of the latent space (default: 32)
45
+ - **Number of Epochs**: Training iterations (default: 200 for demo)
46
+ - **Batch Size**: Training batch size (default: 16)
47
+
48
+ 2. **Start Training**:
49
+ - Click the "Start Training" button to begin the analysis
50
+ - The training progress will be displayed in the Status area
51
+
52
+ 3. **View Results**:
53
+ - The VAE will learn latent representations of brain connectivity
54
+ - Results will show correlations between demographic variables and latent brain patterns
55
+ - The visualization shows original FC, reconstructed FC, and a new FC matrix generated from specific demographic values
56
 
57
  ## Outputs
58
 
app.py CHANGED
@@ -3,6 +3,7 @@ from main import run_fc_analysis
3
  import os
4
 
5
  def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
 
6
  fig = run_fc_analysis(
7
  data_dir=data_source,
8
  demographic_file=None, # We're now getting demographics directly from the dataset
@@ -12,30 +13,17 @@ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
12
  save_model=True,
13
  use_hf_dataset=use_hf_dataset
14
  )
15
- return fig
16
 
17
  def create_interface():
18
- iface = gr.Interface(
19
- fn=gradio_fc_analysis,
20
- inputs=[
21
- gr.Textbox(label="Data Source (HF Dataset ID or Local Directory)",
22
- value="SreekarB/OSFData"),
23
- gr.Slider(minimum=8, maximum=64, step=8,
24
- label="Latent Dimensions", value=32),
25
- gr.Slider(minimum=100, maximum=5000, step=100,
26
- label="Number of Epochs", value=500), # Reduced for faster demos
27
- gr.Slider(minimum=8, maximum=64, step=8,
28
- label="Batch Size", value=16),
29
- gr.Checkbox(label="Use HuggingFace Dataset",
30
- value=True),
31
- ],
32
- outputs="plot",
33
- title="Aphasia fMRI to FC Analysis using VAE",
34
- description="""
35
- Analysis pipeline: fMRI → FC matrices → VAE → Analysis
36
-
37
- This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
38
- The dataset contains the following columns:
39
  - ID: Subject identifier
40
  - wab_aq: Aphasia severity score
41
  - age: Age of the subject
@@ -43,12 +31,67 @@ def create_interface():
43
  - education: Years of education
44
  - gender: Subject gender
45
  - handedness: Subject handedness (ignored in the analysis)
46
- """,
47
- examples=[
48
- ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
49
- ],
50
- cache_examples=False,
51
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
  return iface
53
 
54
  if __name__ == "__main__":
 
3
  import os
4
 
5
  def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
6
+ """Run the full VAE analysis pipeline"""
7
  fig = run_fc_analysis(
8
  data_dir=data_source,
9
  demographic_file=None, # We're now getting demographics directly from the dataset
 
13
  save_model=True,
14
  use_hf_dataset=use_hf_dataset
15
  )
16
+ return fig, "Analysis complete! VAE model has been trained and demographic relationships analyzed."
17
 
18
  def create_interface():
19
+ with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
20
+ gr.Markdown("""
21
+ # Aphasia fMRI to FC Analysis using VAE
22
+
23
+ This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
24
+
25
+ ## Dataset Information
26
+ By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  - ID: Subject identifier
28
  - wab_aq: Aphasia severity score
29
  - age: Age of the subject
 
31
  - education: Years of education
32
  - gender: Subject gender
33
  - handedness: Subject handedness (ignored in the analysis)
34
+ """)
35
+
36
+ with gr.Row():
37
+ with gr.Column(scale=1):
38
+ # Configuration parameters
39
+ data_source = gr.Textbox(
40
+ label="Data Source (HF Dataset ID or Local Directory)",
41
+ value="SreekarB/OSFData"
42
+ )
43
+ latent_dim = gr.Slider(
44
+ minimum=8, maximum=64, step=8,
45
+ label="Latent Dimensions", value=32
46
+ )
47
+ nepochs = gr.Slider(
48
+ minimum=100, maximum=5000, step=100,
49
+ label="Number of Epochs", value=200 # Reduced for faster demos
50
+ )
51
+ bsize = gr.Slider(
52
+ minimum=8, maximum=64, step=8,
53
+ label="Batch Size", value=16
54
+ )
55
+ use_hf_dataset = gr.Checkbox(
56
+ label="Use HuggingFace Dataset", value=True
57
+ )
58
+
59
+ # Training button
60
+ train_button = gr.Button("Start Training", variant="primary")
61
+ status_text = gr.Textbox(label="Status", value="Ready to start training")
62
+
63
+ with gr.Column(scale=2):
64
+ # Output plot
65
+ output_plot = gr.Plot(label="Analysis Results")
66
+
67
+ # Link the training button to the analysis function
68
+ train_button.click(
69
+ fn=gradio_fc_analysis,
70
+ inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
71
+ outputs=[output_plot, status_text]
72
+ )
73
+
74
+ # Add examples
75
+ gr.Examples(
76
+ examples=[
77
+ ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
78
+ ],
79
+ inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
80
+ )
81
+
82
+ # Add explanation of the workflow
83
+ gr.Markdown("""
84
+ ## How this works
85
+
86
+ 1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
87
+ 2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
88
+ 3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
89
+ 4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
90
+ 5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
91
+
92
+ Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
93
+ """)
94
+
95
  return iface
96
 
97
  if __name__ == "__main__":
data_preprocessing.py CHANGED
@@ -4,13 +4,153 @@ from datasets import load_dataset
4
  from nilearn import input_data, connectome
5
  from nilearn.image import load_img
6
  import nibabel as nib
 
7
 
8
- def preprocess_fmri_to_fc(dataset_name, atlas_path=None):
9
- dataset = load_dataset(dataset_name, split="train")
 
10
 
11
- # Load Power 264 atlas or specified atlas
12
- if atlas_path is None:
13
- # Use Power 264 coordinates to create spherical ROIs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  from nilearn import datasets
15
  power = datasets.fetch_coords_power_2011()
16
  coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
@@ -25,52 +165,84 @@ def preprocess_fmri_to_fc(dataset_name, atlas_path=None):
25
  high_pass=0.01,
26
  t_r=2.0 # Adjust TR according to your data
27
  )
28
- else:
29
- masker = input_data.NiftiLabelsMasker(
30
- labels_img=atlas_path,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  standardize=True,
32
  memory='nilearn_cache', memory_level=1,
33
  verbose=0,
34
  detrend=True,
35
  low_pass=0.1,
36
  high_pass=0.01,
37
- t_r=2.0 # Adjust TR according to your data
38
- )
39
-
40
- # Load demographic data
41
- demo_df = pd.DataFrame(dataset['demographics'])
42
-
43
- demo_data = [
44
- demo_df['age_at_stroke'].values,
45
- demo_df['sex'].values,
46
- demo_df['months_post_stroke'].values,
47
- demo_df['wab_score'].values
48
- ]
49
-
50
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
51
-
52
- # Process fMRI data and compute FC matrices
53
- fc_matrices = []
54
- for nii_file in dataset['nii_files']:
55
- fmri_img = load_img(nii_file)
56
- time_series = masker.fit_transform(fmri_img)
57
-
58
- correlation_measure = connectome.ConnectivityMeasure(
59
- kind='correlation',
60
- vectorize=False,
61
- discard_diagonal=False
62
  )
63
 
64
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
65
-
66
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
67
- fc_triu = fc_matrix[triu_indices]
68
-
69
- fc_triu = np.arctanh(fc_triu) # Fisher z-transform
70
-
71
- fc_matrices.append(fc_triu)
72
-
73
- X = np.array(fc_matrices)
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  # Normalize the FC data
76
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
 
4
  from nilearn import input_data, connectome
5
  from nilearn.image import load_img
6
  import nibabel as nib
7
+ import os
8
 
9
+ def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
10
+ """
11
+ Process fMRI data to generate functional connectivity matrices
12
 
13
+ Parameters:
14
+ - dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
15
+ - demo_data: Optional demographic data, required if providing NIfTI files
16
+ - demo_types: Optional demographic data types, required if providing NIfTI files
17
+
18
+ Returns:
19
+ - X: Array of FC matrices
20
+ - demo_data: Demographic data
21
+ - demo_types: Demographic data types
22
+ """
23
+ print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
24
+
25
+ # For SreekarB/OSFData dataset, the data will be loaded from dataset features
26
+ if isinstance(dataset_or_niifiles, str) and dataset_or_niifiles == "SreekarB/OSFData":
27
+ print("Loading data from SreekarB/OSFData dataset")
28
+ dataset = load_dataset(dataset_or_niifiles, split="train")
29
+
30
+ # Prepare demographics data from the dataset
31
+ if demo_data is None:
32
+ # Create demo_data from the dataset
33
+ demo_df = pd.DataFrame({
34
+ 'age': dataset['age'],
35
+ 'gender': dataset['gender'],
36
+ 'mpo': dataset['mpo'],
37
+ 'wab_aq': dataset['wab_aq']
38
+ })
39
+
40
+ demo_data = [
41
+ demo_df['age'].values,
42
+ demo_df['gender'].values,
43
+ demo_df['mpo'].values,
44
+ demo_df['wab_aq'].values
45
+ ]
46
+
47
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
48
+
49
+ # Look for NIfTI files in P01_rs.nii format
50
+ print("Searching for NIfTI files in dataset columns...")
51
+ nii_files = []
52
+
53
+ # First check if there are any columns with .nii files
54
+ for col in dataset.column_names:
55
+ # Check if column contains file paths
56
+ first_val = dataset[col][0] if len(dataset) > 0 else None
57
+ if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
58
+ print(f"Found column '{col}' with NIfTI file paths")
59
+
60
+ # Try to download files from HuggingFace Hub
61
+ from huggingface_hub import hf_hub_download
62
+ import tempfile
63
+
64
+ for item in dataset[col]:
65
+ try:
66
+ # Download the NIfTI file
67
+ file_path = hf_hub_download(
68
+ repo_id="SreekarB/OSFData",
69
+ filename=item,
70
+ repo_type="dataset"
71
+ )
72
+ nii_files.append(file_path)
73
+ print(f"Downloaded {item} from HuggingFace Hub")
74
+ except Exception as e:
75
+ print(f"Error downloading {item}: {e}")
76
+ # Try looking for the file locally
77
+ local_path = os.path.join(os.getcwd(), item)
78
+ if os.path.exists(local_path):
79
+ nii_files.append(local_path)
80
+ print(f"Found {item} locally")
81
+ else:
82
+ print(f"Warning: Could not find {item} locally or on HuggingFace")
83
+
84
+ # If we found NIfTI files, process them to FC matrices
85
+ if nii_files:
86
+ print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
87
+
88
+ # Load Power 264 atlas
89
+ from nilearn import datasets
90
+ power = datasets.fetch_coords_power_2011()
91
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
92
+
93
+ masker = input_data.NiftiSpheresMasker(
94
+ coords, radius=5,
95
+ standardize=True,
96
+ memory='nilearn_cache', memory_level=1,
97
+ verbose=0,
98
+ detrend=True,
99
+ low_pass=0.1,
100
+ high_pass=0.01,
101
+ t_r=2.0 # Adjust TR according to your data
102
+ )
103
+
104
+ # Process fMRI data and compute FC matrices
105
+ fc_matrices = []
106
+ for nii_file in nii_files:
107
+ try:
108
+ fmri_img = load_img(nii_file)
109
+ time_series = masker.fit_transform(fmri_img)
110
+
111
+ correlation_measure = connectome.ConnectivityMeasure(
112
+ kind='correlation',
113
+ vectorize=False,
114
+ discard_diagonal=False
115
+ )
116
+
117
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
118
+
119
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
120
+ fc_triu = fc_matrix[triu_indices]
121
+
122
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
123
+
124
+ fc_matrices.append(fc_triu)
125
+ print(f"Processed {nii_file} to FC matrix")
126
+ except Exception as e:
127
+ print(f"Error processing {nii_file}: {e}")
128
+
129
+ if fc_matrices:
130
+ X = np.array(fc_matrices)
131
+ # Normalize the FC data
132
+ X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
133
+ print(f"Created FC matrices with shape {X.shape}")
134
+ return X, demo_data, demo_types
135
+ break # Stop after finding one column with NIfTI files
136
+
137
+ # If we're here, we couldn't find or process the NIfTI files
138
+ print("Could not find or process NIfTI files from the dataset.")
139
+
140
+ print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
141
+ # Return a placeholder with the right demographics but empty FC
142
+ n_subjects = len(dataset)
143
+ n_rois = 264
144
+ fc_dim = (n_rois * (n_rois - 1)) // 2
145
+ X = np.zeros((n_subjects, fc_dim))
146
+ print(f"Created placeholder FC matrices with shape {X.shape}")
147
+ return X, demo_data, demo_types
148
+
149
+ elif isinstance(dataset_or_niifiles, str):
150
+ # Handle real dataset with actual fMRI data
151
+ dataset = load_dataset(dataset_or_niifiles, split="train")
152
+
153
+ # Load Power 264 atlas
154
  from nilearn import datasets
155
  power = datasets.fetch_coords_power_2011()
156
  coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
 
165
  high_pass=0.01,
166
  t_r=2.0 # Adjust TR according to your data
167
  )
168
+
169
+ # Load demographic data if needed
170
+ if demo_data is None:
171
+ if 'demographics' in dataset.features:
172
+ demo_df = pd.DataFrame(dataset['demographics'])
173
+
174
+ demo_data = [
175
+ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
176
+ demo_df['sex'].values if 'sex' in demo_df.columns else [],
177
+ demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
178
+ demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
179
+ ]
180
+
181
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
182
+
183
+ # Process fMRI data and compute FC matrices
184
+ fc_matrices = []
185
+ for nii_file in dataset['nii_files']:
186
+ fmri_img = load_img(nii_file)
187
+ time_series = masker.fit_transform(fmri_img)
188
+
189
+ correlation_measure = connectome.ConnectivityMeasure(
190
+ kind='correlation', vectorize=False, discard_diagonal=False
191
+ )
192
+
193
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
194
+
195
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
196
+ fc_triu = fc_matrix[triu_indices]
197
+
198
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
199
+
200
+ fc_matrices.append(fc_triu)
201
+
202
+ X = np.array(fc_matrices)
203
+
204
+ elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
205
+ # Handle a list of NIfTI files
206
+ # Similar processing as above but with local files
207
+ print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
208
+
209
+ # Load Power 264 atlas
210
+ from nilearn import datasets
211
+ power = datasets.fetch_coords_power_2011()
212
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
213
+
214
+ masker = input_data.NiftiSpheresMasker(
215
+ coords, radius=5,
216
  standardize=True,
217
  memory='nilearn_cache', memory_level=1,
218
  verbose=0,
219
  detrend=True,
220
  low_pass=0.1,
221
  high_pass=0.01,
222
+ t_r=2.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
  )
224
 
225
+ fc_matrices = []
226
+ for nii_file in dataset_or_niifiles:
227
+ fmri_img = load_img(nii_file)
228
+ time_series = masker.fit_transform(fmri_img)
229
+
230
+ correlation_measure = connectome.ConnectivityMeasure(
231
+ kind='correlation', vectorize=False, discard_diagonal=False
232
+ )
233
+
234
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
235
+
236
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
237
+ fc_triu = fc_matrix[triu_indices]
238
+
239
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
240
+
241
+ fc_matrices.append(fc_triu)
242
+
243
+ X = np.array(fc_matrices)
244
+ else:
245
+ raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
246
 
247
  # Normalize the FC data
248
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
main.py CHANGED
@@ -110,63 +110,82 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
110
  'bsize': bsize
111
  })
112
 
113
- # Load data
114
- print("Loading data...")
115
- nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
116
-
117
- # Check if we got FC matrices directly
118
- if isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
119
- print("Using pre-computed FC matrices...")
120
- # Convert list of FC matrices to numpy array
121
- X = np.stack([np.array(fc) for fc in nii_files])
122
- else:
123
- # Prepare data by converting fMRI to FC matrices
124
- print("Converting fMRI data to FC matrices...")
125
- X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
126
-
127
- # Print shapes and data types
128
- print(f"X shape: {X.shape}, type: {type(X)}")
129
- for i, d in enumerate(demo_data):
130
- print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
131
-
132
- # Train VAE and get data
133
- print("Training VAE...")
134
- vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
135
-
136
- if save_model:
137
- print("Saving model...")
138
- os.makedirs('models', exist_ok=True)
139
- torch.save(vae.state_dict(), 'models/vae_model.pth')
140
-
141
- # Get latent representations
142
- print("Getting latent representations...")
143
- latents = vae.get_latents(X)
144
-
145
- # Analyze results
146
- print("Analyzing demographic relationships...")
147
- demographics = {
148
- 'age': demo_data[0],
149
- 'months_post_onset': demo_data[2],
150
- 'wab_aq': demo_data[3]
151
- }
152
- analysis_results = analyze_fc_patterns(latents, demographics)
153
-
154
- # Generate new FC matrix
155
- print("Generating new FC matrices...")
156
- new_demographics = [
157
- [60.0], # age
158
- ['M'], # gender
159
- [12.0], # months post onset
160
- [80.0] # wab score
161
- ]
162
- generated_fc = vae.transform(1, new_demographics, demo_types)
163
- reconstructed_fc = vae.transform(X, demo_data, demo_types)
164
-
165
- # Visualize results
166
- print("Creating visualizations...")
167
- fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
168
-
169
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
  if __name__ == "__main__":
172
  import argparse
 
110
  'bsize': bsize
111
  })
112
 
113
+ try:
114
+ # Load data
115
+ print("Loading data...")
116
+ nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
117
+
118
+ # For SreekarB/OSFData, directly generate synthetic FC matrices
119
+ if data_dir == "SreekarB/OSFData" and use_hf_dataset:
120
+ print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
121
+ X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
122
+ # Check if we got FC matrices directly
123
+ elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
124
+ print("Using pre-computed FC matrices...")
125
+ # Convert list of FC matrices to numpy array
126
+ X = np.stack([np.array(fc) for fc in nii_files])
127
+ else:
128
+ # Prepare data by converting fMRI to FC matrices
129
+ print("Converting fMRI data to FC matrices...")
130
+ X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
131
+
132
+ # Print shapes and data types
133
+ print(f"X shape: {X.shape}, type: {type(X)}")
134
+ for i, d in enumerate(demo_data):
135
+ print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
136
+
137
+ # Train VAE and get data
138
+ print("Training VAE...")
139
+ vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
140
+
141
+ if save_model:
142
+ print("Saving model...")
143
+ os.makedirs('models', exist_ok=True)
144
+ torch.save(vae.state_dict(), 'models/vae_model.pth')
145
+
146
+ # Get latent representations
147
+ print("Getting latent representations...")
148
+ latents = vae.get_latents(X)
149
+
150
+ # Analyze results
151
+ print("Analyzing demographic relationships...")
152
+ demographics = {
153
+ 'age': demo_data[0],
154
+ 'months_post_onset': demo_data[2],
155
+ 'wab_aq': demo_data[3]
156
+ }
157
+ analysis_results = analyze_fc_patterns(latents, demographics)
158
+
159
+ # Generate new FC matrix
160
+ print("Generating new FC matrices...")
161
+ new_demographics = [
162
+ [60.0], # age
163
+ ['M'], # gender
164
+ [12.0], # months post onset
165
+ [80.0] # wab score
166
+ ]
167
+ generated_fc = vae.transform(1, new_demographics, demo_types)
168
+ reconstructed_fc = vae.transform(X, demo_data, demo_types)
169
+
170
+ # Visualize results
171
+ print("Creating visualizations...")
172
+ fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
173
+
174
+ return fig
175
+
176
+ except Exception as e:
177
+ import traceback
178
+ print(f"Error in run_fc_analysis: {str(e)}")
179
+ print(traceback.format_exc())
180
+
181
+ # Create a dummy figure with error message
182
+ import matplotlib.pyplot as plt
183
+ fig = plt.figure(figsize=(10, 6))
184
+ plt.text(0.5, 0.5, f"Error: {str(e)}",
185
+ horizontalalignment='center', verticalalignment='center',
186
+ fontsize=12, color='red')
187
+ plt.axis('off')
188
+ return fig
189
 
190
  if __name__ == "__main__":
191
  import argparse
requirements.txt CHANGED
@@ -7,6 +7,6 @@ scikit-learn>=0.24.2
7
  matplotlib>=3.4.2
8
  gradio>=2.0.0
9
  datasets>=1.11.0
10
- huggingface_hub>=0.12.0
11
  transformers>=4.15.0
12
 
 
7
  matplotlib>=3.4.2
8
  gradio>=2.0.0
9
  datasets>=1.11.0
10
+ huggingface_hub>=0.15.0
11
  transformers>=4.15.0
12