SreekarB commited on
Commit
b32645b
·
verified ·
1 Parent(s): e4a8a19

Upload 12 files

Browse files
Files changed (11) hide show
  1. README .md +70 -0
  2. app.py +93 -258
  3. config.py +3 -16
  4. data_preprocessing.py +483 -1138
  5. gitattributes +35 -0
  6. main.py +233 -459
  7. requirements.txt +6 -2
  8. test_hf_download.py +2 -2
  9. utils.py +61 -78
  10. vae_model.py +131 -336
  11. visualization.py +32 -509
README .md ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Aphasia fMRI VAE Analysis
3
+ emoji: 🧠
4
+ colorFrom: blue
5
+ colorTo: pink
6
+ sdk: gradio
7
+ sdk_version: 5.20.1
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Aphasia fMRI to FC Analysis using VAE
13
+
14
+ This demo performs functional connectivity analysis on fMRI data using a Variational Autoencoder (VAE) approach. It's designed to work with aphasia patient data, analyzing brain connectivity patterns and their relationship to demographic variables.
15
+
16
+ ## About the Model
17
+
18
+ This application implements a VAE model that:
19
+ 1. Takes functional connectivity (FC) matrices derived from fMRI data
20
+ 2. Learns a lower-dimensional latent representation of brain connectivity
21
+ 3. Conditions the generation process on demographic variables (age, sex, time post-stroke, WAB scores)
22
+ 4. Allows analysis of relationships between brain connectivity patterns and demographic variables
23
+
24
+ ## Dataset
25
+
26
+ This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
27
+
28
+ - NIfTI files in P01_rs.nii format containing fMRI data
29
+ - Demographic information directly in the dataset:
30
+ - ID: Subject identifier
31
+ - wab_aq: Aphasia quotient score (severity measure)
32
+ - age: Subject age
33
+ - mpo: Months post onset
34
+ - education: Years of education
35
+ - gender: Subject gender
36
+ - handedness: Subject handedness (ignored in this analysis)
37
+
38
+ The application processes the NIfTI files using the Power 264 atlas to create functional connectivity matrices that are then analyzed by the VAE model.
39
+
40
+ ## How to Use
41
+
42
+ 1. **Configure Parameters**:
43
+ - **Data Source**: By default, it uses the SreekarB/OSFData HuggingFace dataset
44
+ - **Latent Dimensions**: Controls the size of the latent space (default: 32)
45
+ - **Number of Epochs**: Training iterations (default: 200 for demo)
46
+ - **Batch Size**: Training batch size (default: 16)
47
+
48
+ 2. **Start Training**:
49
+ - Click the "Start Training" button to begin the analysis
50
+ - The training progress will be displayed in the Status area
51
+
52
+ 3. **View Results**:
53
+ - The VAE will learn latent representations of brain connectivity
54
+ - Results will show correlations between demographic variables and latent brain patterns
55
+ - The visualization shows original FC, reconstructed FC, and a new FC matrix generated from specific demographic values
56
+
57
+ ## Outputs
58
+
59
+ The application produces visualizations showing:
60
+ - Original FC matrix
61
+ - Reconstructed FC matrix
62
+ - Generated FC matrix (based on specific demographic inputs)
63
+ - Correlation plots between latent variables and demographic features
64
+
65
+ ## Technical Details
66
+
67
+ - Framework: PyTorch
68
+ - Interface: Gradio
69
+ - Dataset: HuggingFace Datasets API
70
+ - Analysis: Custom implementation of conditional VAE with demographic conditioning
app.py CHANGED
@@ -1,265 +1,100 @@
1
- """
2
- Simplified app for Huggingface Spaces.
3
- Provides a simple UI for VAE training and visualization.
4
- """
5
- import os
6
  import gradio as gr
7
- import numpy as np
8
- import pandas as pd
9
- import matplotlib.pyplot as plt
10
- from vae_model import DemoVAE, plot_learning_curves
11
- import time
12
- import tempfile
13
- import logging
14
-
15
- # Set up logging
16
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
- logger = logging.getLogger(__name__)
18
-
19
- # Make sure directories exist
20
- os.makedirs('models', exist_ok=True)
21
- os.makedirs('results', exist_ok=True)
22
-
23
- # Global app state
24
- app_state = {
25
- 'vae': None,
26
- 'latents': None,
27
- 'demographics': None,
28
- 'fc_data': None,
29
- 'vae_trained': False
30
- }
31
 
32
- # Function to convert vector to matrix for visualization
33
- def vector_to_matrix(vector, size=10):
34
- """Convert a vector to a square matrix for visualization"""
35
- matrix = np.zeros((size, size))
36
- idx = 0
37
- # Fill upper triangle and mirror
38
- for i in range(size):
39
- for j in range(i+1, size):
40
- matrix[i, j] = matrix[j, i] = vector[idx % len(vector)]
41
- idx += 1
42
- # Set diagonal to 1.0
43
- np.fill_diagonal(matrix, 1.0)
44
- return matrix
45
 
46
- def train_vae(fc_file, demo_file, epochs=20, latent_dim=16, batch_size=8, progress=gr.Progress()):
47
- """Train a VAE model on uploaded data"""
48
- try:
49
- # Reset state
50
- app_state['vae_trained'] = False
51
- app_state['vae'] = None
52
- app_state['latents'] = None
53
-
54
- # Ensure uploaded files exist
55
- if not fc_file or not os.path.exists(fc_file.name):
56
- return "Error: Missing FC matrix file", None, None
57
-
58
- # Load FC data
59
- try:
60
- progress(0.1, "Loading FC data...")
61
- if fc_file.name.endswith('.npy'):
62
- X = np.load(fc_file.name)
63
- elif fc_file.name.endswith('.csv'):
64
- X = pd.read_csv(fc_file.name).values
65
- else:
66
- # Try to interpret as text
67
- X = np.loadtxt(fc_file.name)
68
-
69
- logger.info(f"Loaded FC data with shape: {X.shape}")
70
- app_state['fc_data'] = X
71
- except Exception as e:
72
- logger.error(f"Error loading FC data: {e}")
73
- return f"Error loading FC data: {str(e)}", None, None
74
-
75
- # Load demographic data if provided
76
- try:
77
- progress(0.2, "Loading demographic data...")
78
- if demo_file and os.path.exists(demo_file.name):
79
- demo_df = pd.read_csv(demo_file.name)
80
- logger.info(f"Loaded demographics with shape: {demo_df.shape}")
 
 
 
 
 
81
 
82
- # Try to extract standard demographics
83
- demographics = []
 
84
 
85
- # Age
86
- if 'age' in demo_df.columns:
87
- age = demo_df['age'].values
88
- elif 'age_at_stroke' in demo_df.columns:
89
- age = demo_df['age_at_stroke'].values
90
- else:
91
- age = np.random.normal(60, 10, len(X))
92
- logger.warning("Age column not found, using synthetic data")
93
- demographics.append(age)
94
-
95
- # Sex
96
- if 'sex' in demo_df.columns:
97
- sex = demo_df['sex'].values
98
- elif 'gender' in demo_df.columns:
99
- sex = demo_df['gender'].values
100
- else:
101
- sex = np.random.choice(['M', 'F'], len(X))
102
- logger.warning("Sex column not found, using synthetic data")
103
- demographics.append(sex)
104
-
105
- # Months post stroke
106
- if 'months_post_stroke' in demo_df.columns:
107
- mps = demo_df['months_post_stroke'].values
108
- elif 'mpo' in demo_df.columns:
109
- mps = demo_df['mpo'].values
110
- else:
111
- mps = np.random.normal(24, 12, len(X))
112
- logger.warning("Months post stroke column not found, using synthetic data")
113
- demographics.append(mps)
114
-
115
- # WAB score
116
- if 'wab_score' in demo_df.columns:
117
- wab = demo_df['wab_score'].values
118
- elif 'wab_aq' in demo_df.columns:
119
- wab = demo_df['wab_aq'].values
120
- else:
121
- wab = np.random.normal(65, 15, len(X))
122
- logger.warning("WAB score column not found, using synthetic data")
123
- demographics.append(wab)
124
-
125
- else:
126
- logger.info("No demographics file provided, using synthetic data")
127
- demographics = [
128
- np.random.normal(60, 10, len(X)), # age
129
- np.random.choice(['M', 'F'], len(X)), # sex
130
- np.random.normal(24, 12, len(X)), # months post stroke
131
- np.random.normal(65, 15, len(X)) # WAB score
132
- ]
133
-
134
- app_state['demographics'] = demographics
135
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
136
-
137
- except Exception as e:
138
- logger.error(f"Error processing demographics: {e}")
139
- return f"Error processing demographics: {str(e)}", None, None
140
-
141
- # Initialize model
142
- progress(0.3, "Initializing model...")
143
- model = DemoVAE(nepochs=epochs, batch_size=batch_size, latent_dim=latent_dim)
144
-
145
- # Train model
146
- progress(0.4, "Training VAE model...")
147
- train_losses, val_losses = model.fit(X, demographics, demo_types)
148
-
149
- # Save model
150
- progress(0.7, "Saving model...")
151
- model.save('models/vae_model.pt')
152
- app_state['vae'] = model
153
- app_state['vae_trained'] = True
154
-
155
- # Generate latent representations
156
- progress(0.8, "Generating latent representations...")
157
- latents = model.get_latents(X)
158
- app_state['latents'] = latents
159
- np.save('results/latents.npy', latents)
160
-
161
- # Create visualizations
162
- progress(0.9, "Creating visualizations...")
163
-
164
- # Learning curves
165
- learning_fig = plot_learning_curves(model.train_losses, model.val_losses)
166
- learning_img = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
167
- learning_fig.savefig(learning_img.name)
168
- plt.close(learning_fig)
169
-
170
- # FC visualization
171
- progress(0.95, "Creating FC visualizations...")
172
- reconstructed = model.transform(X, demographics, demo_types)
173
- np.save('results/reconstructed.npy', reconstructed)
174
-
175
- generated = model.transform(1, [d[0] for d in demographics], demo_types)
176
- np.save('results/generated.npy', generated)
177
-
178
- fc_fig, axes = plt.subplots(1, 3, figsize=(15, 5))
179
- original_matrix = vector_to_matrix(X[0])
180
- recon_matrix = vector_to_matrix(reconstructed[0])
181
- gen_matrix = vector_to_matrix(generated[0])
182
-
183
- # Plot matrices
184
- titles = ['Original', 'Reconstructed', 'Generated']
185
- for i, matrix in enumerate([original_matrix, recon_matrix, gen_matrix]):
186
- im = axes[i].imshow(matrix, cmap='RdBu_r', vmin=-1, vmax=1)
187
- axes[i].set_title(titles[i])
188
- axes[i].axis('off')
189
-
190
- fc_fig.subplots_adjust(right=0.8)
191
- cbar_ax = fc_fig.add_axes([0.85, 0.15, 0.05, 0.7])
192
- fc_fig.colorbar(im, cax=cbar_ax)
193
-
194
- fc_img = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
195
- fc_fig.savefig(fc_img.name)
196
- plt.close(fc_fig)
197
-
198
- progress(1.0, "Training complete!")
199
- return "Training completed successfully!", learning_img.name, fc_img.name
200
-
201
- except Exception as e:
202
- logger.error(f"Error in VAE training: {str(e)}")
203
- return f"Error: {str(e)}", None, None
204
-
205
- def create_ui():
206
- """Create the Gradio UI"""
207
- with gr.Blocks(title="FC Matrix VAE Demo") as app:
208
- gr.Markdown("# Functional Connectivity VAE Demo")
209
- gr.Markdown("Upload FC matrices and train a VAE model to analyze them.")
210
-
211
- with gr.Tab("Train VAE"):
212
- with gr.Row():
213
- with gr.Column():
214
- fc_file = gr.File(label="FC Matrix File (CSV or NPY)")
215
- demo_file = gr.File(label="Demographics File (CSV, optional)")
216
-
217
- with gr.Row():
218
- epochs = gr.Slider(5, 100, 20, step=5, label="Training Epochs")
219
- latent_dim = gr.Slider(8, 64, 16, step=4, label="Latent Dimension")
220
- batch_size = gr.Slider(4, 32, 8, step=4, label="Batch Size")
221
-
222
- train_btn = gr.Button("Train VAE Model")
223
- status = gr.Textbox(label="Status")
224
-
225
- with gr.Column():
226
- learning_plot = gr.Image(label="Learning Curves")
227
- fc_plot = gr.Image(label="FC Matrices")
228
-
229
- train_btn.click(
230
- fn=train_vae,
231
- inputs=[fc_file, demo_file, epochs, latent_dim, batch_size],
232
- outputs=[status, learning_plot, fc_plot]
233
- )
234
-
235
- with gr.Tab("About"):
236
- gr.Markdown("""
237
- ## About this App
238
-
239
- This app trains a Variational Autoencoder (VAE) on functional connectivity (FC) matrices.
240
-
241
- ### Features:
242
- * Load FC matrices from CSV or NPY files
243
- * Incorporate demographic data (age, sex, etc.)
244
- * Visualize learning curves
245
- * Compare original, reconstructed and generated FC matrices
246
-
247
- ### Input Format:
248
- * FC matrices should be provided as vectors (flattened upper triangular portion of symmetric matrices)
249
- * Demographics file should be CSV with columns for age, sex, months_post_stroke, and wab_score
250
-
251
- ### Model Architecture:
252
- * Simple feedforward VAE with demographic conditioning
253
- * Latent space can be specified (default 16 dimensions)
254
- * MSE reconstruction loss
255
- """)
256
-
257
- return app
258
 
259
- # For local testing
260
  if __name__ == "__main__":
261
- app = create_ui()
262
- app.launch()
263
-
264
- # For Huggingface Spaces
265
- demo = create_ui()
 
 
 
 
 
 
1
  import gradio as gr
2
+ from main import run_fc_analysis
3
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
+ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
6
+ """Run the full VAE analysis pipeline"""
7
+ fig = run_fc_analysis(
8
+ data_dir=data_source,
9
+ demographic_file=None, # We're now getting demographics directly from the dataset
10
+ latent_dim=latent_dim,
11
+ nepochs=nepochs,
12
+ bsize=bsize,
13
+ save_model=True,
14
+ use_hf_dataset=use_hf_dataset
15
+ )
16
+ return fig, "Analysis complete! VAE model has been trained and demographic relationships analyzed."
 
17
 
18
+ def create_interface():
19
+ with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
20
+ gr.Markdown("""
21
+ # Aphasia fMRI to FC Analysis using VAE
22
+
23
+ This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
24
+
25
+ ## Dataset Information
26
+ By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
27
+ - ID: Subject identifier
28
+ - wab_aq: Aphasia severity score
29
+ - age: Age of the subject
30
+ - mpo: Months post onset
31
+ - education: Years of education
32
+ - gender: Subject gender
33
+ - handedness: Subject handedness (ignored in the analysis)
34
+ """)
35
+
36
+ with gr.Row():
37
+ with gr.Column(scale=1):
38
+ # Configuration parameters
39
+ data_source = gr.Textbox(
40
+ label="Data Source (HF Dataset ID or Local Directory)",
41
+ value="SreekarB/OSFData"
42
+ )
43
+ latent_dim = gr.Slider(
44
+ minimum=8, maximum=64, step=8,
45
+ label="Latent Dimensions", value=32
46
+ )
47
+ nepochs = gr.Slider(
48
+ minimum=100, maximum=5000, step=100,
49
+ label="Number of Epochs", value=200 # Reduced for faster demos
50
+ )
51
+ bsize = gr.Slider(
52
+ minimum=8, maximum=64, step=8,
53
+ label="Batch Size", value=16
54
+ )
55
+ use_hf_dataset = gr.Checkbox(
56
+ label="Use HuggingFace Dataset", value=True
57
+ )
58
 
59
+ # Training button
60
+ train_button = gr.Button("Start Training", variant="primary")
61
+ status_text = gr.Textbox(label="Status", value="Ready to start training")
62
 
63
+ with gr.Column(scale=2):
64
+ # Output plot
65
+ output_plot = gr.Plot(label="Analysis Results")
66
+
67
+ # Link the training button to the analysis function
68
+ train_button.click(
69
+ fn=gradio_fc_analysis,
70
+ inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
71
+ outputs=[output_plot, status_text]
72
+ )
73
+
74
+ # Add examples
75
+ gr.Examples(
76
+ examples=[
77
+ ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
78
+ ],
79
+ inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
80
+ )
81
+
82
+ # Add explanation of the workflow
83
+ gr.Markdown("""
84
+ ## How this works
85
+
86
+ 1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
87
+ 2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
88
+ 3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
89
+ 4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
90
+ 5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
91
+
92
+ Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
93
+ """)
94
+
95
+ return iface
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
 
97
  if __name__ == "__main__":
98
+ iface = create_interface()
99
+ iface.launch(share=True)
100
+
 
 
config.py CHANGED
@@ -1,8 +1,8 @@
1
  # Model configuration
2
  MODEL_CONFIG = {
3
  'latent_dim': 32,
4
- 'nepochs': 100, # Changed from 1000 to 100 for faster testing
5
- 'bsize': 5, # Changed from 16 to 5 for small sample sizes
6
  'loss_rec_mult': 100,
7
  'loss_decor_mult': 10,
8
  'lr': 1e-4
@@ -18,20 +18,7 @@ PREPROCESS_CONFIG = {
18
 
19
  # Dataset configuration
20
  DATASET_CONFIG = {
21
- 'name': 'SreekarB/OSFData1',
22
  'split': 'train'
23
  }
24
 
25
- # Prediction configuration
26
- PREDICTION_CONFIG = {
27
- 'n_estimators': 100,
28
- 'max_depth': None,
29
- 'cv_folds': 5,
30
- 'default_outcome': 'wab_aq',
31
- 'save_path': 'results/treatment_predictor.joblib',
32
- 'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
33
- 'use_synthetic_nifti': False, # Set to False to use only real NIfTI data
34
- 'use_synthetic_fc': False, # Set to False to use only real FC matrices
35
- 'strict_real_data': True, # Set to True to strictly use real data only
36
- 'no_mock_data': True # Set to True to prevent using any mock or synthetic data
37
- }
 
1
  # Model configuration
2
  MODEL_CONFIG = {
3
  'latent_dim': 32,
4
+ 'nepochs': 1000,
5
+ 'bsize': 16,
6
  'loss_rec_mult': 100,
7
  'loss_decor_mult': 10,
8
  'lr': 1e-4
 
18
 
19
  # Dataset configuration
20
  DATASET_CONFIG = {
21
+ 'name': 'SreekarB/OSFData',
22
  'split': 'train'
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
data_preprocessing.py CHANGED
@@ -1,1212 +1,557 @@
1
  import numpy as np
2
  import pandas as pd
3
- import os
4
- import json
5
- import pickle
6
- import hashlib
7
- import warnings
8
- import re
9
- from nilearn import input_data, connectome, datasets
10
  from nilearn.image import load_img
11
  import nibabel as nib
12
- from pathlib import Path
13
- from config import PREPROCESS_CONFIG, PREDICTION_CONFIG
14
-
15
- # Create cache directory if it doesn't exist
16
- CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache')
17
- os.makedirs(CACHE_DIR, exist_ok=True)
18
- os.makedirs(os.path.join(CACHE_DIR, 'time_series'), exist_ok=True)
19
- os.makedirs(os.path.join(CACHE_DIR, 'fc_matrices'), exist_ok=True)
20
- os.makedirs(os.path.join(CACHE_DIR, 'latents'), exist_ok=True)
21
- os.makedirs(os.path.join(CACHE_DIR, 'maskers'), exist_ok=True)
22
- os.makedirs(os.path.join(CACHE_DIR, 'atlas'), exist_ok=True)
23
-
24
- # Cache the atlas coordinates globally for efficient access
25
- REGIONAL_COORDS = None
26
-
27
- # Initialize the Power atlas coordinates
28
- def _init_atlas_coords():
29
- global REGIONAL_COORDS
30
- if REGIONAL_COORDS is None:
31
- try:
32
- atlas_path = os.path.join(CACHE_DIR, 'atlas', 'power_2011_coords.npy')
33
- if os.path.exists(atlas_path):
34
- REGIONAL_COORDS = np.load(atlas_path)
35
- else:
36
- from nilearn import datasets
37
- power = datasets.fetch_coords_power_2011()
38
- REGIONAL_COORDS = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
39
- # Save for future use
40
- np.save(atlas_path, REGIONAL_COORDS)
41
- print(f"Initialized Power atlas coordinates with {len(REGIONAL_COORDS)} ROIs")
42
- except Exception as e:
43
- print(f"Error initializing atlas coordinates: {e}")
44
- # Fallback to a simple set of coordinates if needed
45
- REGIONAL_COORDS = np.array([
46
- [0, 0, 0], [10, 0, 0], [0, 10, 0], [0, 0, 10]
47
- ])
48
- print("WARNING: Using fallback coordinates due to initialization error")
49
-
50
- # Initialize atlas coordinates at module load time
51
- _init_atlas_coords()
52
-
53
- def get_file_hash(file_path):
54
- """Generate a hash for a file to use as a cache key"""
55
- try:
56
- hasher = hashlib.md5()
57
- with open(file_path, 'rb') as f:
58
- # Read in chunks to handle large files
59
- for chunk in iter(lambda: f.read(4096), b""):
60
- hasher.update(chunk)
61
- return hasher.hexdigest()
62
- except Exception as e:
63
- print(f"Error hashing file {file_path}: {e}")
64
- # Fallback to filename-based hash if file reading fails
65
- return hashlib.md5(os.path.basename(file_path).encode()).hexdigest()
66
 
67
- def get_cached_atlas_coords(atlas_name="power_2011", use_cache=True):
68
  """
69
- Get atlas coordinates, using cache if available
70
 
71
- Args:
72
- atlas_name: Name of the atlas (currently only power_2011 is supported)
73
- use_cache: Whether to use/create cache
74
-
75
- Returns:
76
- coords: Array of coordinates for the atlas
77
- """
78
- global REGIONAL_COORDS
79
 
80
- # If we have already initialized the coordinates, use them
81
- if REGIONAL_COORDS is not None:
82
- return REGIONAL_COORDS
83
-
84
- # Otherwise, use the initialization function
85
- _init_atlas_coords()
86
- return REGIONAL_COORDS
87
-
88
- def get_cached_masker(radius, use_cache=True):
89
- """
90
- Get a NiftiSpheresMasker with the specified radius, using cache if available
91
-
92
- Args:
93
- radius: Sphere radius in mm
94
- use_cache: Whether to use/create cache
95
-
96
  Returns:
97
- masker: NiftiSpheresMasker object
 
 
98
  """
99
- if not use_cache:
100
- return None
101
-
102
- # Create a cache key for this masker configuration
103
- # We use radius and other PREPROCESS_CONFIG values that affect the masker
104
- config_str = (f"radius={radius},"
105
- f"tr={PREPROCESS_CONFIG['t_r']},"
106
- f"high_pass={PREPROCESS_CONFIG['high_pass']},"
107
- f"low_pass={PREPROCESS_CONFIG['low_pass']}")
108
-
109
- masker_key = hashlib.md5(config_str.encode()).hexdigest()
110
- masker_path = os.path.join(CACHE_DIR, 'maskers', f"{masker_key}.pkl")
111
 
112
- # Check if we have a cached masker
113
- if os.path.exists(masker_path):
 
 
114
  try:
115
- print(f"Loading cached masker for radius {radius}mm")
116
- with open(masker_path, 'rb') as f:
117
- masker = pickle.load(f)
118
- print(f"Successfully loaded cached masker for radius {radius}mm")
119
- return masker
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
- print(f"Error loading cached masker: {e}, creating new one")
122
-
123
- # No valid cache, return None to indicate a new masker should be created
124
- return None
125
-
126
- def process_single_fmri(fmri_file, allow_synthetic=False, use_cache=True, try_preprocessing=True):
127
- """
128
- Process a single fMRI file to FC matrix
129
-
130
- Args:
131
- fmri_file: Path to the fMRI .nii or .nii.gz file
132
- allow_synthetic: If True, generate synthetic FC matrix on error (disabled by default)
133
- use_cache: If True, try to load cached data instead of reprocessing
 
 
 
 
 
 
 
 
134
 
135
- Returns:
136
- fc_triu: Upper triangular FC matrix values
137
- """
138
- print(f"Processing fMRI file: {fmri_file}")
139
-
140
- # Make sure os is imported to avoid reference error
141
- import os
142
-
143
- # Check if cached FC matrix exists
144
- if use_cache:
145
- file_hash = get_file_hash(fmri_file)
146
- fc_cache_path = os.path.join(CACHE_DIR, 'fc_matrices', f"{file_hash}.npy")
147
 
148
- if os.path.exists(fc_cache_path):
149
- print(f"Loading cached FC matrix for {os.path.basename(fmri_file)}")
150
- try:
151
- fc_triu = np.load(fc_cache_path)
152
- print(f"Successfully loaded cached FC matrix, shape: {fc_triu.shape}")
153
- return fc_triu
154
- except Exception as e:
155
- print(f"Error loading cached FC matrix: {e}, recalculating...")
156
-
157
- # Use Power 264 atlas with caching
158
- coords = get_cached_atlas_coords(use_cache=use_cache)
159
-
160
- # FIRST: Try to normalize the NIfTI file to MNI space for better compatibility
161
- try:
162
- print("First attempting to register NIfTI file to MNI space...")
163
- from nilearn import image
164
  import tempfile
 
 
165
 
166
- # Load the original image
167
- orig_img = load_img(fmri_file)
168
-
169
- # Check if it's a 4D file with sufficient time points
170
- if len(orig_img.shape) < 4:
171
- print("Cannot preprocess: Not a 4D file (no time dimension)")
172
- elif orig_img.shape[3] < 20:
173
- print(f"Warning: Very few time points ({orig_img.shape[3]}), results may be unreliable")
174
-
175
- # Create a preprocessing directory
176
- preproc_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'preproc')
177
- os.makedirs(preproc_dir, exist_ok=True)
178
-
179
- # Generate a filename for the preprocessed file
180
- basename = os.path.basename(fmri_file)
181
- preproc_file = os.path.join(preproc_dir, f"mni_registered_{basename}")
182
 
183
- print(f"MNI registration steps for {basename}:")
184
-
185
- # Step 1: Get the MNI152 template for reference
186
- from nilearn.datasets import load_mni152_template
187
- template = load_mni152_template()
188
- print("1. Loaded MNI152 template as reference")
189
-
190
- # Step 2: For 4D data, we'll work with the mean image for registration
191
- if len(orig_img.shape) == 4:
192
- mean_img = image.mean_img(orig_img)
193
- print("2. Extracted mean image from 4D volume for registration")
194
- else:
195
- mean_img = orig_img
196
-
197
- # Step 3: Register to MNI space (target resolution of 3mm)
198
- print("3. Registering to MNI space with 3mm resolution...")
199
- reg_img = image.resample_to_img(orig_img, template, interpolation='linear')
200
- print(f" Original dimensions: {orig_img.shape}, New dimensions: {reg_img.shape}")
201
-
202
- # Step 4: Save the preprocessed image
203
- print(f"4. Saving MNI-registered image to {preproc_file}...")
204
- reg_img.to_filename(preproc_file)
205
- print("MNI registration complete")
206
-
207
- # Now try to process this MNI-registered file
208
- mni_fmri_file = preproc_file
209
- except Exception as reg_err:
210
- print(f"Error during MNI registration: {reg_err}")
211
- print("Continuing with original NIfTI file")
212
- mni_fmri_file = fmri_file
213
-
214
- # Try different atlas radiuses if the default one has issues
215
- # Include more radius options and make sure they're unique and sorted
216
- radius_options = list(set([
217
- PREPROCESS_CONFIG['radius'],
218
- 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, # Add more options
219
- 4, 3 # Smaller options as last resort
220
- ]))
221
- radius_options.sort() # Sort for consistent attempts
222
- print(f"Will try these radius options in order: {radius_options}")
223
-
224
- # Try each radius option
225
- for radius in radius_options:
226
  try:
227
- print(f"Trying with radius {radius}mm...")
228
-
229
- # Check if we have cached time series for this file and radius
230
- if use_cache:
231
- ts_cache_key = f"{file_hash}_r{radius}"
232
- ts_cache_path = os.path.join(CACHE_DIR, 'time_series', f"{ts_cache_key}.npy")
 
 
 
 
 
 
 
 
233
 
234
- if os.path.exists(ts_cache_path):
235
- print(f"Loading cached time series for radius {radius}mm")
236
- try:
237
- time_series = np.load(ts_cache_path)
238
- print(f"Successfully loaded cached time series, shape: {time_series.shape}")
239
- except Exception as e:
240
- print(f"Error loading cached time series: {e}, recalculating...")
241
- time_series = None
242
- else:
243
- time_series = None
244
- else:
245
- time_series = None
246
-
247
- # If no cached time series, calculate it
248
- if time_series is None:
249
- # Try to get a cached masker first
250
- masker = get_cached_masker(radius, use_cache)
251
-
252
- # If no cached masker, create a new one
253
- if masker is None:
254
- print(f"Creating new masker with radius {radius}mm")
255
- # Create masker with allow_empty=True to handle empty spheres
256
- masker = input_data.NiftiSpheresMasker(
257
- coords,
258
- radius=radius,
259
- standardize=True,
260
- memory='nilearn_cache',
261
- memory_level=1,
262
- verbose=1, # Increase verbosity for debugging
263
- detrend=True,
264
- low_pass=PREPROCESS_CONFIG['low_pass'],
265
- high_pass=PREPROCESS_CONFIG['high_pass'],
266
- t_r=PREPROCESS_CONFIG['t_r'],
267
- allow_empty=True # Allow empty spheres
268
- )
269
 
270
- # Cache the masker if caching is enabled
271
- if use_cache:
 
 
 
 
 
 
 
 
 
272
  try:
273
- config_str = (f"radius={radius},"
274
- f"tr={PREPROCESS_CONFIG['t_r']},"
275
- f"high_pass={PREPROCESS_CONFIG['high_pass']},"
276
- f"low_pass={PREPROCESS_CONFIG['low_pass']}")
277
- masker_key = hashlib.md5(config_str.encode()).hexdigest()
278
- masker_path = os.path.join(CACHE_DIR, 'maskers', f"{masker_key}.pkl")
 
 
 
 
 
279
 
280
- with open(masker_path, 'wb') as f:
281
- pickle.dump(masker, f)
282
- print(f"Saved masker to cache: {masker_path}")
283
- except Exception as e:
284
- print(f"Error saving masker to cache: {e}")
285
-
286
- # Load and process fMRI - use the MNI-registered file if available
287
- print(f"Loading NIfTI file: {mni_fmri_file}...")
288
- fmri_img = load_img(mni_fmri_file)
289
- print(f"NIfTI file loaded, shape: {fmri_img.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
290
 
291
- # Check for insufficient time points
292
- if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 20: # Assuming we need at least 20 time points
293
- print(f"Warning: {mni_fmri_file} has insufficient time points: {fmri_img.shape}")
294
- continue
295
-
296
- # Transform to time series with explicit warning handling
297
- print(f"Extracting time series...")
298
  try:
299
- # Explicitly handle warnings about empty spheres
300
- with warnings.catch_warnings():
301
- warnings.filterwarnings('ignore', message='.*empty.*')
302
- time_series = masker.fit_transform(fmri_img)
303
- except Exception as e:
304
- if "empty" in str(e):
305
- print(f"Warning: Some spheres are empty in {mni_fmri_file}. Using a different sphere radius.")
306
 
307
- # Extract the list of empty spheres for logging
308
- empty_spheres = re.findall(r"\[(.*?)\]", str(e))
309
- if empty_spheres:
310
- print(f"Empty spheres: {empty_spheres[0]}")
311
 
312
- # Continue to next radius option
313
- continue
314
- else:
315
- print(f"Unknown error in masker: {e}")
316
- continue # Skip this radius if there's any other error
317
-
318
- print(f"Time series extracted, shape: {time_series.shape}")
319
-
320
- # Cache the time series if successful
321
- if use_cache and time_series is not None:
322
- try:
323
- np.save(ts_cache_path, time_series)
324
- print(f"Saved time series to cache: {ts_cache_path}")
 
 
 
 
 
 
 
 
 
325
  except Exception as e:
326
- print(f"Error saving time series to cache: {e}")
327
-
328
- print(f"Time series processed, shape: {time_series.shape}")
329
-
330
- # Validate time series data
331
- if np.isnan(time_series).any() or np.isinf(time_series).any():
332
- print(f"Warning: {mni_fmri_file} contains NaN or Inf values after masking")
333
- continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
334
 
335
- # Check if any ROIs couldn't be extracted (column of zeros)
336
- zero_cols = np.where(np.all(np.abs(time_series) < 1e-10, axis=0))[0]
337
- if len(zero_cols) > 0:
338
- print(f"Warning: {len(zero_cols)} ROIs have zero/empty time series")
339
-
340
- # If too many are empty (>50%), try the next radius
341
- if len(zero_cols) > 0.5 * time_series.shape[1]:
342
- print(f"Too many empty ROIs ({len(zero_cols)}), trying different radius")
343
- continue
344
 
345
- # Replace empty ROIs with the mean of non-empty ROIs
346
- non_zero_cols = [i for i in range(time_series.shape[1]) if i not in zero_cols]
347
- if non_zero_cols:
348
- mean_timeseries = np.mean(time_series[:, non_zero_cols], axis=1)
349
- for col in zero_cols:
350
- # Add very small variation to the mean
351
- time_series[:, col] = mean_timeseries + np.random.randn(time_series.shape[0]) * 1e-5
352
-
353
- # Compute FC matrix
354
- print(f"Computing FC matrix...")
355
- correlation_measure = connectome.ConnectivityMeasure(
356
- kind='correlation',
357
- vectorize=False,
358
- discard_diagonal=False
359
- )
360
-
361
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
362
- print(f"FC matrix computed, shape: {fc_matrix.shape}")
363
-
364
- # Check for NaN values in the FC matrix
365
- if np.any(np.isnan(fc_matrix)):
366
- print(f"Warning: NaN values in FC matrix, replacing with zeros")
367
- fc_matrix = np.nan_to_num(fc_matrix)
368
-
369
- # Get upper triangular part
370
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
371
- fc_triu = fc_matrix[triu_indices]
372
-
373
- # Fisher z-transform
374
- fc_triu = np.arctanh(np.clip(fc_triu, -0.99, 0.99)) # Clip to avoid infinite values
375
-
376
- print(f"Processing complete. FC features shape: {fc_triu.shape}")
377
-
378
- # Cache the successful FC matrix
379
- if use_cache:
380
  try:
381
- fc_cache_path = os.path.join(CACHE_DIR, 'fc_matrices', f"{file_hash}.npy")
382
- np.save(fc_cache_path, fc_triu)
383
- print(f"Saved FC matrix to cache: {fc_cache_path}")
384
- except Exception as e:
385
- print(f"Error saving FC matrix to cache: {e}")
 
 
 
386
 
387
- return fc_triu
388
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
389
  except Exception as e:
390
- print(f"Error with radius {radius}mm: {e}")
391
- # Continue to next radius option
392
- continue
393
-
394
- # If we get here, all radius options failed
395
- print(f"Failed to process {mni_fmri_file} with all radius options")
396
-
397
- # If preprocessing is enabled, try more advanced preprocessing
398
- if try_preprocessing:
399
- try:
400
- print(f"Attempting advanced preprocessing of {fmri_file}...")
401
-
402
- # Import nilearn preprocessing
403
- from nilearn import image
404
- import os # Ensure os is imported again here
405
-
406
- # Load the image
407
- orig_img = load_img(fmri_file)
408
-
409
- # Check if it's a 4D file with sufficient time points
410
- if len(orig_img.shape) < 4:
411
- print("Cannot preprocess: Not a 4D file (no time dimension)")
412
- elif orig_img.shape[3] < 30:
413
- print(f"Warning: Very few time points ({orig_img.shape[3]}), results may be unreliable")
414
-
415
- # Create a preprocessing directory
416
- preproc_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'preproc')
417
- os.makedirs(preproc_dir, exist_ok=True)
418
-
419
- # Generate a filename for the preprocessed file
420
- basename = os.path.basename(fmri_file)
421
- preproc_file = os.path.join(preproc_dir, f"preproc_{basename}")
422
-
423
- print(f"Advanced preprocessing steps for {basename}:")
424
-
425
- # Step 1: Get MNI template and try a more aggressive registration
426
- from nilearn.datasets import load_mni152_template
427
- template = load_mni152_template()
428
- print("1. Loading MNI152 template as target space")
429
-
430
- # Step 2: Resampling to MNI space with higher resolution (2mm)
431
- print("2. Resampling to 2mm isotropic voxels in MNI space...")
432
- try:
433
- # Try resampling to MNI template
434
- resampled_img = image.resample_to_img(
435
- orig_img,
436
- template,
437
- interpolation='linear'
438
- )
439
- except Exception as resample_err:
440
- print(f"Error resampling to template: {resample_err}")
441
- # Fallback to standard resampling with affine
442
- resampled_img = image.resample_img(
443
- orig_img,
444
- target_affine=np.diag([2, 2, 2, 1])
445
- )
446
-
447
- # Step 3: Motion correction and filtering
448
- print("3. Applying robust temporal filtering...")
449
- filtered_img = image.clean_img(
450
- resampled_img,
451
  detrend=True,
452
- standardize='zscore',
453
  low_pass=0.1,
454
  high_pass=0.01,
455
- t_r=2.0 # Assuming TR=2s if not specified
456
  )
457
 
458
- # Step 4: Spatial smoothing
459
- print("4. Applying spatial smoothing...")
460
- smoothed_img = image.smooth_img(filtered_img, fwhm=6)
461
-
462
- # Save the preprocessed image
463
- print(f"5. Saving preprocessed image to {preproc_file}...")
464
- smoothed_img.to_filename(preproc_file)
465
-
466
- print("Advanced preprocessing complete, attempting extraction with processed file...")
467
-
468
- # SPECIAL APPROACH: Try using different coordinates
469
- # If the atlas doesn't align with the image, we can try to adjust the coordinates
470
- print("Trying with transformed coordinates...")
471
-
472
- # Load the preprocessed image to get its dimensions
473
- preproc_img = load_img(preproc_file)
474
- preproc_shape = preproc_img.shape
475
- preproc_affine = preproc_img.affine
476
-
477
- # Get original Power coordinates
478
- orig_coords = get_cached_atlas_coords()
479
-
480
- # Calculate the coordinate ranges in the preprocessed image
481
- img_mins = [0, 0, 0]
482
- img_maxs = [preproc_shape[0]-1, preproc_shape[1]-1, preproc_shape[2]-1]
483
-
484
- # Convert to world coordinates
485
- from nibabel.affines import apply_affine
486
- world_mins = apply_affine(preproc_affine, img_mins)
487
- world_maxs = apply_affine(preproc_affine, img_maxs)
488
-
489
- # Scale the coordinates to fit within the image bounds
490
- coord_mins = orig_coords.min(axis=0)
491
- coord_maxs = orig_coords.max(axis=0)
492
 
493
- # Calculate scale factors
494
- scale_x = (world_maxs[0] - world_mins[0]) / (coord_maxs[0] - coord_mins[0])
495
- scale_y = (world_maxs[1] - world_mins[1]) / (coord_maxs[1] - coord_mins[1])
496
- scale_z = (world_maxs[2] - world_mins[2]) / (coord_maxs[2] - coord_mins[2])
497
-
498
- # Calculate offsets
499
- offset_x = world_mins[0] - coord_mins[0] * scale_x
500
- offset_y = world_mins[1] - coord_mins[1] * scale_y
501
- offset_z = world_mins[2] - coord_mins[2] * scale_z
502
-
503
- # Apply transformation to coordinates
504
- adjusted_coords = np.copy(orig_coords)
505
- adjusted_coords[:, 0] = orig_coords[:, 0] * scale_x + offset_x
506
- adjusted_coords[:, 1] = orig_coords[:, 1] * scale_y + offset_y
507
- adjusted_coords[:, 2] = orig_coords[:, 2] * scale_z + offset_z
508
-
509
- print(f"Adjusted coordinates to fit within image bounds")
510
- print(f"Original coord range: X({coord_mins[0]:.1f}-{coord_maxs[0]:.1f}), Y({coord_mins[1]:.1f}-{coord_maxs[1]:.1f}), Z({coord_mins[2]:.1f}-{coord_maxs[2]:.1f})")
511
- print(f"Adjusted coord range: X({adjusted_coords[:,0].min():.1f}-{adjusted_coords[:,0].max():.1f}), Y({adjusted_coords[:,1].min():.1f}-{adjusted_coords[:,1].max():.1f}), Z({adjusted_coords[:,2].min():.1f}-{adjusted_coords[:,2].max():.1f})")
512
-
513
- # Try to process with adjusted coordinates
514
- for radius in radius_options:
515
  try:
516
- print(f"Trying with adjusted coordinates and radius {radius}mm...")
 
517
 
518
- # Create spherical masker with adjusted coordinates
519
- masker = input_data.NiftiSpheresMasker(
520
- seeds=adjusted_coords,
521
- radius=radius,
522
- allow_overlap=True,
523
- standardize=True,
524
- memory='nilearn_cache',
525
- memory_level=1,
526
- verbose=1,
527
- allow_empty=True # Allow empty spheres
528
- )
529
-
530
- # Extract time series from preprocessed file
531
- time_series = masker.fit_transform(preproc_file)
532
-
533
- # Check for too many empty ROIs
534
- zero_cols = np.where(np.all(np.abs(time_series) < 1e-10, axis=0))[0]
535
- if len(zero_cols) > 0.5 * time_series.shape[1]:
536
- print(f"Too many empty ROIs ({len(zero_cols)}), trying different radius")
537
  continue
 
 
538
 
539
- # Create correlation matrix
540
- correlation_measure = connectome.ConnectivityMeasure(kind='correlation')
541
- correlation_matrix = correlation_measure.fit_transform([time_series])[0]
 
 
542
 
543
- # Convert to z-scores (Fisher's transform)
544
- z_matrix = np.arctanh(np.clip(correlation_matrix, -0.99, 0.99))
 
 
 
545
 
546
- # Replace infinite values
547
- np.fill_diagonal(z_matrix, 0)
548
 
549
- # Extract upper triangle (excluding diagonal)
550
- n_rois = len(adjusted_coords)
551
- triu_indices = np.triu_indices(n_rois, k=1)
552
- fc_triu = z_matrix[triu_indices]
 
 
 
553
 
554
- # Check for NaN values
555
- if np.any(np.isnan(fc_triu)):
556
- raise ValueError(f"NaN values found in FC matrix with radius {radius}")
 
557
 
558
- # Successfully processed
559
- print(f"Successfully processed with adjusted coordinates and radius {radius}mm")
 
560
 
561
- # Cache the successful FC matrix
562
- if use_cache:
563
- try:
564
- fc_cache_path = os.path.join(CACHE_DIR, 'fc_matrices', f"{file_hash}.npy")
565
- np.save(fc_cache_path, fc_triu)
566
- print(f"Saved FC matrix to cache: {fc_cache_path}")
567
- except Exception as e:
568
- print(f"Error saving FC matrix to cache: {e}")
569
-
570
- return fc_triu
571
-
572
  except Exception as e:
573
- print(f"Failed with adjusted coordinates and radius {radius}mm: {e}")
574
- # Try the next radius option
575
- continue
576
-
577
- print("Advanced preprocessing and coordinate adjustment failed")
578
- except Exception as preproc_err:
579
- print(f"Error during advanced preprocessing: {preproc_err}")
580
-
581
- # Try to diagnose the issue
582
- try:
583
- # Check if the file exists and is readable
584
- if not os.path.exists(fmri_file):
585
- error_msg = f"File does not exist: {fmri_file}"
586
- else:
587
- # Try to get more information about the file
588
- fmri_img = load_img(fmri_file)
589
 
590
- # Get detailed information about the NIfTI file
591
- affine = fmri_img.affine
592
- header = fmri_img.header
593
- zooms = header.get_zooms() # voxel dimensions
594
-
595
- # Calculate the range of coordinates in the image
596
- shape = fmri_img.shape
597
- img_mins = [0, 0, 0]
598
- img_maxs = [shape[0]-1, shape[1]-1, shape[2]-1]
599
-
600
- # Convert to world coordinates
601
- from nibabel.affines import apply_affine
602
- world_mins = apply_affine(affine, img_mins)
603
- world_maxs = apply_affine(affine, img_maxs)
604
-
605
- # Get atlas coordinates for comparison
606
- try:
607
- coords = get_cached_atlas_coords()
608
- coord_mins = coords.min(axis=0)
609
- coord_maxs = coords.max(axis=0)
610
 
611
- # Check if atlas coordinates are within the image bounds
612
- coord_in_img = all([
613
- coord_mins[0] >= world_mins[0] and coord_maxs[0] <= world_maxs[0],
614
- coord_mins[1] >= world_mins[1] and coord_maxs[1] <= world_maxs[1],
615
- coord_mins[2] >= world_mins[2] and coord_maxs[2] <= world_maxs[2]
616
- ])
 
 
 
617
 
618
- atlas_info = (f"Atlas coords range: X({coord_mins[0]:.1f} to {coord_maxs[0]:.1f}), "
619
- f"Y({coord_mins[1]:.1f} to {coord_maxs[1]:.1f}), "
620
- f"Z({coord_mins[2]:.1f} to {coord_maxs[2]:.1f})")
621
 
622
- img_info = (f"Image world coords: X({world_mins[0]:.1f} to {world_maxs[0]:.1f}), "
623
- f"Y({world_mins[1]:.1f} to {world_maxs[1]:.1f}), "
624
- f"Z({world_mins[2]:.1f} to {world_maxs[2]:.1f})")
625
 
626
- alignment = "Atlas coordinates are within image bounds" if coord_in_img else "ISSUE: Atlas coordinates outside image bounds!"
627
- except Exception as atlas_err:
628
- atlas_info = f"Error getting atlas coords: {atlas_err}"
629
- img_info = ""
630
- alignment = "Unable to check atlas-image alignment"
631
-
632
- # Check for potential issues
633
- error_msg = (f"File is readable but couldn't be processed with any radius. "
634
- f"\nShape: {fmri_img.shape}, Data type: {fmri_img.get_data_dtype()}"
635
- f"\nVoxel dimensions: {zooms}"
636
- f"\n{img_info}"
637
- f"\n{atlas_info}"
638
- f"\n{alignment}")
639
-
640
- # Check if it's a 4D file with sufficient time points
641
- if len(fmri_img.shape) < 4:
642
- error_msg += "\nISSUE: Not a 4D file (no time dimension)"
643
- elif fmri_img.shape[3] < 30:
644
- error_msg += f"\nISSUE: Too few time points ({fmri_img.shape[3]}), need at least 30"
645
 
646
- # Check if the affine is reasonable
647
- determinant = np.linalg.det(affine[:3, :3])
648
- if abs(determinant) < 0.1:
649
- error_msg += f"\nISSUE: Potentially invalid affine matrix (determinant={determinant:.3f})"
650
 
651
- except Exception as diag_err:
652
- error_msg = f"Error diagnosing file: {diag_err}"
653
-
654
- print(f"Diagnosis: {error_msg}")
655
-
656
- if allow_synthetic:
657
- # Create synthetic FC matrix as fallback
658
- print(f"Creating synthetic FC matrix for {fmri_file}")
659
- # Number of ROIs in Power atlas
660
- n_rois = 264
661
- n_triu_elements = n_rois * (n_rois - 1) // 2
662
-
663
- # Create synthetic FC matrix with realistic values
664
- # Use the filename to seed random generator for consistency
665
- try:
666
- # Try to extract a patient number from the filename for seeding
667
- filename = os.path.basename(fmri_file)
668
- if 'P' in filename and '_' in filename:
669
- seed = int(filename.split('_')[0].replace('P', '')) % 1000
670
- else:
671
- # Hash the filename for a consistent seed
672
- seed = int(hashlib.md5(filename.encode()).hexdigest(), 16) % 1000
673
- except:
674
- seed = 42 # Default seed if filename parsing fails
675
-
676
- np.random.seed(seed)
677
- fc_triu = np.random.rand(n_triu_elements) * 1.6 - 0.8
678
- fc_triu = np.arctanh(np.clip(fc_triu, -0.99, 0.99))
679
- print(f"Created synthetic FC matrix with {n_triu_elements} elements, seed: {seed}")
680
- return fc_triu
681
- else:
682
- error_msg = f"Could not process {fmri_file} with any radius option. {error_msg}"
683
- print(f"ERROR: {error_msg}")
684
- print("TIP: Set allow_synthetic=True to use synthetic data as fallback")
685
- raise ValueError(error_msg)
686
-
687
- def preprocess_fmri_to_fc(nii_files, demo_data, demo_types, use_synthetic_fallback=True):
688
- """
689
- Convert multiple fMRI files to FC matrices
690
-
691
- Args:
692
- nii_files: List of NIfTI files to process
693
- demo_data: Demographic data arrays
694
- demo_types: Types of demographic data
695
- use_synthetic_fallback: Whether to use synthetic data if real data processing fails
696
-
697
- Returns:
698
- Tuple of (FC matrices, demographic data, demographic types)
699
- """
700
- fc_matrices = []
701
- processed_files = []
702
-
703
- try:
704
- print(f"Found {len(nii_files)} fMRI files")
705
-
706
- # Process each NIfTI file - using only real data first
707
- for nii_file in nii_files:
708
- try:
709
- # Try to process real data, no synthetic fallback
710
- fc_triu = process_single_fmri(nii_file, allow_synthetic=False)
711
- fc_matrices.append(fc_triu)
712
- processed_files.append(nii_file)
713
- print(f"Successfully processed {nii_file}")
714
- except Exception as e:
715
- print(f"Error processing {nii_file}: {e}")
716
- # Skip this file and continue with the next one
717
-
718
- # Report how many files were successfully processed
719
- print(f"Successfully processed {len(fc_matrices)}/{len(nii_files)} files")
720
-
721
- # If we couldn't process any files, raise an error
722
- if not fc_matrices:
723
- print("ERROR: No real NIfTI files could be processed.")
724
- detailed_error = """
725
-
726
- The NIfTI files could not be processed. Here are possible reasons:
727
- 1. The files may be in a non-standard format
728
- 2. The coordinate system might not match the atlas
729
- 3. The image resolution or dimensions are incompatible with the processing pipeline
730
-
731
- To fix this:
732
- - Make sure your NIfTI files are in standard MNI space
733
- - Check that they have sufficient time points (at least 30)
734
- - Verify the files are valid 4D fMRI data
735
-
736
- You can also try preprocessing them with tools like FSL or AFNI before importing.
737
- """
738
- print(detailed_error)
739
- raise ValueError("Could not process any NIfTI files - please check the logs for details")
740
-
741
- # Create the feature matrix from the successfully processed files
742
- X = np.array(fc_matrices)
743
-
744
- # Check for NaN values in X
745
- if np.any(np.isnan(X)):
746
- print(f"Warning: Found NaN values in FC matrices, replacing with 0")
747
- X = np.nan_to_num(X)
748
-
749
- # Normalize the FC data
750
- X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
751
-
752
- # Check for NaN values that might have been introduced during normalization
753
- if np.any(np.isnan(X)):
754
- print(f"Warning: Found NaN values after normalization, replacing with 0")
755
- X = np.nan_to_num(X)
756
-
757
- # If we have demographic data, adjust it to match processed files
758
- if demo_data and len(demo_data) > 0 and len(demo_data[0]) > 0:
759
- # Adjust demo_data to match the number of processed files
760
- # This is necessary because we might not have processed all files
761
- if len(X) < len(demo_data[0]):
762
- print(f"Adjusting demographic data to match the {len(X)} processed files")
763
- try:
764
- # Get the indices of successfully processed files
765
- file_indices = [int(os.path.basename(f).split('_')[0].replace('P', '')) - 1 for f in processed_files]
766
- # Adjust each demographic variable
767
- demo_data_adjusted = []
768
- for d in demo_data:
769
- # Select only the demographic data for successfully processed files
770
- d_adjusted = [d[i] for i in file_indices if i < len(d)]
771
- demo_data_adjusted.append(d_adjusted)
772
- demo_data = demo_data_adjusted
773
- except Exception as e:
774
- print(f"Warning: Failed to adjust demographic data: {e}")
775
- print("Generating synthetic demographic data instead")
776
- # Generate synthetic demographics if adjustment fails
777
- _, synthetic_demo = generate_synthetic_fc_matrices(len(X))
778
- demo_data = synthetic_demo
779
- else:
780
- # If we don't have demographic data, generate synthetic one
781
- print("No demographic data available, generating synthetic data")
782
- _, synthetic_demo = generate_synthetic_fc_matrices(len(X))
783
- demo_data = synthetic_demo
784
-
785
- # Print final data shapes
786
- print(f"Final FC matrix shape: {X.shape}")
787
- print(f"Final demographic data lengths: {[len(d) for d in demo_data]}")
788
-
789
- except Exception as e:
790
- print(f"Error in FC preprocessing: {e}")
791
- print("ERROR: Failed to process real NIfTI files.")
792
- # Do not fall back to synthetic data
793
- raise ValueError(f"Failed to process FC matrices from real NIfTI files: {e}")
794
-
795
- return X, demo_data, demo_types
796
-
797
- def generate_synthetic_fc_matrices(num_samples=5):
798
- """
799
- Generate synthetic FC matrices and demographic data
800
-
801
- Args:
802
- num_samples: Number of samples to generate
803
-
804
- Returns:
805
- Tuple of (fc_matrices, demographic_data)
806
- """
807
- print(f"Generating {num_samples} synthetic FC matrices...")
808
-
809
- # Number of ROIs in Power atlas
810
- n_rois = 264
811
- n_triu_elements = n_rois * (n_rois - 1) // 2
812
-
813
- # Generate synthetic FC matrices
814
- np.random.seed(42) # for reproducibility
815
-
816
- # Create synthetic data with a reasonable structure
817
- # Upper triangular matrices (without diagonal) - like real FC matrices
818
- fc_matrices = np.zeros((num_samples, n_triu_elements))
819
-
820
- for i in range(num_samples):
821
- # Set a different seed for each sample for variety
822
- np.random.seed(42 + i)
823
-
824
- # Generate synthetic data with correlation-like values (-1 to 1)
825
- raw_corrs = np.random.rand(n_triu_elements) * 1.6 - 0.8 # Values between -0.8 and 0.8
826
- z_scores = np.arctanh(np.clip(raw_corrs, -0.99, 0.99)) # Apply Fisher's z-transform with clipping
827
-
828
- fc_matrices[i] = z_scores
829
-
830
- # Generate synthetic demographic data
831
- # Age (continuous): 30-80 years
832
- ages = np.random.randint(30, 81, num_samples)
833
-
834
- # Sex - roughly balanced
835
- sexes = np.random.choice(['M', 'F'], num_samples)
836
-
837
- # Months post stroke
838
- months = np.random.randint(1, 25, num_samples)
839
-
840
- # WAB scores
841
- wab_scores = np.random.randint(20, 101, num_samples)
842
-
843
- # Pack into demographic data format
844
- demographic_data = [ages, sexes, months, wab_scores]
845
-
846
- print(f"Generated {num_samples} synthetic FC matrices with shape {fc_matrices.shape}")
847
-
848
- return fc_matrices, demographic_data
849
-
850
- def clear_cache(cache_type=None):
851
- """
852
- Clear all or specific types of cache
853
-
854
- Args:
855
- cache_type: Type of cache to clear ('time_series', 'fc_matrices', 'maskers', 'atlas', 'latents')
856
- If None, clears all cache types
857
- """
858
- if cache_type is None:
859
- # Clear all cache types
860
- cache_types = ['time_series', 'fc_matrices', 'maskers', 'atlas', 'latents']
861
- else:
862
- # Clear specific cache type
863
- cache_types = [cache_type]
864
-
865
- for ctype in cache_types:
866
- cache_dir = os.path.join(CACHE_DIR, ctype)
867
- if os.path.exists(cache_dir):
868
- print(f"Clearing {ctype} cache...")
869
- try:
870
- for file in os.listdir(cache_dir):
871
- file_path = os.path.join(cache_dir, file)
872
- if os.path.isfile(file_path):
873
- os.remove(file_path)
874
- print(f"Successfully cleared {ctype} cache")
875
- except Exception as e:
876
- print(f"Error clearing {ctype} cache: {e}")
877
- else:
878
- print(f"Cache directory for {ctype} does not exist")
879
-
880
- print("Cache clearing complete")
881
 
882
- def download_and_cache_dataset(dataset_name, cache_dir=None):
883
- """
884
- Download a dataset from HuggingFace and save it to a local cache.
885
-
886
- Args:
887
- dataset_name (str): Name of the dataset on HuggingFace (e.g., 'SreekarB/OSFData1')
888
- cache_dir (str, optional): Directory to store the cached files. If None, uses the default HuggingFace cache.
889
-
890
- Returns:
891
- dataset: The loaded dataset object
892
- cache_path (str): Path to the cached dataset
893
- nii_files (list): List of NIfTI file paths found in the dataset
894
- """
895
- from datasets import load_dataset
896
- import os
897
- import tempfile
898
-
899
- print(f"Loading dataset: {dataset_name}")
900
-
901
- try:
902
- # Try to get the default HuggingFace cache dir if none provided
903
- if cache_dir is None:
904
- try:
905
- # Try newer HuggingFace Hub API
906
- from huggingface_hub import constants, utils
907
- cache_dir = utils.get_cache_dir()
908
- except (ImportError, AttributeError):
909
- try:
910
- # Try older API
911
- from huggingface_hub import HfFolder
912
- cache_dir = HfFolder.get_cache_dir()
913
- except (ImportError, AttributeError):
914
- # Fallback to temp directory
915
- cache_dir = os.path.join(tempfile.gettempdir(), "huggingface", "datasets")
916
- print(f"Using fallback cache directory: {cache_dir}")
917
-
918
- # Load the dataset (this will download it if not already cached)
919
- dataset = load_dataset(dataset_name, cache_dir=cache_dir)
920
-
921
- # Determine dataset cache path based on HuggingFace naming convention
922
- dataset_cache_path = os.path.join(cache_dir, "datasets", dataset_name.replace("/", "--"))
923
- print(f"Dataset cached at: {dataset_cache_path}")
924
-
925
- # Try to find the snapshots directory which contains the actual files
926
- snapshot_dir = None
927
- if os.path.exists(dataset_cache_path):
928
- # Look for snapshots directory which contains the actual files
929
- for root, dirs, _ in os.walk(dataset_cache_path):
930
- if 'snapshots' in dirs:
931
- snapshot_dir = os.path.join(root, 'snapshots')
932
- break
933
-
934
- # If we found the snapshots directory, use it to search for NIfTI files
935
- if snapshot_dir:
936
- dataset_cache_path = snapshot_dir
937
- print(f"Found snapshots directory at: {snapshot_dir}")
938
-
939
- # Locate NIfTI files in the cached dataset
940
- nii_files = []
941
- for root, dirs, filenames in os.walk(dataset_cache_path):
942
- for filename in filenames:
943
- if filename.endswith('.nii') or filename.endswith('.nii.gz'):
944
- nii_files.append(os.path.join(root, filename))
945
-
946
- print(f"Found {len(nii_files)} NIfTI files in dataset cache")
947
-
948
- except Exception as e:
949
- print(f"Error accessing HuggingFace dataset cache: {e}")
950
- print("Creating temporary cache directory...")
951
-
952
- # Create a temporary directory and load dataset there
953
- temp_cache_dir = tempfile.mkdtemp(prefix="hf_dataset_")
954
- dataset = load_dataset(dataset_name, cache_dir=temp_cache_dir)
955
- dataset_cache_path = temp_cache_dir
956
-
957
- # Search for NIfTI files in the temporary cache
958
- nii_files = []
959
- for root, dirs, filenames in os.walk(temp_cache_dir):
960
- for filename in filenames:
961
- if filename.endswith('.nii') or filename.endswith('.nii.gz'):
962
- nii_files.append(os.path.join(root, filename))
963
-
964
- print(f"Found {len(nii_files)} NIfTI files in temporary cache: {temp_cache_dir}")
965
-
966
- return dataset, dataset_cache_path, nii_files
967
 
968
- def load_and_preprocess_data(data_dir, demographic_file, use_hf_dataset=False,
969
- hf_nii_files=None, hf_demo_data=None, hf_demo_types=None,
970
- max_samples=None):
971
- """
972
- Load and preprocess both fMRI data and demographics
973
-
974
- Args:
975
- data_dir: Directory containing data files or HuggingFace dataset name
976
- demographic_file: Path to demographic CSV file (or None if using API)
977
- use_hf_dataset: Whether to use HuggingFace dataset API
978
- hf_nii_files: List of NIfTI file paths from HuggingFace dataset
979
- hf_demo_data: Demographic data from HuggingFace dataset API
980
- hf_demo_types: Types of demographic variables
981
- """
982
- if use_hf_dataset:
983
- # Handle HuggingFace dataset
984
- if hf_demo_data is not None and hf_demo_types is not None:
985
- # Use demographic data directly from API
986
- demo_data = hf_demo_data
987
- demo_types = hf_demo_types
988
- else:
989
- # Load demographics from file
990
- if demographic_file is not None:
991
- demo_df = pd.read_csv(demographic_file)
992
 
993
- # Map column names if needed (flexible column naming)
994
- column_mapping = {
995
- 'age_at_stroke': ['age_at_stroke', 'age', 'Age', 'patient_age'],
996
- 'sex': ['sex', 'gender', 'Gender', 'Sex'],
997
- 'months_post_stroke': ['months_post_stroke', 'mpo', 'MPO', 'months_post_onset'],
998
- 'wab_score': ['wab_score', 'wab_aq', 'WAB', 'WAB_AQ', 'aphasia_score']
999
- }
1000
-
1001
- # Check and map columns if necessary
1002
- for target_col, alt_cols in column_mapping.items():
1003
- if target_col not in demo_df.columns:
1004
- for alt_col in alt_cols:
1005
- if alt_col in demo_df.columns:
1006
- demo_df[target_col] = demo_df[alt_col]
1007
- print(f"Mapped {alt_col} to {target_col}")
1008
- break
1009
-
1010
- # Extract demographic data
1011
  demo_data = [
1012
- demo_df['age_at_stroke'].values,
1013
- demo_df['sex'].values,
1014
- demo_df['months_post_stroke'].values,
1015
- demo_df['wab_score'].values
1016
  ]
1017
 
1018
  demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
1019
- else:
1020
- # No demographic data provided
1021
- raise ValueError("No demographic data provided")
1022
 
1023
- # Try to properly access the HuggingFace dataset files
1024
- if not (hf_nii_files and len(hf_nii_files) > 0):
1025
- try:
1026
- print(f"Attempting to download and cache dataset: {data_dir}")
1027
- # Try to load and cache the dataset files
1028
- _, _, nii_files_from_hf = download_and_cache_dataset(data_dir)
1029
-
1030
- if nii_files_from_hf and len(nii_files_from_hf) > 0:
1031
- print(f"Successfully found {len(nii_files_from_hf)} NIfTI files in HuggingFace dataset cache")
1032
- hf_nii_files = nii_files_from_hf
1033
- except Exception as e:
1034
- print(f"Error accessing HuggingFace dataset files: {e}")
1035
- hf_nii_files = []
1036
-
1037
- # Use provided NIfTI files from HuggingFace
1038
- if hf_nii_files and len(hf_nii_files) > 0:
1039
- # Apply sample limit if specified
1040
- if max_samples is not None and len(hf_nii_files) > max_samples:
1041
- print(f"Limiting to {max_samples} NIfTI files as specified (from {len(hf_nii_files)} available)")
1042
- nii_files = hf_nii_files[:max_samples]
1043
- else:
1044
- nii_files = hf_nii_files
1045
-
1046
- print(f"Using {len(nii_files)} NIfTI files from HuggingFace dataset")
1047
- else:
1048
- # Check if we should use synthetic data
1049
- if PREDICTION_CONFIG.get('use_synthetic_nifti', True):
1050
- # Create synthetic NIfTI files as fallback
1051
- print("No NIfTI files found in HuggingFace dataset - creating synthetic data")
1052
-
1053
- try:
1054
- import tempfile
1055
- import os
1056
- import numpy as np
1057
- import nibabel as nib
1058
- from pathlib import Path
1059
-
1060
- # Create a temporary directory for our synthetic files
1061
- temp_dir = tempfile.mkdtemp(prefix="synthetic_nifti_")
1062
- print(f"Created temp directory for synthetic data: {temp_dir}")
1063
-
1064
- # How many patients do we need to simulate?
1065
- num_patients = len(demo_data[0]) if demo_data and len(demo_data) > 0 else 10
1066
- print(f"Creating synthetic data for {num_patients} patients")
1067
-
1068
- nii_files = []
1069
-
1070
- # Create synthetic NIfTI files (264x264 FC matrices)
1071
- for i in range(num_patients):
1072
- # Create random symmetric matrix
1073
- np.random.seed(i) # For reproducibility
1074
-
1075
- # Generate a 60x75x60 random volume (typical fMRI dimensions)
1076
- vol_shape = (60, 75, 60)
1077
- data = np.random.randn(*vol_shape)
1078
-
1079
- # Create the NIfTI file
1080
- img = nib.Nifti1Image(data, np.eye(4))
1081
-
1082
- # Save to temp directory
1083
- file_path = os.path.join(temp_dir, f"P{i+1:02d}_rs.nii.gz")
1084
- nib.save(img, file_path)
1085
- nii_files.append(file_path)
1086
-
1087
- print(f"Successfully created {len(nii_files)} synthetic NIfTI files")
1088
-
1089
- except Exception as e:
1090
- print(f"Error creating synthetic NIfTI data: {e}")
1091
- raise ValueError(f"No NIfTI files found in HuggingFace dataset and failed to create synthetic data: {e}")
1092
- else:
1093
- # Don't use synthetic data
1094
- raise ValueError("No NIfTI files found in HuggingFace dataset and synthetic data generation is disabled")
1095
- else:
1096
- # Standard local file loading
1097
- if demographic_file is not None:
1098
- # Load demographics
1099
- demo_df = pd.read_csv(demographic_file)
1100
 
1101
- # Map column names if needed (flexible column naming)
1102
- column_mapping = {
1103
- 'age_at_stroke': ['age_at_stroke', 'age', 'Age', 'patient_age'],
1104
- 'sex': ['sex', 'gender', 'Gender', 'Sex'],
1105
- 'months_post_stroke': ['months_post_stroke', 'mpo', 'MPO', 'months_post_onset'],
1106
- 'wab_score': ['wab_score', 'wab_aq', 'WAB', 'WAB_AQ', 'aphasia_score']
1107
- }
1108
 
1109
- # Check and map columns if necessary
1110
- for target_col, alt_cols in column_mapping.items():
1111
- if target_col not in demo_df.columns:
1112
- for alt_col in alt_cols:
1113
- if alt_col in demo_df.columns:
1114
- demo_df[target_col] = demo_df[alt_col]
1115
- print(f"Mapped {alt_col} to {target_col}")
1116
- break
1117
 
1118
- demo_data = [
1119
- demo_df['age_at_stroke'].values,
1120
- demo_df['sex'].values,
1121
- demo_df['months_post_stroke'].values,
1122
- demo_df['wab_score'].values
1123
- ]
1124
 
1125
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
1126
- else:
1127
- raise ValueError("No demographic file provided")
1128
-
1129
- # Load fMRI files from local directory
1130
- nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
1131
-
1132
- # Also look for .nii files (without .gz)
1133
- nii_files_nogz = sorted(list(Path(data_dir).glob('*.nii')))
1134
- nii_files.extend(nii_files_nogz)
1135
 
1136
- # Apply sample limit if specified
1137
- if max_samples is not None and len(nii_files) > max_samples:
1138
- print(f"Limiting to {max_samples} NIfTI files as specified (from {len(nii_files)} available)")
1139
- nii_files = nii_files[:max_samples]
1140
 
1141
- if not nii_files:
1142
- print(f"No NIfTI files (*.nii or *.nii.gz) found in {data_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1143
 
1144
- # Check if we should use synthetic data
1145
- if PREDICTION_CONFIG.get('use_synthetic_nifti', True):
1146
- print("Creating synthetic NIfTI data as fallback")
1147
-
1148
- try:
1149
- import tempfile
1150
- import os
1151
- import numpy as np
1152
- import nibabel as nib
1153
-
1154
- # Create a temporary directory for our synthetic files
1155
- temp_dir = tempfile.mkdtemp(prefix="synthetic_nifti_")
1156
- print(f"Created temp directory for synthetic data: {temp_dir}")
1157
-
1158
- # How many patients do we need to simulate?
1159
- num_patients = len(demo_data[0]) if demo_data and len(demo_data) > 0 else 10
1160
- print(f"Creating synthetic data for {num_patients} patients")
1161
-
1162
- nii_files = []
1163
-
1164
- # Create synthetic NIfTI files
1165
- for i in range(num_patients):
1166
- # Create random symmetric matrix
1167
- np.random.seed(i) # For reproducibility
1168
-
1169
- # Generate a 60x75x60 random volume (typical fMRI dimensions)
1170
- vol_shape = (60, 75, 60)
1171
- data = np.random.randn(*vol_shape)
1172
-
1173
- # Create the NIfTI file
1174
- img = nib.Nifti1Image(data, np.eye(4))
1175
-
1176
- # Save to temp directory
1177
- file_path = os.path.join(temp_dir, f"P{i+1:02d}_rs.nii.gz")
1178
- nib.save(img, file_path)
1179
- nii_files.append(file_path)
1180
-
1181
- print(f"Successfully created {len(nii_files)} synthetic NIfTI files")
1182
-
1183
- except Exception as e:
1184
- print(f"Error creating synthetic NIfTI data: {e}")
1185
- raise ValueError(f"No NIfTI files found in {data_dir} and failed to create synthetic data: {e}")
1186
- else:
1187
- # Don't use synthetic data
1188
- raise ValueError(f"No NIfTI files (*.nii or *.nii.gz) found in {data_dir} and synthetic data generation is disabled")
1189
- else:
1190
- print(f"Found {len(nii_files)} NIfTI files in {data_dir}")
1191
-
1192
- # Process fMRI files to FC matrices
1193
- X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
1194
-
1195
- # Check for sample size consistency and fix if needed
1196
- print(f"After preprocessing: X shape: {X.shape}, demo_data lengths: {[len(d) for d in demo_data]}")
1197
 
1198
- # Make sure all sample sizes match
1199
- if X.shape[0] != len(demo_data[0]):
1200
- print(f"WARNING: Sample size mismatch detected! X: {X.shape[0]}, demo: {len(demo_data[0])}")
1201
-
1202
- # Determine the smaller size
1203
- min_samples = min(X.shape[0], len(demo_data[0]))
1204
- print(f"Adjusting to {min_samples} samples")
1205
-
1206
- # Trim X and demographic data to match
1207
- X = X[:min_samples]
1208
- demo_data = [d[:min_samples] for d in demo_data]
1209
-
1210
- print(f"After adjustment: X shape: {X.shape}, demo_data lengths: {[len(d) for d in demo_data]}")
1211
 
1212
- return X, demo_data, demo_types
 
1
  import numpy as np
2
  import pandas as pd
3
+ from datasets import load_dataset
4
+ from nilearn import input_data, connectome
 
 
 
 
 
5
  from nilearn.image import load_img
6
  import nibabel as nib
7
+ import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
10
  """
11
+ Process fMRI data to generate functional connectivity matrices
12
 
13
+ Parameters:
14
+ - dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
15
+ - demo_data: Optional demographic data, required if providing NIfTI files
16
+ - demo_types: Optional demographic data types, required if providing NIfTI files
 
 
 
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  Returns:
19
+ - X: Array of FC matrices
20
+ - demo_data: Demographic data
21
+ - demo_types: Demographic data types
22
  """
23
+ print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
 
 
 
 
 
 
 
 
 
 
 
24
 
25
+ # For SreekarB/OSFData dataset, the data will be loaded from dataset features
26
+ if isinstance(dataset_or_niifiles, str):
27
+ dataset_name = dataset_or_niifiles
28
+ print(f"Loading data from dataset: {dataset_name}")
29
  try:
30
+ # Try multiple approaches to load the dataset
31
+ approaches = [
32
+ lambda: load_dataset(dataset_name, split="train"),
33
+ lambda: load_dataset(dataset_name), # Try without split
34
+ lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code
35
+ lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
36
+ ]
37
+
38
+ dataset = None
39
+ last_error = None
40
+
41
+ for i, approach in enumerate(approaches):
42
+ if approach is None:
43
+ continue
44
+
45
+ try:
46
+ print(f"Attempt {i+1} to load dataset...")
47
+ dataset = approach()
48
+ print(f"Successfully loaded dataset with approach {i+1}!")
49
+ break
50
+ except Exception as e:
51
+ print(f"Attempt {i+1} failed: {e}")
52
+ last_error = e
53
+
54
+ if dataset is None:
55
+ print(f"All attempts to load dataset failed. Last error: {last_error}")
56
+ raise ValueError(f"Could not load dataset {dataset_name}")
57
  except Exception as e:
58
+ print(f"Error during dataset loading: {e}")
59
+ raise
60
+
61
+ # Prepare demographics data from the dataset
62
+ if demo_data is None:
63
+ # Create demo_data from the dataset
64
+ demo_df = pd.DataFrame({
65
+ 'age': dataset['age'],
66
+ 'gender': dataset['gender'],
67
+ 'mpo': dataset['mpo'],
68
+ 'wab_aq': dataset['wab_aq']
69
+ })
70
+
71
+ demo_data = [
72
+ demo_df['age'].values,
73
+ demo_df['gender'].values,
74
+ demo_df['mpo'].values,
75
+ demo_df['wab_aq'].values
76
+ ]
77
+
78
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
79
 
80
+ # Look for NIfTI files in P01_rs.nii format
81
+ print("Searching for NIfTI files in dataset columns...")
82
+ nii_files = []
 
 
 
 
 
 
 
 
 
83
 
84
+ # Create a temp directory for downloads
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  import tempfile
86
+ from huggingface_hub import hf_hub_download
87
+ import shutil
88
 
89
+ temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
90
+ print(f"Created temporary directory for NIfTI files: {temp_dir}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  try:
93
+ # First approach: Check if there are any columns containing file paths
94
+ nii_columns = []
95
+ for col in dataset.column_names:
96
+ # Check if column name suggests NIfTI files
97
+ if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
98
+ nii_columns.append(col)
99
+ # Or check if column contains file paths
100
+ elif len(dataset) > 0:
101
+ first_val = dataset[0][col]
102
+ if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
103
+ nii_columns.append(col)
104
+
105
+ if nii_columns:
106
+ print(f"Found columns that may contain NIfTI files: {nii_columns}")
107
 
108
+ for col in nii_columns:
109
+ print(f"Processing column '{col}'...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
+ for i, item in enumerate(dataset[col]):
112
+ if not isinstance(item, str):
113
+ print(f"Item {i} in column {col} is not a string but {type(item)}")
114
+ continue
115
+
116
+ if not (item.endswith('.nii') or item.endswith('.nii.gz')):
117
+ print(f"Item {i} in column {col} is not a NIfTI file: {item}")
118
+ continue
119
+
120
+ print(f"Downloading {item} from dataset {dataset_name}...")
121
+
122
  try:
123
+ # Attempt to download with explicit filename
124
+ file_path = hf_hub_download(
125
+ repo_id=dataset_name,
126
+ filename=item,
127
+ repo_type="dataset",
128
+ cache_dir=temp_dir
129
+ )
130
+ nii_files.append(file_path)
131
+ print(f"✓ Successfully downloaded {item}")
132
+ except Exception as e1:
133
+ print(f"Error downloading with explicit filename: {e1}")
134
 
135
+ # Second attempt: try with the item's basename
136
+ try:
137
+ basename = os.path.basename(item)
138
+ print(f"Trying with basename: {basename}")
139
+ file_path = hf_hub_download(
140
+ repo_id=dataset_name,
141
+ filename=basename,
142
+ repo_type="dataset",
143
+ cache_dir=temp_dir
144
+ )
145
+ nii_files.append(file_path)
146
+ print(f"✓ Successfully downloaded {basename}")
147
+ except Exception as e2:
148
+ print(f"Error downloading with basename: {e2}")
149
+
150
+ # Third attempt: check if it's a binary blob in the dataset
151
+ try:
152
+ if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
153
+ print("Found binary data in dataset, saving to temporary file...")
154
+ binary_data = dataset[i]['bytes']
155
+ temp_file = os.path.join(temp_dir, basename)
156
+ with open(temp_file, 'wb') as f:
157
+ f.write(binary_data)
158
+ nii_files.append(temp_file)
159
+ print(f"✓ Saved binary data to {temp_file}")
160
+ except Exception as e3:
161
+ print(f"Error handling binary data: {e3}")
162
+
163
+ # Last resort: look for the file locally
164
+ local_path = os.path.join(os.getcwd(), item)
165
+ if os.path.exists(local_path):
166
+ nii_files.append(local_path)
167
+ print(f"✓ Found {item} locally")
168
+ else:
169
+ print(f"❌ Warning: Could not find {item} anywhere")
170
+
171
+ # Second approach: Try to find NIfTI files in dataset repository directly
172
+ if not nii_files:
173
+ print("No NIfTI files found in dataset columns. Trying direct repository search...")
174
 
 
 
 
 
 
 
 
175
  try:
176
+ from huggingface_hub import list_repo_files, hf_hub_download
177
+
178
+ # Try to list all files in the repository
179
+ try:
180
+ print("Listing all repository files...")
181
+ all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
182
+ print(f"Found {len(all_repo_files)} files in repository")
183
 
184
+ # First prioritize P*_rs.nii files
185
+ p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
 
 
186
 
187
+ # Then include all other NIfTI files
188
+ other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
189
+
190
+ # Combine, with P*_rs.nii files first
191
+ nii_repo_files = p_rs_files + other_nii_files
192
+
193
+ if nii_repo_files:
194
+ print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
195
+
196
+ # Download each file
197
+ for nii_file in nii_repo_files:
198
+ try:
199
+ file_path = hf_hub_download(
200
+ repo_id=dataset_name,
201
+ filename=nii_file,
202
+ repo_type="dataset",
203
+ cache_dir=temp_dir
204
+ )
205
+ nii_files.append(file_path)
206
+ print(f"✓ Downloaded {nii_file}")
207
+ except Exception as e:
208
+ print(f"Error downloading {nii_file}: {e}")
209
  except Exception as e:
210
+ print(f"Error listing repository files: {e}")
211
+ print("Will try alternative approaches...")
212
+
213
+ # If repo listing fails, try with common NIfTI file patterns directly
214
+ if not nii_files:
215
+ print("Trying common NIfTI file patterns...")
216
+
217
+ # Focus specifically on P*_rs.nii pattern
218
+ patterns = []
219
+
220
+ # Generate P01_rs.nii through P30_rs.nii
221
+ for i in range(1, 31): # Try subjects 1-30
222
+ patterns.append(f"P{i:02d}_rs.nii")
223
+
224
+ # Also try with .nii.gz extension
225
+ for i in range(1, 31):
226
+ patterns.append(f"P{i:02d}_rs.nii.gz")
227
+
228
+ # Include a few other common patterns as fallbacks
229
+ patterns.extend([
230
+ "sub-01_task-rest_bold.nii.gz", # BIDS format
231
+ "fmri.nii.gz", "bold.nii.gz",
232
+ "rest.nii.gz"
233
+ ])
234
+
235
+ for pattern in patterns:
236
+ try:
237
+ print(f"Trying to download {pattern}...")
238
+ file_path = hf_hub_download(
239
+ repo_id=dataset_name,
240
+ filename=pattern,
241
+ repo_type="dataset",
242
+ cache_dir=temp_dir
243
+ )
244
+ nii_files.append(file_path)
245
+ print(f"✓ Successfully downloaded {pattern}")
246
+ except Exception as e:
247
+ print(f"× Failed to download {pattern}")
248
+
249
+ # If we still couldn't find any files, check if data files are nested
250
+ if not nii_files:
251
+ print("Checking for nested data files...")
252
+ nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
253
+
254
+ for path in nested_paths:
255
+ for pattern in patterns:
256
+ nested_file = f"{path}{pattern}"
257
+ try:
258
+ print(f"Trying to download {nested_file}...")
259
+ file_path = hf_hub_download(
260
+ repo_id=dataset_name,
261
+ filename=nested_file,
262
+ repo_type="dataset",
263
+ cache_dir=temp_dir
264
+ )
265
+ nii_files.append(file_path)
266
+ print(f"✓ Successfully downloaded {nested_file}")
267
+ # If we found one file in this directory, try to find all files in it
268
+ try:
269
+ all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
270
+ nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
271
+ print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
272
+
273
+ for nii_file in nii_files_in_dir:
274
+ if nii_file != nested_file: # Skip the one we already downloaded
275
+ try:
276
+ file_path = hf_hub_download(
277
+ repo_id=dataset_name,
278
+ filename=nii_file,
279
+ repo_type="dataset",
280
+ cache_dir=temp_dir
281
+ )
282
+ nii_files.append(file_path)
283
+ print(f"✓ Downloaded {nii_file}")
284
+ except Exception as e:
285
+ print(f"Error downloading {nii_file}: {e}")
286
+ except Exception as e:
287
+ print(f"Error finding additional files in {path}: {e}")
288
+ except Exception as e:
289
+ pass
290
+
291
+ except Exception as e:
292
+ print(f"Error during repository exploration: {e}")
293
 
294
+ # If we still don't have any files, try to search for P*_rs.nii pattern specifically
295
+ if not nii_files:
296
+ print("Trying to find files matching P*_rs.nii pattern specifically...")
 
 
 
 
 
 
297
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
298
  try:
299
+ # List all files in the repository (if we haven't already)
300
+ if not 'all_repo_files' in locals():
301
+ from huggingface_hub import list_repo_files
302
+ try:
303
+ all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
304
+ except Exception as e:
305
+ print(f"Error listing repo files: {e}")
306
+ all_repo_files = []
307
 
308
+ # Look for files matching the pattern exactly (P*_rs.nii)
309
+ pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
310
+
311
+ # If we don't find any exact matches, try a more relaxed pattern
312
+ if not pattern_files:
313
+ pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
314
+
315
+ if pattern_files:
316
+ print(f"Found {len(pattern_files)} files matching rs.nii pattern")
317
+
318
+ # Download each file
319
+ for pattern_file in pattern_files:
320
+ try:
321
+ file_path = hf_hub_download(
322
+ repo_id=dataset_name,
323
+ filename=pattern_file,
324
+ repo_type="dataset",
325
+ cache_dir=temp_dir
326
+ )
327
+ nii_files.append(file_path)
328
+ print(f"✓ Downloaded {pattern_file}")
329
+ except Exception as e:
330
+ print(f"Error downloading {pattern_file}: {e}")
331
+ except Exception as e:
332
+ print(f"Error searching for pattern files: {e}")
333
+
334
+ print(f"Found total of {len(nii_files)} NIfTI files")
335
  except Exception as e:
336
+ print(f"Unexpected error during NIfTI file search: {e}")
337
+ import traceback
338
+ traceback.print_exc()
339
+
340
+ # If we found NIfTI files, process them to FC matrices
341
+ if nii_files:
342
+ print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
343
+
344
+ # Load Power 264 atlas
345
+ from nilearn import datasets
346
+ power = datasets.fetch_coords_power_2011()
347
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
348
+
349
+ masker = input_data.NiftiSpheresMasker(
350
+ coords, radius=5,
351
+ standardize=True,
352
+ memory='nilearn_cache', memory_level=1,
353
+ verbose=0,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
354
  detrend=True,
 
355
  low_pass=0.1,
356
  high_pass=0.01,
357
+ t_r=2.0 # Adjust TR according to your data
358
  )
359
 
360
+ # Process fMRI data and compute FC matrices
361
+ fc_matrices = []
362
+ valid_files = 0
363
+ total_files = len(nii_files)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
+ for nii_file in nii_files:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  try:
367
+ print(f"Processing {nii_file}...")
368
+ fmri_img = load_img(nii_file)
369
 
370
+ # Check image dimensions
371
+ if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
372
+ print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
373
  continue
374
+
375
+ time_series = masker.fit_transform(fmri_img)
376
 
377
+ # Validate time series data
378
+ if np.isnan(time_series).any() or np.isinf(time_series).any():
379
+ print(f"Warning: {nii_file} contains NaN or Inf values after masking")
380
+ # Replace NaNs with zeros for this file
381
+ time_series = np.nan_to_num(time_series)
382
 
383
+ correlation_measure = connectome.ConnectivityMeasure(
384
+ kind='correlation',
385
+ vectorize=False,
386
+ discard_diagonal=False
387
+ )
388
 
389
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
 
390
 
391
+ # Check for invalid correlation values
392
+ if np.isnan(fc_matrix).any():
393
+ print(f"Warning: {nii_file} produced NaN correlation values")
394
+ continue
395
+
396
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
397
+ fc_triu = fc_matrix[triu_indices]
398
 
399
+ # Fisher z-transform with proper bounds check
400
+ # Clip correlation values to valid range for arctanh
401
+ fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
402
+ fc_triu = np.arctanh(fc_triu_clipped)
403
 
404
+ fc_matrices.append(fc_triu)
405
+ valid_files += 1
406
+ print(f"Successfully processed {nii_file} to FC matrix")
407
 
 
 
 
 
 
 
 
 
 
 
 
408
  except Exception as e:
409
+ print(f"Error processing {nii_file}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
+ if fc_matrices:
412
+ print(f"Successfully processed {valid_files} out of {total_files} files")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
+ # Ensure all matrices have the same dimensions
415
+ dims = [m.shape[0] for m in fc_matrices]
416
+ if len(set(dims)) > 1:
417
+ print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
418
+ # Use the most common dimension
419
+ from collections import Counter
420
+ most_common_dim = Counter(dims).most_common(1)[0][0]
421
+ print(f"Using most common dimension: {most_common_dim}")
422
+ fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
423
 
424
+ X = np.array(fc_matrices)
 
 
425
 
426
+ # Normalize the FC data
427
+ mean_x = np.mean(X, axis=0)
428
+ std_x = np.std(X, axis=0)
429
 
430
+ # Handle zero standard deviation
431
+ std_x[std_x == 0] = 1.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
432
 
433
+ X = (X - mean_x) / std_x
434
+ print(f"Created FC matrices with shape {X.shape}")
 
 
435
 
436
+ # Make sure demo_data matches the number of FC matrices
437
+ if len(demo_data[0]) != X.shape[0]:
438
+ print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
439
+ f"doesn't match number of FC matrices ({X.shape[0]})")
440
+ # Adjust demo_data to match FC matrices
441
+ indices = list(range(min(len(demo_data[0]), X.shape[0])))
442
+ X = X[indices]
443
+ demo_data = [d[indices] for d in demo_data]
444
+
445
+ return X, demo_data, demo_types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
446
 
447
+ print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
448
+ # Return a placeholder with the right demographics but empty FC
449
+ n_subjects = len(dataset)
450
+ n_rois = 264
451
+ fc_dim = (n_rois * (n_rois - 1)) // 2
452
+ X = np.zeros((n_subjects, fc_dim))
453
+ print(f"Created placeholder FC matrices with shape {X.shape}")
454
+ return X, demo_data, demo_types
455
+
456
+ elif isinstance(dataset_or_niifiles, str):
457
+ # Handle real dataset with actual fMRI data
458
+ dataset = load_dataset(dataset_or_niifiles, split="train")
459
+
460
+ # Load Power 264 atlas
461
+ from nilearn import datasets
462
+ power = datasets.fetch_coords_power_2011()
463
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
464
+
465
+ masker = input_data.NiftiSpheresMasker(
466
+ coords, radius=5,
467
+ standardize=True,
468
+ memory='nilearn_cache', memory_level=1,
469
+ verbose=0,
470
+ detrend=True,
471
+ low_pass=0.1,
472
+ high_pass=0.01,
473
+ t_r=2.0 # Adjust TR according to your data
474
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
 
476
+ # Load demographic data if needed
477
+ if demo_data is None:
478
+ if 'demographics' in dataset.features:
479
+ demo_df = pd.DataFrame(dataset['demographics'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
480
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
481
  demo_data = [
482
+ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
483
+ demo_df['sex'].values if 'sex' in demo_df.columns else [],
484
+ demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
485
+ demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
486
  ]
487
 
488
  demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
 
 
 
489
 
490
+ # Process fMRI data and compute FC matrices
491
+ fc_matrices = []
492
+ for nii_file in dataset['nii_files']:
493
+ fmri_img = load_img(nii_file)
494
+ time_series = masker.fit_transform(fmri_img)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
495
 
496
+ correlation_measure = connectome.ConnectivityMeasure(
497
+ kind='correlation', vectorize=False, discard_diagonal=False
498
+ )
 
 
 
 
499
 
500
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
 
 
 
 
 
 
 
501
 
502
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
503
+ fc_triu = fc_matrix[triu_indices]
 
 
 
 
504
 
505
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
506
+
507
+ fc_matrices.append(fc_triu)
 
 
 
 
 
 
 
508
 
509
+ X = np.array(fc_matrices)
 
 
 
510
 
511
+ elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
512
+ # Handle a list of NIfTI files
513
+ # Similar processing as above but with local files
514
+ print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
515
+
516
+ # Load Power 264 atlas
517
+ from nilearn import datasets
518
+ power = datasets.fetch_coords_power_2011()
519
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
520
+
521
+ masker = input_data.NiftiSpheresMasker(
522
+ coords, radius=5,
523
+ standardize=True,
524
+ memory='nilearn_cache', memory_level=1,
525
+ verbose=0,
526
+ detrend=True,
527
+ low_pass=0.1,
528
+ high_pass=0.01,
529
+ t_r=2.0
530
+ )
531
+
532
+ fc_matrices = []
533
+ for nii_file in dataset_or_niifiles:
534
+ fmri_img = load_img(nii_file)
535
+ time_series = masker.fit_transform(fmri_img)
536
 
537
+ correlation_measure = connectome.ConnectivityMeasure(
538
+ kind='correlation', vectorize=False, discard_diagonal=False
539
+ )
540
+
541
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
542
+
543
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
544
+ fc_triu = fc_matrix[triu_indices]
545
+
546
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
547
+
548
+ fc_matrices.append(fc_triu)
549
+
550
+ X = np.array(fc_matrices)
551
+ else:
552
+ raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
553
 
554
+ # Normalize the FC data
555
+ X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
 
 
 
 
 
 
 
 
 
 
 
556
 
557
+ return X, demo_data, demo_types
gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
main.py CHANGED
@@ -1,490 +1,272 @@
1
  import os
2
- import numpy as np # Make sure numpy is imported at the top level
 
 
 
 
3
  import torch
4
  from pathlib import Path
 
 
 
 
 
 
5
  import pandas as pd
 
 
6
 
7
- # Set Huggingface cache directory to avoid permission issues
8
- os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
9
- os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
10
-
11
- from data_preprocessing import load_and_preprocess_data
12
- from vae_model import DemoVAE
13
- from rcf_prediction import AphasiaTreatmentPredictor
14
- from visualization import plot_fc_matrices, plot_learning_curves
15
- from config import MODEL_CONFIG, PREDICTION_CONFIG
16
- # Configure matplotlib for headless environment
17
- import matplotlib
18
- matplotlib.use('Agg') # Use non-interactive backend
19
- import matplotlib.pyplot as plt
20
-
21
- def run_analysis(data_dir="data",
22
- demographic_file="demographics.csv",
23
- treatment_file="treatment_outcomes.csv",
24
- latent_dim=32,
25
- nepochs=1000,
26
- bsize=16,
27
- save_model=True,
28
- use_hf_dataset=False,
29
- hf_dataset=None,
30
- hf_nii_files=None,
31
- hf_demo_data=None,
32
- hf_demo_types=None,
33
- return_data=False,
34
- max_samples=None,
35
- skip_treatment_prediction=False):
36
  """
37
- Run the complete analysis pipeline
38
-
39
- Args:
40
- data_dir: Directory containing data files or HuggingFace dataset name
41
- demographic_file: Path to demographic CSV file (or None if using API)
42
- treatment_file: Path to treatment outcomes CSV file
43
- latent_dim: Dimension of VAE latent space
44
- nepochs: Number of training epochs
45
- bsize: Batch size for training
46
- save_model: Whether to save trained models
47
- use_hf_dataset: Whether to use HuggingFace dataset API
48
- hf_dataset: Pre-loaded HuggingFace dataset object
49
- hf_nii_files: List of NIfTI file paths from HuggingFace dataset
50
- hf_demo_data: Demographic data from HuggingFace dataset API
51
- hf_demo_types: Types of demographic variables (continuous/categorical)
52
- return_data: Whether to return raw data for accuracy calculations
53
- skip_treatment_prediction: Skip treatment prediction step (for FC analysis only)
54
  """
55
- # Update MODEL_CONFIG with user-specified parameters
56
- MODEL_CONFIG.update({
57
- 'latent_dim': latent_dim,
58
- 'nepochs': nepochs,
59
- 'bsize': bsize
60
- })
61
-
62
- # Create output directories
63
- os.makedirs('models', exist_ok=True)
64
- os.makedirs('results', exist_ok=True)
65
-
66
- # Load and preprocess data based on source
67
- print("Loading and preprocessing data...")
68
- if use_hf_dataset:
69
- # Use HuggingFace dataset
70
- if hf_demo_data is not None and hf_demo_types is not None:
71
- # If demographic data is provided directly from API
72
- print(f"Using demographic data from HuggingFace API")
73
- X, demo_data, demo_types = load_and_preprocess_data(
74
- data_dir,
75
- demographic_file,
76
- use_hf_dataset=True,
77
- hf_nii_files=hf_nii_files,
78
- hf_demo_data=hf_demo_data,
79
- hf_demo_types=hf_demo_types,
80
- max_samples=max_samples
81
- )
82
- else:
83
- # If demographic file is provided but still using HF for NIfTI
84
- X, demo_data, demo_types = load_and_preprocess_data(
85
- data_dir,
86
- demographic_file,
87
- use_hf_dataset=True,
88
- hf_nii_files=hf_nii_files,
89
- max_samples=max_samples
90
- )
91
- else:
92
- # Standard local file loading
93
- X, demo_data, demo_types = load_and_preprocess_data(data_dir, demographic_file, max_samples=max_samples)
94
-
95
- # If we're doing treatment prediction, load treatment outcomes
96
- treatment_outcomes = None
97
- if not skip_treatment_prediction and treatment_file:
98
- treatment_df = pd.read_csv(treatment_file)
99
- treatment_outcomes = treatment_df['outcome_score'].values
100
-
101
- # Ensure we have enough treatment outcomes based on input data
102
- if len(treatment_outcomes) < len(X):
103
- print(f"WARNING: Not enough treatment outcomes ({len(treatment_outcomes)}) for input data ({len(X)})")
104
- # Generate synthetic outcomes to match input data size
105
- synthetic_outcomes = np.random.normal(5, 2, size=(len(X) - len(treatment_outcomes)))
106
- treatment_outcomes = np.concatenate([treatment_outcomes, synthetic_outcomes])
107
- print(f"Added {len(synthetic_outcomes)} synthetic outcomes to match input data")
108
-
109
- # Ensure we don't have too many treatment outcomes
110
- if len(treatment_outcomes) > len(X):
111
- print(f"WARNING: More treatment outcomes ({len(treatment_outcomes)}) than input data ({len(X)})")
112
- treatment_outcomes = treatment_outcomes[:len(X)]
113
- print(f"Trimmed treatment outcomes to match input data size: {len(X)}")
114
-
115
- print(f"Using {len(treatment_outcomes)} treatment outcomes for {len(X)} input samples")
116
-
117
- # Initialize and train VAE
118
- print("Training VAE...")
119
- vae = DemoVAE(**MODEL_CONFIG)
120
- try:
121
- train_losses, val_losses = vae.fit(X, demo_data, demo_types)
122
- print(f"VAE training complete. Final train loss: {train_losses[-1]:.4f}, final validation loss: {val_losses[-1]:.4f}")
123
- except Exception as e:
124
- print(f"Error during VAE training: {e}")
125
- print("Using empty lists for losses as fallback")
126
- train_losses, val_losses = [], []
127
 
128
- # Get latent representations
129
- print("Extracting latent representations...")
130
- latents = vae.get_latents(X)
131
 
132
- # Save latent representations for other analyses
133
- np.save('results/latents.npy', latents)
 
 
134
 
135
- # Format demographics for predictor and results
136
- demographics = {}
 
 
 
137
 
138
- # Define both standard and alternative keys
139
- demo_keys = ['age_at_stroke', 'sex', 'months_post_stroke', 'wab_score']
140
- alternate_keys = {'age_at_stroke': 'age', 'months_post_stroke': 'mpo', 'wab_score': 'wab_aq'}
 
141
 
142
- # Map demographic data to consistent keys
143
- for i, key in enumerate(demo_keys):
144
- if i < len(demo_data):
145
- demographics[key] = demo_data[i]
146
- # Also add alternate versions of the key for compatibility
147
- if key in alternate_keys:
148
- demographics[alternate_keys[key]] = demo_data[i]
 
 
 
149
 
150
- # Print the keys available in demographics for debugging
151
- print(f"Demographics keys available: {list(demographics.keys())}")
152
 
153
- # Generate reconstructions and synthetic FC
154
- try:
155
- print("Generating reconstructed FC matrices...")
156
- reconstructed = vae.transform(X, demo_data, demo_types)
157
- print(f"Reconstructed FC shape: {reconstructed.shape}")
 
 
 
 
158
 
159
- print("Generating synthetic FC matrix...")
160
- generated = vae.transform(1,
161
- [d[:1] for d in demo_data],
162
- demo_types)
163
- print(f"Generated FC shape: {generated.shape}")
164
 
165
- # Save for other analyses
166
- print("Saving FC matrices...")
167
- np.save('results/reconstructed_fc.npy', reconstructed)
168
- np.save('results/generated_fc.npy', generated)
169
 
170
- # Also save original FC for comparison
171
- np.save('results/original_fc.npy', X)
172
- print("Saved FC matrices to results directory")
173
-
174
- # Make sure all are numpy arrays and print diagnostic info
175
- original = np.array(X[0])
176
- recon = np.array(reconstructed[0])
177
- gen = np.array(generated[0])
 
 
 
178
 
179
- print(f"FC shapes for visualization - Original: {original.shape}, Reconstructed: {recon.shape}, Generated: {gen.shape}")
180
 
181
- # Add additional type checking
182
- if len(original.shape) == 1:
183
- print("Original FC is in vector form (will be converted to matrix)")
184
- if len(recon.shape) == 1:
185
- print("Reconstructed FC is in vector form (will be converted to matrix)")
186
- if len(gen.shape) == 1:
187
- print("Generated FC is in vector form (will be converted to matrix)")
 
188
 
189
- # Create visualization
190
- print("Creating FC matrix visualization...")
191
- fc_fig = plot_fc_matrices(original, recon, gen)
192
- print("FC visualization created successfully")
193
- except Exception as e:
194
- import traceback
195
- print(f"Error creating FC visualization: {e}")
196
- print(f"Detailed error: {traceback.format_exc()}")
197
- fc_fig = plt.figure(figsize=(15, 5))
198
- plt.text(0.5, 0.5, f"FC visualization unavailable: {str(e)}",
199
- ha='center', va='center', transform=plt.gca().transAxes)
200
- plt.tight_layout()
201
-
202
- # Learning curves
203
- try:
204
- print("Creating learning curve visualization...")
205
-
206
- # Check if losses are stored in the VAE object first (most reliable source)
207
- train_data = []
208
- val_data = []
209
 
210
- # Only use real data from VAE object or training results
211
- if hasattr(vae, 'train_losses') and len(getattr(vae, 'train_losses', [])) > 0:
212
- train_data = vae.train_losses
213
- print(f"Found {len(train_data)} real training loss points in VAE object")
214
- elif train_losses and len(train_losses) > 0:
215
- train_data = train_losses
216
- print(f"Using {len(train_data)} real training loss points from fit return value")
217
- else:
218
- # Instead of synthetic data, provide empty list and warning
219
- print("WARNING: No real training loss data found")
220
- train_data = []
221
 
222
- # Do the same for validation data
223
- if hasattr(vae, 'val_losses') and len(getattr(vae, 'val_losses', [])) > 0:
224
- val_data = vae.val_losses
225
- print(f"Found {len(val_data)} real validation loss points in VAE object")
226
- elif val_losses and len(val_losses) > 0:
227
- val_data = val_losses
228
- print(f"Using {len(val_data)} real validation loss points from fit return value")
229
- else:
230
- # Instead of synthetic data, provide empty list and warning
231
- print("WARNING: No real validation loss data found")
232
- val_data = []
233
-
234
- # If we get here, we have some training data (real or synthetic)
235
- # Store the data in the VAE object for future use
236
- if not hasattr(vae, 'train_losses') or len(getattr(vae, 'train_losses', [])) == 0:
237
- print("Storing training loss data in VAE object")
238
- vae.train_losses = train_data
239
 
240
- if not hasattr(vae, 'val_losses') or len(getattr(vae, 'val_losses', [])) == 0:
241
- print("Storing validation loss data in VAE object")
242
- vae.val_losses = val_data
 
 
243
 
244
- # Now create the visualization using the data we collected
245
- print(f"Creating learning curve with {len(train_data)} training and {len(val_data)} validation points")
246
- learning_fig = plot_learning_curves(train_data, val_data)
247
- except Exception as e:
248
- import traceback
249
- print(f"Error creating learning curve plot: {e}")
250
- print(f"Traceback: {traceback.format_exc()}")
251
-
252
- # Create a more informative error display
253
- learning_fig = plt.figure(figsize=(10, 6))
254
- plt.text(0.5, 0.5, f"Error creating learning curves: {str(e)}",
255
- ha='center', va='center', transform=plt.gca().transAxes,
256
- fontsize=12, color='darkred')
257
- plt.axis('off')
258
- plt.tight_layout()
259
-
260
- # Check if we should use strict real data mode
261
- use_strict_real_data = PREDICTION_CONFIG.get('strict_real_data', False)
262
- no_mock_data = PREDICTION_CONFIG.get('no_mock_data', False)
263
-
264
- if use_strict_real_data or no_mock_data:
265
- print("Using strict real data mode - only including real data in results")
266
- # Only include figures if they contain real data
267
- figures = {}
268
- if hasattr(vae, 'train_losses') and len(vae.train_losses) > 0:
269
- figures['learning_curves'] = learning_fig
270
- print("Including real learning curves")
271
  else:
272
- print("WARNING: No real learning curve data available")
 
273
 
274
- # Only include FC analysis if it's based on real data
275
- if len(np.array(X).shape) > 0 and len(X) > 0:
276
- figures['vae'] = fc_fig
277
- figures['fc_analysis'] = fc_fig
278
- print("Including real FC analysis")
279
- else:
280
- print("WARNING: No real FC data available")
281
  else:
282
- # Include all figures, even if based on synthetic data
283
- figures = {
284
- 'vae': fc_fig,
285
- 'fc_analysis': fc_fig,
286
- 'learning_curves': learning_fig
287
- }
 
 
 
 
 
 
 
288
 
289
- # Initialize results dictionary
290
- results = {
291
- 'vae': vae,
292
- 'latents': latents,
293
- 'demographics': demographics,
294
- 'figures': figures
295
- }
 
 
 
296
 
297
- # Add reconstructed and generated FC if available
298
- if return_data:
299
- results.update({
300
- 'X': X,
301
- 'reconstructed_fc': reconstructed,
302
- 'generated_fc': generated
303
- })
304
 
305
- # Treatment prediction is optional
306
- if not skip_treatment_prediction and treatment_outcomes is not None:
307
- # Initialize and train treatment predictor
308
- print("Training treatment predictor...")
309
- predictor = AphasiaTreatmentPredictor(
310
- n_estimators=PREDICTION_CONFIG.get('n_estimators', 100),
311
- max_depth=PREDICTION_CONFIG.get('max_depth', None)
312
- )
313
 
314
- # Cross-validate the predictor
315
- print("Performing cross-validation...")
316
- cv_results = predictor.cross_validate(
317
- latents=latents,
318
- demographics=demographics,
319
- treatment_outcomes=treatment_outcomes,
320
- n_splits=PREDICTION_CONFIG.get('cv_folds', 5)
321
- )
 
 
 
 
 
322
 
323
- # Extract results from CV
324
- mean_metrics = cv_results.get("mean_metrics", {})
325
- fold_metrics = cv_results.get("fold_metrics", [])
 
326
 
327
- # Handle zeros_like fallback
 
328
  try:
329
- predictions = cv_results.get("predictions")
330
- if predictions is None:
331
- predictions = np.zeros_like(treatment_outcomes)
332
-
333
- prediction_stds = cv_results.get("prediction_stds")
334
- if prediction_stds is None:
335
- prediction_stds = np.zeros_like(treatment_outcomes)
 
 
336
  except Exception as e:
337
- print(f"Error getting predictions from CV results: {e}")
338
- # Create simple arrays as fallback
339
- predictions = np.zeros(len(treatment_outcomes))
340
- prediction_stds = np.zeros(len(treatment_outcomes))
341
 
342
- # For regression, get R2 metrics, otherwise use accuracy
343
- try:
344
- cv_mean = mean_metrics.get("r2", 0.0)
345
- if fold_metrics and "r2" in fold_metrics[0]:
346
- cv_std = np.std([fold.get("r2", 0.0) for fold in fold_metrics])
347
- else:
348
- cv_std = 0.0
349
- except Exception as e:
350
- print(f"Error calculating CV metrics: {e}")
351
- cv_mean, cv_std = 0.0, 0.0
352
 
353
- # Fit final predictor model
354
- predictor.fit(latents, demographics, treatment_outcomes)
 
 
 
 
 
 
355
 
356
- # Feature importance
357
- try:
358
- importance_fig = predictor.plot_feature_importance()
359
- except Exception as e:
360
- print(f"Error creating feature importance plot: {e}")
361
- importance_fig = plt.figure(figsize=(8, 6))
362
- plt.text(0.5, 0.5, "Feature importance unavailable",
363
- ha='center', va='center', transform=plt.gca().transAxes)
364
- plt.tight_layout()
365
 
366
- # Prediction performance
367
- performance_fig = plt.figure(figsize=(8, 6))
368
 
369
- # Check if we have valid predictions
370
- if len(treatment_outcomes) > 0 and len(predictions) == len(treatment_outcomes):
371
- try:
372
- # Only create scatter plot if we have matching data
373
- plt.scatter(treatment_outcomes, predictions)
374
-
375
- # Reference line
376
- min_val = min(np.min(treatment_outcomes), np.min(predictions))
377
- max_val = max(np.max(treatment_outcomes), np.max(predictions))
378
- plt.plot([min_val, max_val], [min_val, max_val], 'r--')
379
-
380
- # Confidence band
381
- plt.fill_between(treatment_outcomes,
382
- predictions - 2*prediction_stds,
383
- predictions + 2*prediction_stds,
384
- alpha=0.2, color='gray')
385
-
386
- # Labels
387
- plt.xlabel('Actual Outcome')
388
- plt.ylabel('Predicted Outcome')
389
-
390
- # Title with metrics
391
- if predictor.prediction_type == "regression":
392
- plt.title(f'Treatment Outcome Prediction\nR² = {cv_mean:.3f} ± {cv_std:.3f}')
393
- else:
394
- plt.title(f'Treatment Outcome Prediction\nAccuracy = {cv_mean:.3f} ± {cv_std:.3f}')
395
- except Exception as e:
396
- print(f"Error creating performance plot: {e}")
397
- plt.text(0.5, 0.5, "Error creating plot",
398
- ha='center', va='center', transform=plt.gca().transAxes)
399
- else:
400
- # Handle case with no data
401
- plt.text(0.5, 0.5, "No prediction data available",
402
- ha='center', va='center', transform=plt.gca().transAxes)
403
-
404
- plt.tight_layout()
405
 
406
- # Save results
407
- print("Saving prediction results...")
408
- np.save('results/predictions.npy', predictions)
409
- np.save('results/prediction_stds.npy', prediction_stds)
410
 
411
- # Update results dictionary with prediction information
412
- predictor_cv_results = {
413
- 'mean_metrics': mean_metrics if mean_metrics else {},
414
- 'fold_metrics': fold_metrics if fold_metrics else [],
415
- 'predictions': predictions if len(predictions) > 0 else np.zeros(0),
416
- 'prediction_stds': prediction_stds if len(prediction_stds) > 0 else np.zeros(0)
417
- }
 
 
 
 
418
 
419
- # Add prediction results to main results dictionary
420
- results.update({
421
- 'predictor': predictor,
422
- 'cv_scores': (cv_mean, cv_std),
423
- 'predictions': predictions,
424
- 'prediction_stds': prediction_stds,
425
- 'predictor_cv_results': predictor_cv_results,
426
- })
427
 
428
- # Add prediction figures to results dictionary
429
- results['figures'].update({
430
- 'importance': importance_fig,
431
- 'performance': performance_fig
432
- })
433
-
434
- # Save models if requested
435
- if save_model:
436
- print("Saving models...")
437
- vae.save('models/vae_model.pt')
438
 
439
- # Save predictor model if it exists
440
- if not skip_treatment_prediction and 'predictor' in results:
441
- torch.save({
442
- 'predictor_state': results['predictor'].model,
443
- 'feature_importance': results['predictor'].feature_importance
444
- }, 'models/predictor_model.pt')
445
-
446
- print("Analysis complete!")
447
- return results
448
-
449
- # Alias for backwards compatibility - simplified version of run_analysis just for FC
450
- def run_fc_analysis(data_dir="data",
451
- demographic_file=None,
452
- latent_dim=32,
453
- nepochs=100,
454
- bsize=16,
455
- save_model=True,
456
- use_hf_dataset=True,
457
- return_data=True):
458
- """
459
- Run only the FC analysis portion without prediction
460
-
461
- This is a simplified version of run_analysis focused on VAE training
462
- and FC matrix visualization for the Gradio interface.
463
- """
464
- # Call the main function with skip_treatment_prediction=True
465
- return run_analysis(
466
- data_dir=data_dir,
467
- demographic_file=demographic_file,
468
- treatment_file=None, # No treatment file needed
469
- latent_dim=latent_dim,
470
- nepochs=nepochs,
471
- bsize=bsize,
472
- save_model=save_model,
473
- use_hf_dataset=use_hf_dataset,
474
- return_data=return_data,
475
- skip_treatment_prediction=True
476
- )
477
 
478
  if __name__ == "__main__":
479
  import argparse
480
 
481
- parser = argparse.ArgumentParser(description='Run Aphasia Treatment Analysis')
482
- parser.add_argument('--data_dir', type=str, default='data',
483
- help='Directory containing fMRI data')
484
- parser.add_argument('--demographic_file', type=str, default='demographics.csv',
485
  help='Path to demographic data CSV file')
486
- parser.add_argument('--treatment_file', type=str, default='treatment_outcomes.csv',
487
- help='Path to treatment outcomes CSV file')
488
  parser.add_argument('--latent_dim', type=int, default=32,
489
  help='Dimension of latent space')
490
  parser.add_argument('--nepochs', type=int, default=1000,
@@ -492,28 +274,20 @@ if __name__ == "__main__":
492
  parser.add_argument('--bsize', type=int, default=16,
493
  help='Batch size for training')
494
  parser.add_argument('--no_save', action='store_false',
495
- help='Do not save the models')
496
- parser.add_argument('--fc_only', action='store_true',
497
- help='Run only the FC analysis without treatment prediction')
498
 
499
  args = parser.parse_args()
500
 
501
- if args.fc_only:
502
- results = run_fc_analysis(
503
- data_dir=args.data_dir,
504
- demographic_file=args.demographic_file,
505
- latent_dim=args.latent_dim,
506
- nepochs=args.nepochs,
507
- bsize=args.bsize,
508
- save_model=args.no_save
509
- )
510
- else:
511
- results = run_analysis(
512
- data_dir=args.data_dir,
513
- demographic_file=args.demographic_file,
514
- treatment_file=args.treatment_file,
515
- latent_dim=args.latent_dim,
516
- nepochs=args.nepochs,
517
- bsize=args.bsize,
518
- save_model=args.no_save
519
- )
 
1
  import os
2
+ import sys
3
+ # Add the src directory to the path so we can import from demovae
4
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
5
+
6
+ import numpy as np
7
  import torch
8
  from pathlib import Path
9
+ import nibabel as nib
10
+ from data_preprocessing import preprocess_fmri_to_fc
11
+ from src.demovae.sklearn import DemoVAE
12
+ from analysis import analyze_fc_patterns
13
+ from visualization import visualize_fc_analysis
14
+ from config import MODEL_CONFIG, DATASET_CONFIG
15
  import pandas as pd
16
+ import io
17
+ from typing import List, Dict, Union, Tuple, Any
18
 
19
+ def train_fc_vae(X, demo_data, demo_types, model_config):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  """
21
+ Train a VAE model on functional connectivity matrices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
+ n_rois = 264
24
+ input_dim = (n_rois * (n_rois - 1)) // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
 
 
27
 
28
+ # Ensure X is a numpy array with correct data type
29
+ if not isinstance(X, np.ndarray):
30
+ print(f"Converting X from {type(X)} to numpy array")
31
+ X = np.array(X, dtype=np.float32)
32
 
33
+ # Ensure demo_data contains numpy arrays
34
+ for i, d in enumerate(demo_data):
35
+ if not isinstance(d, np.ndarray):
36
+ print(f"Converting demographic {i} from {type(d)} to numpy array")
37
+ demo_data[i] = np.array(d)
38
 
39
+ # Check for NaN or Inf values
40
+ if np.isnan(X).any() or np.isinf(X).any():
41
+ print("Warning: X contains NaN or Inf values. Replacing with zeros.")
42
+ X = np.nan_to_num(X)
43
 
44
+ # Create the VAE model
45
+ vae = DemoVAE(
46
+ latent_dim=model_config['latent_dim'],
47
+ nepochs=model_config['nepochs'],
48
+ bsize=model_config['bsize'],
49
+ loss_rec_mult=model_config.get('loss_rec_mult', 100),
50
+ loss_decor_mult=model_config.get('loss_decor_mult', 10),
51
+ lr=model_config.get('lr', 1e-4),
52
+ use_cuda=torch.cuda.is_available()
53
+ )
54
 
55
+ print("Fitting VAE model...")
56
+ vae.fit(X, demo_data, demo_types)
57
 
58
+ return vae, X, demo_data, demo_types
59
+
60
+ def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
61
+ """
62
+ Load fMRI data and demographics from HuggingFace dataset or local files
63
+ """
64
+ if use_hf_dataset:
65
+ # Load from HuggingFace Datasets
66
+ from datasets import load_dataset
67
 
68
+ print(f"Loading dataset from HuggingFace: {data_dir}")
69
+ dataset = load_dataset(data_dir)
 
 
 
70
 
71
+ print(f"Dataset columns: {dataset['train'].column_names}")
 
 
 
72
 
73
+ # Get demographics directly from the dataset
74
+ # Create a DataFrame from the dataset features
75
+ demo_df = pd.DataFrame({
76
+ 'ID': dataset['train']['ID'],
77
+ 'wab_aq': dataset['train']['wab_aq'],
78
+ 'age': dataset['train']['age'],
79
+ 'mpo': dataset['train']['mpo'],
80
+ 'education': dataset['train']['education'],
81
+ 'gender': dataset['train']['gender'],
82
+ 'handedness': dataset['train']['handedness']
83
+ })
84
 
85
+ print(f"Loaded demographic data with {len(demo_df)} subjects")
86
 
87
+ # Extract demographic data matching our expected format
88
+ # Map the dataset columns to our expected format
89
+ demo_data = [
90
+ demo_df['age'].values, # age at stroke -> age
91
+ demo_df['gender'].values, # sex -> gender
92
+ demo_df['mpo'].values, # months post stroke -> mpo
93
+ demo_df['wab_aq'].values # wab score -> wab_aq
94
+ ]
95
 
96
+ # Check for FC matrices in the dataset
97
+ fc_columns = []
98
+ for col in dataset['train'].column_names:
99
+ if col.startswith("fc_") or "_fc" in col:
100
+ fc_columns.append(col)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ if fc_columns:
103
+ print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
104
+ # Extract FC matrices
105
+ fc_matrices = []
106
+ for fc_col in fc_columns:
107
+ fc_matrices.append(dataset['train'][fc_col])
 
 
 
 
 
108
 
109
+ # If we have FC matrices, return them directly
110
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
111
+ return fc_matrices, demo_data, demo_types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
 
113
+ # If no FC matrices, look for .nii files
114
+ nii_files = []
115
+ for col in dataset['train'].column_names:
116
+ if col.endswith(".nii.gz") or col.endswith(".nii"):
117
+ nii_files.append(dataset['train'][col])
118
 
119
+ if nii_files:
120
+ print(f"Found {len(nii_files)} .nii files")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  else:
122
+ print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
123
+ # If no structured data is found, we can try to download raw files later
124
 
 
 
 
 
 
 
 
125
  else:
126
+ # Original local file loading
127
+ # Load demographics
128
+ demo_df = pd.read_csv(demographic_file)
129
+
130
+ demo_data = [
131
+ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
132
+ demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
133
+ demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
134
+ demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
135
+ ]
136
+
137
+ # Load fMRI files
138
+ nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
139
 
140
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
141
+ return nii_files, demo_data, demo_types
142
+
143
+ def run_fc_analysis(data_dir="SreekarB/OSFData",
144
+ demographic_file=None,
145
+ latent_dim=32,
146
+ nepochs=1000,
147
+ bsize=16,
148
+ save_model=True,
149
+ use_hf_dataset=True):
150
 
151
+ # Update MODEL_CONFIG with user-specified parameters
152
+ MODEL_CONFIG.update({
153
+ 'latent_dim': latent_dim,
154
+ 'nepochs': nepochs,
155
+ 'bsize': bsize
156
+ })
 
157
 
158
+ try:
159
+ # Load data
160
+ print("Loading data...")
161
+ nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
 
 
 
 
162
 
163
+ # For SreekarB/OSFData, directly generate synthetic FC matrices
164
+ if data_dir == "SreekarB/OSFData" and use_hf_dataset:
165
+ print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
166
+ X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
167
+ # Check if we got FC matrices directly
168
+ elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
169
+ print("Using pre-computed FC matrices...")
170
+ # Convert list of FC matrices to numpy array
171
+ X = np.stack([np.array(fc) for fc in nii_files])
172
+ else:
173
+ # Prepare data by converting fMRI to FC matrices
174
+ print("Converting fMRI data to FC matrices...")
175
+ X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
176
 
177
+ # Print shapes and data types
178
+ print(f"X shape: {X.shape}, type: {type(X)}")
179
+ for i, d in enumerate(demo_data):
180
+ print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
181
 
182
+ # Train VAE and get data
183
+ print("Training VAE...")
184
  try:
185
+ # Use the proper DemoVAE implementation from src/demovae/sklearn.py
186
+ vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
187
+
188
+ if save_model:
189
+ print("Saving model...")
190
+ os.makedirs('models', exist_ok=True)
191
+ # Use the save method from DemoVAE
192
+ vae.save('models/vae_model.pth')
193
+ print("Model saved successfully.")
194
  except Exception as e:
195
+ print(f"Error during VAE training: {e}")
196
+ raise
 
 
197
 
198
+ # Get latent representations
199
+ print("Getting latent representations...")
200
+ latents = vae.get_latents(X)
 
 
 
 
 
 
 
201
 
202
+ # Analyze results
203
+ print("Analyzing demographic relationships...")
204
+ demographics = {
205
+ 'age': demo_data[0],
206
+ 'months_post_onset': demo_data[2],
207
+ 'wab_aq': demo_data[3]
208
+ }
209
+ analysis_results = analyze_fc_patterns(latents, demographics)
210
 
211
+ # Generate new FC matrix
212
+ print("Generating new FC matrices...")
 
 
 
 
 
 
 
213
 
214
+ # Get data types from original demographic data for proper conversion
215
+ demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
216
 
217
+ # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
218
+ new_demographics = [
219
+ np.array([60.0], dtype=np.float64), # age
220
+ np.array(['M'], dtype=np.str_), # gender
221
+ np.array([12.0], dtype=np.float64), # months post onset
222
+ np.array([80.0], dtype=np.float64) # wab score
223
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
224
 
225
+ # Verify the demographic data arrays match the expected types
226
+ print("Demographic data types:")
227
+ for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
228
+ print(f" {name}: shape={data.shape}, dtype={data.dtype}")
229
 
230
+ print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
231
+ try:
232
+ generated_fc = vae.transform(1, new_demographics, demo_types)
233
+ except Exception as e:
234
+ print(f"Error generating new FC matrix: {e}")
235
+ # Try with a fallback approach
236
+ print("Trying alternative generation approach...")
237
+ # If specific gender is causing issues, try the first gender from training data
238
+ new_demographics[1] = np.array([demo_data[1][0]])
239
+ generated_fc = vae.transform(1, new_demographics, demo_types)
240
+ reconstructed_fc = vae.transform(X, demo_data, demo_types)
241
 
242
+ # Visualize results
243
+ print("Creating visualizations...")
244
+ fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
 
 
 
 
 
245
 
246
+ return fig
 
 
 
 
 
 
 
 
 
247
 
248
+ except Exception as e:
249
+ import traceback
250
+ print(f"Error in run_fc_analysis: {str(e)}")
251
+ print(traceback.format_exc())
252
+
253
+ # Create a dummy figure with error message
254
+ import matplotlib.pyplot as plt
255
+ fig = plt.figure(figsize=(10, 6))
256
+ plt.text(0.5, 0.5, f"Error: {str(e)}",
257
+ horizontalalignment='center', verticalalignment='center',
258
+ fontsize=12, color='red')
259
+ plt.axis('off')
260
+ return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  if __name__ == "__main__":
263
  import argparse
264
 
265
+ parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
266
+ parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
267
+ help='HuggingFace dataset ID or directory containing fMRI data')
268
+ parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
269
  help='Path to demographic data CSV file')
 
 
270
  parser.add_argument('--latent_dim', type=int, default=32,
271
  help='Dimension of latent space')
272
  parser.add_argument('--nepochs', type=int, default=1000,
 
274
  parser.add_argument('--bsize', type=int, default=16,
275
  help='Batch size for training')
276
  parser.add_argument('--no_save', action='store_false',
277
+ help='Do not save the model')
278
+ parser.add_argument('--use_local', action='store_true',
279
+ help='Use local data instead of HuggingFace dataset')
280
 
281
  args = parser.parse_args()
282
 
283
+ fig = run_fc_analysis(
284
+ data_dir=args.data_dir,
285
+ demographic_file=args.demographic_file,
286
+ latent_dim=args.latent_dim,
287
+ nepochs=args.nepochs,
288
+ bsize=args.bsize,
289
+ save_model=args.no_save,
290
+ use_hf_dataset=not args.use_local
291
+ )
292
+ fig.show()
293
+
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -1,8 +1,12 @@
1
  torch>=1.9.0
2
  numpy>=1.19.2
3
  pandas>=1.2.4
 
 
4
  scikit-learn>=0.24.2
5
  matplotlib>=3.4.2
6
- gradio>=3.0.0
7
- joblib>=1.0.1
 
 
8
 
 
1
  torch>=1.9.0
2
  numpy>=1.19.2
3
  pandas>=1.2.4
4
+ nilearn>=0.8.1
5
+ nibabel>=3.2.1
6
  scikit-learn>=0.24.2
7
  matplotlib>=3.4.2
8
+ gradio>=2.0.0
9
+ datasets>=1.11.0
10
+ huggingface_hub>=0.15.0
11
+ transformers>=4.15.0
12
 
test_hf_download.py CHANGED
@@ -6,7 +6,7 @@ from datasets import load_dataset
6
  import numpy as np
7
  import pandas as pd
8
 
9
- def test_huggingface_download(dataset_name="SreekarB/OSFData1", revision=None, auth_token=None):
10
  """
11
  Test script to verify downloading NIfTI files from HuggingFace Datasets
12
  """
@@ -227,7 +227,7 @@ if __name__ == "__main__":
227
  # Process command line arguments
228
  import argparse
229
  parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
230
- parser.add_argument('--dataset', type=str, default="SreekarB/OSFData1", help='HuggingFace dataset name')
231
  parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
232
  parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
233
 
 
6
  import numpy as np
7
  import pandas as pd
8
 
9
+ def test_huggingface_download(dataset_name="SreekarB/OSFData", revision=None, auth_token=None):
10
  """
11
  Test script to verify downloading NIfTI files from HuggingFace Datasets
12
  """
 
227
  # Process command line arguments
228
  import argparse
229
  parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
230
+ parser.add_argument('--dataset', type=str, default="SreekarB/OSFData", help='HuggingFace dataset name')
231
  parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
232
  parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
233
 
utils.py CHANGED
@@ -3,18 +3,23 @@ import numpy as np
3
  from sklearn.linear_model import Ridge, LogisticRegression
4
 
5
  def to_torch(x):
6
- if not isinstance(x, np.ndarray):
7
- x = np.array(x)
8
  return torch.from_numpy(x).float()
9
 
10
  def to_cuda(x, use_cuda):
11
- if use_cuda and torch.cuda.is_available():
12
- return x.cuda()
13
- return x
14
 
15
  def to_numpy(x):
16
  return x.detach().cpu().numpy()
17
 
 
 
 
 
 
 
 
 
 
18
  def rmse(a, b, mean=torch.mean):
19
  return mean((a-b)**2)**0.5
20
 
@@ -42,7 +47,6 @@ def decor_loss(z, demo, use_cuda=True):
42
  ps.append(p)
43
  losses = torch.stack(losses)
44
  return losses, ps
45
-
46
  def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
47
  demo_t = []
48
  demo_idx = 0
@@ -66,33 +70,10 @@ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
66
  def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
67
  loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
68
  loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
69
-
70
  # Get linear predictors for demographics
71
  pred_w = []
72
  pred_i = []
73
  pred_stats = []
74
- train_losses = []
75
- val_losses = []
76
-
77
- # Check if sample sizes are consistent
78
- n_samples = x.shape[0]
79
- print(f"Sample sizes - X: {n_samples}, Demographics: {[len(d) for d in demo]}")
80
-
81
- # Ensure all sample sizes match
82
- if any(len(d) != n_samples for d in demo):
83
- print("WARNING: Sample size mismatch detected! Fixing...")
84
-
85
- # Trim to smallest size
86
- min_samples = min(n_samples, *[len(d) for d in demo])
87
- print(f"Adjusting to {min_samples} samples")
88
-
89
- # Adjust x and demo
90
- x = x[:min_samples]
91
- demo = [d[:min_samples] for d in demo]
92
-
93
- print(f"After adjustment - X: {x.shape[0]}, Demographics: {[len(d) for d in demo]}")
94
-
95
- print(f"Using {x.shape[0]} samples for training")
96
 
97
  for i, d, t in zip(range(len(demo)), demo, demo_types):
98
  print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
@@ -133,21 +114,7 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
133
  ce = torch.nn.CrossEntropyLoss()
134
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
135
 
136
- # Calculate initial validation loss
137
- print("Calculating initial validation metrics...")
138
- vae.eval()
139
- with torch.no_grad():
140
- z_val = vae.enc(x)
141
- y_val = vae.dec(z_val, demo_t)
142
- initial_val_loss = rmse(x, y_val).item()
143
- val_losses.append(initial_val_loss)
144
- print(f"Initial validation loss: {initial_val_loss:.4f}")
145
-
146
- # Main training loop
147
  for e in range(nepochs):
148
- epoch_losses = []
149
- vae.train()
150
-
151
  for bs in range(0, len(x), bsize):
152
  xb = x[bs:(bs+bsize)]
153
  db = demo_t[bs:(bs+bsize)]
@@ -161,43 +128,59 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
161
  loss_decor = sum(loss_decor)
162
  loss_rec = rmse(xb, y)
163
 
164
- # Calculate total loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
  total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
166
- loss_rec_mult*loss_rec + loss_decor_mult*loss_decor)
 
167
 
168
  total_loss.backward()
169
  optim.step()
170
 
171
- epoch_losses.append(total_loss.item())
172
-
173
- # Record training loss
174
- epoch_loss = np.mean(epoch_losses)
175
- train_losses.append(epoch_loss)
176
-
177
- # Print progress for every epoch
178
- print(f'Epoch {e+1}/{nepochs} - Train Loss: {epoch_loss:.4f}')
179
-
180
- # Validation step (perform at every epoch to have full data for plotting)
181
- vae.eval()
182
- with torch.no_grad():
183
- z = vae.enc(x)
184
- y = vae.dec(z, demo_t)
185
- val_loss = rmse(x, y).item()
186
- val_losses.append(val_loss)
187
-
188
- # Only print detailed validation logs at pperiod intervals
189
- if (e + 1) % pperiod == 0:
190
- print(f' Validation - Val Loss: {val_loss:.4f}')
191
-
192
- # Make sure losses are converted to regular Python lists (for serialization)
193
- train_losses = [float(loss) for loss in train_losses]
194
- val_losses = [float(loss) for loss in val_losses]
195
-
196
- print(f"Training complete - Final train loss: {train_losses[-1]:.4f}, Val loss: {val_losses[-1]:.4f}")
197
- print(f"Loss history recorded: {len(train_losses)} train points, {len(val_losses)} validation points")
198
-
199
- # Store the losses in the return object for future reference
200
- ret_obj.train_losses = train_losses
201
- ret_obj.val_losses = val_losses
202
-
203
- return train_losses, val_losses
 
3
  from sklearn.linear_model import Ridge, LogisticRegression
4
 
5
  def to_torch(x):
 
 
6
  return torch.from_numpy(x).float()
7
 
8
  def to_cuda(x, use_cuda):
9
+ return x.cuda() if use_cuda else x
 
 
10
 
11
  def to_numpy(x):
12
  return x.detach().cpu().numpy()
13
 
14
+ def fc_matrix_from_triu(triu_values, n_rois=264):
15
+ fc_matrix = np.zeros((n_rois, n_rois))
16
+ triu_indices = np.triu_indices(n_rois, k=1)
17
+ triu_values = np.tanh(triu_values)
18
+ fc_matrix[triu_indices] = triu_values
19
+ fc_matrix = fc_matrix + fc_matrix.T
20
+ np.fill_diagonal(fc_matrix, 1)
21
+ return fc_matrix
22
+
23
  def rmse(a, b, mean=torch.mean):
24
  return mean((a-b)**2)**0.5
25
 
 
47
  ps.append(p)
48
  losses = torch.stack(losses)
49
  return losses, ps
 
50
  def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
51
  demo_t = []
52
  demo_idx = 0
 
70
  def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
71
  loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
72
  loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
 
73
  # Get linear predictors for demographics
74
  pred_w = []
75
  pred_i = []
76
  pred_stats = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  for i, d, t in zip(range(len(demo)), demo, demo_types):
79
  print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
 
114
  ce = torch.nn.CrossEntropyLoss()
115
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
116
 
 
 
 
 
 
 
 
 
 
 
 
117
  for e in range(nepochs):
 
 
 
118
  for bs in range(0, len(x), bsize):
119
  xb = x[bs:(bs+bsize)]
120
  db = demo_t[bs:(bs+bsize)]
 
128
  loss_decor = sum(loss_decor)
129
  loss_rec = rmse(xb, y)
130
 
131
+ # Sample demographics
132
+ demo_gen = []
133
+ for s, t in zip(pred_stats, demo_types):
134
+ if t == 'continuous':
135
+ mu, std = s
136
+ dd = torch.randn(100).float()
137
+ dd = dd*std+mu
138
+ dd = to_cuda(dd, vae.use_cuda)
139
+ demo_gen.append(dd)
140
+ elif t == 'categorical':
141
+ idx = np.random.randint(0, len(s))
142
+ for i in range(len(s)):
143
+ dd = torch.ones(100).float() if idx == i else torch.zeros(100).float()
144
+ dd = to_cuda(dd, vae.use_cuda)
145
+ demo_gen.append(dd)
146
+
147
+ demo_gen = torch.stack(demo_gen).permute(1,0)
148
+
149
+ # Generate
150
+ z = vae.gen(100)
151
+ y = vae.dec(z, demo_gen)
152
+
153
+ # Regressor/classifier guidance loss
154
+ losses_pred = []
155
+ idcs = []
156
+ dg_idx = 0
157
+
158
+ for s, t in zip(pred_stats, demo_types):
159
+ if t == 'continuous':
160
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
161
+ loss = rmse(demo_gen[:,dg_idx], yy)
162
+ losses_pred.append(loss)
163
+ idcs.append(float(demo_gen[0,dg_idx]))
164
+ dg_idx += 1
165
+ elif t == 'categorical':
166
+ loss = 0
167
+ for i in range(len(s)):
168
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
169
+ loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
170
+ idcs.append(int(demo_gen[0,dg_idx]))
171
+ dg_idx += 1
172
+ losses_pred.append(loss)
173
+
174
  total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
175
+ loss_rec_mult*loss_rec + loss_decor_mult*loss_decor +
176
+ loss_pred_mult*sum(losses_pred))
177
 
178
  total_loss.backward()
179
  optim.step()
180
 
181
+ if e%pperiod == 0 or e == nepochs-1:
182
+ print(f'Epoch {e} ReconLoss {loss_rec:.4f} CovarianceLoss {loss_C:.4f} '
183
+ f'MeanLoss {loss_mu:.4f} DecorLoss {loss_decor:.4f}')
184
+ print(f'GuidanceTargets {idcs}')
185
+ print(f'GuidanceLosses {[f"{loss:.4f}" for loss in losses_pred]}')
186
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
vae_model.py CHANGED
@@ -1,355 +1,150 @@
1
- """
2
- Simplified VAE implementation with explicit loss tracking.
3
- """
4
  import torch
5
  import torch.nn as nn
6
  import torch.nn.functional as F
7
  import numpy as np
8
- import os
9
- import matplotlib.pyplot as plt
10
  from sklearn.base import BaseEstimator
11
 
12
- class SimpleVAE(nn.Module):
13
- def __init__(self, input_dim, latent_dim, demo_dim):
14
- super(SimpleVAE, self).__init__()
15
- # Store dimensions
16
  self.input_dim = input_dim
17
  self.latent_dim = latent_dim
18
  self.demo_dim = demo_dim
 
19
 
20
- # Encoder (FC data → latent)
21
- self.enc1 = nn.Linear(input_dim, 256)
22
- self.enc2 = nn.Linear(256, latent_dim)
23
 
24
- # Decoder (latent + demographics → FC reconstruction)
25
- self.dec1 = nn.Linear(latent_dim + demo_dim, 256)
26
- self.dec2 = nn.Linear(256, input_dim)
27
 
28
- def encode(self, x):
29
- """Encode FC data to latent space"""
30
- h = F.relu(self.enc1(x))
31
- return self.enc2(h)
32
-
33
- def decode(self, z, demo):
34
- """Decode from latent space to FC reconstruction"""
35
- # Combine latent with demographics
36
- z_combined = torch.cat([z, demo], dim=1)
37
- h = F.relu(self.dec1(z_combined))
38
- return self.dec2(h)
39
-
40
- def forward(self, x, demo):
41
- """Full forward pass"""
42
- z = self.encode(x)
43
- return self.decode(z, demo)
44
 
45
- class DemoVAE:
46
- def __init__(self, nepochs=50, batch_size=8, latent_dim=16, lr=1e-3):
47
- """Simple VAE model with demographic conditioning"""
48
- self.nepochs = nepochs
49
- self.batch_size = batch_size
50
- self.latent_dim = latent_dim
51
- self.lr = lr
52
- self.vae = None
53
- self.train_losses = []
54
- self.val_losses = []
55
- self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
56
-
57
- def preprocess_demo(self, demo_data, demo_types, n_samples=None):
58
- """Process demographic data into one-hot encoded tensors"""
59
- if n_samples is None:
60
- n_samples = len(demo_data[0])
61
-
62
- processed_demos = []
63
- total_dims = 0
64
-
65
- # Process each demographic variable
66
- for i, (data, dtype) in enumerate(zip(demo_data, demo_types)):
67
- if dtype == 'continuous':
68
- # For continuous variables, just normalize
69
- data_np = np.array(data).reshape(-1, 1)
70
- mean, std = np.mean(data_np), np.std(data_np)
71
- if std == 0: # Handle constant values
72
- normalized = np.zeros_like(data_np)
73
- else:
74
- normalized = (data_np - mean) / std
75
- processed_demos.append(normalized)
76
- total_dims += 1
77
- elif dtype == 'categorical':
78
- # For categorical, create one-hot encoding
79
- data_list = list(data)
80
- categories = sorted(list(set(data_list)))
81
-
82
- # Create one-hot vectors
83
- one_hot = np.zeros((len(data_list), len(categories)))
84
- for j, val in enumerate(data_list):
85
- idx = categories.index(val)
86
- one_hot[j, idx] = 1
87
-
88
- processed_demos.append(one_hot)
89
- total_dims += len(categories)
90
-
91
- # Combine all demographics
92
- demo_tensor = np.hstack(processed_demos)
93
- return torch.tensor(demo_tensor, dtype=torch.float32), total_dims
94
-
95
- def fit(self, X, demo_data, demo_types):
96
- """Train the VAE model"""
97
- # Convert to numpy arrays if needed
98
- X = np.array(X)
99
-
100
- # Process demographics
101
- print("Processing demographics...")
102
- demo_tensor, demo_dim = self.preprocess_demo(demo_data, demo_types)
103
-
104
- # Initialize model
105
- input_dim = X.shape[1]
106
- print(f"Creating model with input_dim={input_dim}, latent_dim={self.latent_dim}, demo_dim={demo_dim}")
107
- self.vae = SimpleVAE(input_dim, self.latent_dim, demo_dim)
108
- self.vae.to(self.device)
109
-
110
- # Convert data to tensors
111
- X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
112
- demo_tensor = demo_tensor.to(self.device)
113
-
114
- # Initialize optimizer
115
- optimizer = torch.optim.Adam(self.vae.parameters(), lr=self.lr)
116
-
117
- # Training loop
118
- n_samples = X.shape[0]
119
- batch_size = min(self.batch_size, n_samples)
120
-
121
- # Clear any old losses
122
- self.train_losses = []
123
- self.val_losses = []
124
-
125
- # Initial validation loss
126
- self.vae.eval()
127
- with torch.no_grad():
128
- reconstructed = self.vae(X_tensor, demo_tensor)
129
- init_val_loss = F.mse_loss(reconstructed, X_tensor).item()
130
- self.val_losses.append(init_val_loss)
131
- print(f"Initial validation loss: {init_val_loss:.4f}")
132
-
133
- # Main training loop
134
- for epoch in range(self.nepochs):
135
- epoch_losses = []
136
- self.vae.train()
137
-
138
- # Process in batches
139
- for i in range(0, n_samples, batch_size):
140
- # Get batch
141
- end = min(i + batch_size, n_samples)
142
- x_batch = X_tensor[i:end]
143
- demo_batch = demo_tensor[i:end]
144
-
145
- # Forward pass
146
- optimizer.zero_grad()
147
- reconstructed = self.vae(x_batch, demo_batch)
148
-
149
- # Calculate loss
150
- loss = F.mse_loss(reconstructed, x_batch)
151
-
152
- # Backward pass
153
- loss.backward()
154
- optimizer.step()
155
-
156
- # Record loss
157
- epoch_losses.append(loss.item())
158
-
159
- # End of epoch
160
- avg_loss = np.mean(epoch_losses)
161
- self.train_losses.append(avg_loss)
162
-
163
- # Validation
164
- self.vae.eval()
165
- with torch.no_grad():
166
- reconstructed = self.vae(X_tensor, demo_tensor)
167
- val_loss = F.mse_loss(reconstructed, X_tensor).item()
168
- self.val_losses.append(val_loss)
169
-
170
- # Print progress every few epochs
171
- if (epoch + 1) % 5 == 0 or epoch == 0:
172
- print(f"Epoch {epoch+1}/{self.nepochs} - "
173
- f"Train loss: {avg_loss:.4f}, Val loss: {val_loss:.4f}")
174
-
175
- print(f"Training complete! Final loss: {self.train_losses[-1]:.4f}")
176
- print(f"Loss history: {len(self.train_losses)} train, {len(self.val_losses)} validation")
177
-
178
- return self.train_losses, self.val_losses
179
-
180
- def transform(self, X, demo_data, demo_types):
181
- """Generate reconstructions or synthetic samples"""
182
- # Check if model is available
183
- if self.vae is None:
184
- raise ValueError("Model not trained or loaded yet")
185
-
186
- # Set model to evaluation mode
187
- self.vae.eval()
188
-
189
- # Check if we're generating or reconstructing
190
- if isinstance(X, int):
191
- # Generating n random samples
192
- n_samples = X
193
-
194
- # Process demo data (repeat single values if needed)
195
- demo_list = []
196
- for d in demo_data:
197
- if not isinstance(d, (list, np.ndarray)):
198
- # Single value, repeat for all samples
199
- demo_list.append([d] * n_samples)
200
- else:
201
- demo_list.append(d)
202
-
203
- print(f"Generating {n_samples} samples with demo data: {demo_list}")
204
-
205
- # Process demographics
206
- demo_tensor, demo_dim = self.preprocess_demo(demo_list, demo_types, n_samples)
207
-
208
- # Generate random latent vectors
209
- z = torch.randn(n_samples, self.latent_dim).to(self.device)
210
-
211
  else:
212
- # Reconstructing existing data
213
- X = np.array(X)
214
- n_samples = X.shape[0]
215
-
216
- # Process demo data (repeat single values if needed)
217
- demo_list = []
218
- for d in demo_data:
219
- if not isinstance(d, (list, np.ndarray)) or len(d) != n_samples:
220
- # Single value, repeat for all samples
221
- demo_list.append([d] * n_samples)
222
- else:
223
- demo_list.append(d)
224
-
225
- # Process demographics
226
- demo_tensor, demo_dim = self.preprocess_demo(demo_list, demo_types)
227
-
228
- # Encode input data
229
- X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
230
- z = self.vae.encode(X_tensor)
231
-
232
- # Print shapes for debugging
233
- print(f"Latent shape: {z.shape}, Demo tensor shape: {demo_tensor.shape}")
234
-
235
- # Decode to get output
236
- demo_tensor = demo_tensor.to(self.device)
237
- with torch.no_grad():
238
- # Make sure demo_tensor has the right dimensions
239
- if demo_tensor.shape[1] != self.vae.demo_dim:
240
- print(f"WARNING: Demo dimension mismatch. Expected {self.vae.demo_dim}, got {demo_tensor.shape[1]}")
241
- # Use demographic dimension from the model
242
- if demo_tensor.shape[1] > self.vae.demo_dim:
243
- # Trim extra dimensions
244
- demo_tensor = demo_tensor[:, :self.vae.demo_dim]
245
- else:
246
- # Pad with zeros
247
- padding = torch.zeros(demo_tensor.shape[0], self.vae.demo_dim - demo_tensor.shape[1]).to(self.device)
248
- demo_tensor = torch.cat([demo_tensor, padding], dim=1)
249
- print(f"Adjusted demo tensor shape: {demo_tensor.shape}")
250
-
251
- output = self.vae.decode(z, demo_tensor)
252
-
253
- # Convert to numpy
254
- return output.cpu().numpy()
255
-
256
- def get_latents(self, X):
257
- """Encode data to latent representations"""
258
- X = np.array(X)
259
- X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
260
-
261
- with torch.no_grad():
262
- z = self.vae.encode(X_tensor)
263
-
264
- return z.cpu().numpy()
265
-
266
  def save(self, path):
267
- """Save the model and training history"""
268
- # Ensure the directory exists
269
- os.makedirs(os.path.dirname(os.path.abspath(path)), exist_ok=True)
270
-
271
- # Create state dict with all necessary info
272
- state = {
273
- 'vae_state': self.vae.state_dict(),
274
- 'input_dim': self.vae.input_dim,
275
- 'latent_dim': self.latent_dim,
276
- 'demo_dim': self.vae.demo_dim,
277
- 'train_losses': self.train_losses,
278
- 'val_losses': self.val_losses,
279
- 'nepochs': self.nepochs,
280
- 'batch_size': self.batch_size,
281
- 'lr': self.lr
282
- }
283
-
284
- # Save the model
285
- torch.save(state, path)
286
- print(f"Model saved to {path}")
287
-
288
- # Print info about saved losses
289
- print(f"Saved loss data: {len(self.train_losses)} train, {len(self.val_losses)} validation")
290
-
291
  def load(self, path):
292
- """Load the model from a file"""
293
- if not os.path.exists(path):
294
- raise FileNotFoundError(f"Model file not found: {path}")
295
-
296
- # Load state dict
297
- state = torch.load(path, map_location=self.device)
298
-
299
- # Set attributes
300
- self.latent_dim = state['latent_dim']
301
- self.nepochs = state.get('nepochs', 50)
302
- self.batch_size = state.get('batch_size', 8)
303
- self.lr = state.get('lr', 1e-3)
304
- self.train_losses = state.get('train_losses', [])
305
- self.val_losses = state.get('val_losses', [])
306
-
307
- # Create model
308
- self.vae = SimpleVAE(
309
- input_dim=state['input_dim'],
310
- latent_dim=self.latent_dim,
311
- demo_dim=state['demo_dim']
312
- )
313
-
314
- # Load weights
315
- self.vae.load_state_dict(state['vae_state'])
316
- self.vae.to(self.device)
317
-
318
- print(f"Model loaded from {path}")
319
- print(f"Loaded loss data: {len(self.train_losses)} train, {len(self.val_losses)} validation")
320
-
321
- def plot_learning_curves(train_losses, val_losses):
322
- """Plot training and validation loss curves"""
323
- # Create figure
324
- plt.figure(figsize=(10, 6))
325
-
326
- # Check if we have loss data
327
- if not train_losses:
328
- plt.text(0.5, 0.5, "No training loss data available",
329
- ha='center', va='center', transform=plt.gca().transAxes,
330
- fontsize=14, color='red')
331
- plt.axis('off')
332
- return plt.gcf()
333
-
334
- # Plot losses
335
- epochs = range(1, len(train_losses) + 1)
336
- plt.plot(epochs, train_losses, 'b-', label='Training loss')
337
 
338
- if val_losses:
339
- # Adjust validation epochs if lengths differ
340
- if len(val_losses) == len(train_losses) + 1:
341
- # Initial validation + epoch validations
342
- val_epochs = [0] + list(epochs)
343
- else:
344
- val_epochs = epochs[:len(val_losses)]
345
-
346
- plt.plot(val_epochs, val_losses, 'r-', label='Validation loss')
347
 
348
- # Add labels
349
- plt.title('VAE Training and Validation Loss')
350
- plt.xlabel('Epoch')
351
- plt.ylabel('Loss')
352
- plt.legend()
353
- plt.grid(True, alpha=0.3)
354
 
355
- return plt.gcf()
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
  import numpy as np
5
+ from utils import to_torch, to_cuda, to_numpy, demo_to_torch
 
6
  from sklearn.base import BaseEstimator
7
 
8
+ class VAE(nn.Module):
9
+ def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
10
+ super(VAE, self).__init__()
 
11
  self.input_dim = input_dim
12
  self.latent_dim = latent_dim
13
  self.demo_dim = demo_dim
14
+ self.use_cuda = use_cuda
15
 
16
+ # Encoder
17
+ self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
18
+ self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
19
 
20
+ # Decoder
21
+ self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
22
+ self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
23
 
24
+ # Batch normalization layers
25
+ self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
26
+ self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ def enc(self, x):
29
+ x = self.bn1(F.relu(self.enc1(x)))
30
+ z = self.enc2(x)
31
+ return z
32
+
33
+ def gen(self, n):
34
+ return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
35
+
36
+ def dec(self, z, demo):
37
+ z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
38
+ x = self.bn2(F.relu(self.dec1(z)))
39
+ x = self.dec2(x)
40
+ return x
41
+
42
+ class DemoVAE(BaseEstimator):
43
+ def __init__(self, **params):
44
+ self.set_params(**params)
45
+
46
+ @staticmethod
47
+ def get_default_params():
48
+ return dict(
49
+ latent_dim=32,
50
+ use_cuda=True,
51
+ nepochs=1000,
52
+ pperiod=100,
53
+ bsize=16,
54
+ loss_C_mult=1,
55
+ loss_mu_mult=1,
56
+ loss_rec_mult=100,
57
+ loss_decor_mult=10,
58
+ loss_pred_mult=0.001,
59
+ alpha=100,
60
+ LR_C=100,
61
+ lr=1e-4,
62
+ weight_decay=0
63
+ )
64
+
65
+ def get_params(self, deep=True):
66
+ return {k: getattr(self, k) for k in self.get_default_params().keys()}
67
+
68
+ def set_params(self, **params):
69
+ for k, v in self.get_default_params().items():
70
+ setattr(self, k, params.get(k, v))
71
+ return self
72
+
73
+ def fit(self, x, demo, demo_types):
74
+ from utils import train_vae
75
+
76
+ # Calculate demo_dim
77
+ demo_dim = 0
78
+ for d, t in zip(demo, demo_types):
79
+ if t == 'continuous':
80
+ demo_dim += 1
81
+ elif t == 'categorical':
82
+ demo_dim += len(set(d))
83
+ else:
84
+ raise ValueError(f'Demographic type "{t}" not supported')
85
+
86
+ # Initialize VAE
87
+ self.input_dim = x.shape[1]
88
+ self.demo_dim = demo_dim
89
+ self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
90
+
91
+ # Train VAE
92
+ train_vae(
93
+ self.vae, x, demo, demo_types,
94
+ self.nepochs, self.pperiod, self.bsize,
95
+ self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
96
+ self.loss_decor_mult, self.loss_pred_mult,
97
+ self.lr, self.weight_decay, self.alpha, self.LR_C,
98
+ self
99
+ )
100
+ return self
101
+
102
+ def transform(self, x, demo, demo_types):
103
+ if isinstance(x, int):
104
+ z = self.vae.gen(x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  else:
106
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
107
+ demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
108
+ y = self.vae.dec(z, demo_t)
109
+ return to_numpy(y)
110
+
111
+ def get_latents(self, x):
112
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
113
+ return to_numpy(z)
114
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  def save(self, path):
116
+ torch.save({
117
+ 'model_state_dict': self.vae.state_dict(),
118
+ 'params': self.get_params(),
119
+ 'pred_stats': self.pred_stats,
120
+ 'input_dim': self.input_dim,
121
+ 'demo_dim': self.demo_dim
122
+ }, path)
123
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  def load(self, path):
125
+ checkpoint = torch.load(path)
126
+ self.set_params(**checkpoint['params'])
127
+ self.pred_stats = checkpoint['pred_stats']
128
+ self.input_dim = checkpoint['input_dim']
129
+ self.demo_dim = checkpoint['demo_dim']
130
+ self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
131
+ self.vae.load_state_dict(checkpoint['model_state_dict'])
132
+
133
+ def train_fc_vae(X, demo_data, demo_types, model_config):
134
+ n_rois = 264
135
+ input_dim = (n_rois * (n_rois - 1)) // 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ vae = DemoVAE(
138
+ latent_dim=model_config['latent_dim'],
139
+ nepochs=model_config['nepochs'],
140
+ bsize=model_config['bsize'],
141
+ loss_rec_mult=model_config['loss_rec_mult'],
142
+ loss_decor_mult=model_config['loss_decor_mult'],
143
+ lr=model_config['lr'],
144
+ use_cuda=torch.cuda.is_available()
145
+ )
146
 
147
+ vae.fit(X, demo_data, demo_types)
 
 
 
 
 
148
 
149
+ return vae, X, demo_data, demo_types
150
+
visualization.py CHANGED
@@ -1,521 +1,44 @@
1
- # Configure matplotlib for headless environment
2
- import matplotlib
3
- matplotlib.use('Agg') # Use non-interactive backend
4
  import matplotlib.pyplot as plt
5
  import numpy as np
 
6
 
7
- def vector_to_matrix(vector):
8
- """Convert an upper triangular vector to a symmetric matrix"""
9
- # Make sure we have a numpy array
10
- if not isinstance(vector, np.ndarray):
11
- try:
12
- vector = np.array(vector)
13
- print(f"Converted input to numpy array, shape: {vector.shape}")
14
- except Exception as e:
15
- print(f"Error converting input to numpy array: {e}")
16
- # Create a fallback empty matrix
17
- n = 264 # Standard size for brain atlas
18
- matrix = np.zeros((n, n))
19
- np.fill_diagonal(matrix, 1.0)
20
- return matrix
21
 
22
- # Handle special case: already a matrix
23
- if len(vector.shape) == 2:
24
- print(f"Input is already a matrix with shape {vector.shape}")
25
- return vector
26
 
27
- # Handle regular vector case
28
- try:
29
- print(f"Converting vector to matrix. Vector shape: {vector.shape}, length: {len(vector)}")
30
-
31
- # For a 264x264 FC matrix, we expect 34716 elements
32
- if len(vector) == 34716:
33
- print("Detected standard FC vector with 34716 elements (264x264 matrix)")
34
- n = 264
35
- else:
36
- # For other sized vectors, calculate matrix size
37
- # For a matrix of size n×n, the number of elements in the upper triangular part (excl. diagonal) is n(n-1)/2
38
- n = int(np.sqrt(2 * len(vector) + 0.25) + 0.5)
39
- print(f"Calculated matrix size: {n}x{n}")
40
-
41
- # Validate calculation
42
- expected_elements = int(n * (n-1) / 2)
43
- if expected_elements != len(vector):
44
- print(f"WARNING: Vector length {len(vector)} doesn't match expected length {expected_elements} for {n}x{n} matrix")
45
- if len(vector) < expected_elements:
46
- print(f"Padding vector with {expected_elements - len(vector)} zeros")
47
- vector = np.pad(vector, (0, expected_elements - len(vector)))
48
- else:
49
- print(f"Trimming vector to {expected_elements} elements")
50
- vector = vector[:expected_elements]
51
-
52
- # Create empty matrix
53
- matrix = np.zeros((n, n))
54
-
55
- # Get indices for upper triangle
56
- triu_indices = np.triu_indices_from(matrix, k=1)
57
-
58
- # Convert from Fisher z-transform if needed (check if values exceed [-1,1])
59
- if np.any(np.abs(vector) > 1):
60
- print("Vector contains values >1, applying inverse Fisher z-transform")
61
- values = np.tanh(vector) # Inverse Fisher z-transform
62
- else:
63
- values = vector
64
-
65
- # Check for NaN or Inf values
66
- if np.any(np.isnan(values)) or np.any(np.isinf(values)):
67
- print("WARNING: Vector contains NaN or Inf values, replacing with zeros")
68
- values = np.nan_to_num(values)
69
-
70
- # Set upper triangle values
71
- matrix[triu_indices] = values
72
-
73
- # Make symmetric
74
- matrix = matrix + matrix.T
75
-
76
- # Set diagonal to 1 (perfect correlation)
77
- np.fill_diagonal(matrix, 1.0)
78
-
79
- print(f"Successfully converted to matrix with shape {matrix.shape}")
80
- return matrix
81
-
82
- except Exception as e:
83
- import traceback
84
- print(f"Error in vector_to_matrix: {e}")
85
- print(f"Traceback: {traceback.format_exc()}")
86
- print(f"Vector stats: min={np.min(vector) if len(vector) > 0 else 'N/A'}, "
87
- f"max={np.max(vector) if len(vector) > 0 else 'N/A'}, "
88
- f"mean={np.mean(vector) if len(vector) > 0 else 'N/A'}")
89
-
90
- # Fallback 1 - check if it's already a matrix that was flattened
91
- if len(vector) > 0 and np.sqrt(len(vector)) == int(np.sqrt(len(vector))):
92
- n = int(np.sqrt(len(vector)))
93
- print(f"Trying fallback reshape to {n}x{n}")
94
- return vector.reshape(n, n)
95
-
96
- # Fallback 2 - try standard FC matrix size
97
- elif len(vector) > 30000 and len(vector) < 40000: # Close to 34716
98
- print(f"Vector length {len(vector)} is close to 34716, trying 264x264 matrix")
99
- n = 264
100
- matrix = np.zeros((n, n))
101
- np.fill_diagonal(matrix, 1.0)
102
-
103
- # Try to fill as much as possible
104
- triu_indices = np.triu_indices_from(matrix, k=1)
105
- max_idx = min(len(vector), len(triu_indices[0]))
106
-
107
- # Convert from Fisher z-transform if needed
108
- if np.any(np.abs(vector[:max_idx]) > 1):
109
- values = np.tanh(vector[:max_idx])
110
- else:
111
- values = vector[:max_idx]
112
-
113
- # Fill the upper triangle with as many values as we can
114
- for i in range(max_idx):
115
- matrix[triu_indices[0][i], triu_indices[1][i]] = values[i]
116
-
117
- # Make symmetric
118
- matrix = matrix + matrix.T
119
- np.fill_diagonal(matrix, 1.0)
120
-
121
- print(f"Created partial matrix with shape {matrix.shape}")
122
- return matrix
123
-
124
- # Fallback 3 - create a dummy identity matrix as last resort
125
- else:
126
- print("Creating fallback identity matrix")
127
- n = 264 # Standard size for brain atlas
128
- matrix = np.zeros((n, n))
129
- np.fill_diagonal(matrix, 1.0)
130
- return matrix
131
-
132
- def plot_fc_matrices(original, reconstructed, generated):
133
- """Plot FC matrices comparison with enhanced visualization of brain region connections"""
134
- try:
135
- print("Starting FC matrix visualization...")
136
- print(f"Input shapes - Original: {original.shape}, Reconstructed: {reconstructed.shape}, Generated: {generated.shape}")
137
-
138
- # Use a larger figure for more detailed visualization
139
- fig = plt.figure(figsize=(20, 12))
140
-
141
- # Create a grid layout with 3 rows
142
- gs = plt.GridSpec(3, 3, height_ratios=[1, 0.7, 0.7], figure=fig)
143
-
144
- # First row: Original matrices
145
- ax1 = fig.add_subplot(gs[0, 0])
146
- ax2 = fig.add_subplot(gs[0, 1])
147
- ax3 = fig.add_subplot(gs[0, 2])
148
-
149
- # Second row: Difference matrix and top connections
150
- ax_diff = fig.add_subplot(gs[1, 0:2])
151
- ax_top = fig.add_subplot(gs[1, 2])
152
-
153
- # Third row: Region-specific analysis and histogram
154
- ax_region = fig.add_subplot(gs[2, 0])
155
- ax_hist = fig.add_subplot(gs[2, 1])
156
- ax_metrics = fig.add_subplot(gs[2, 2])
157
-
158
- vmin, vmax = -1, 1
159
-
160
- # Convert from vector to matrix if needed
161
- print("Converting inputs to matrices if needed...")
162
- if len(original.shape) == 1:
163
- print("Converting original FC from vector to matrix...")
164
- original = vector_to_matrix(original)
165
- if len(reconstructed.shape) == 1:
166
- print("Converting reconstructed FC from vector to matrix...")
167
- reconstructed = vector_to_matrix(reconstructed)
168
- if len(generated.shape) == 1:
169
- print("Converting generated FC from vector to matrix...")
170
- generated = vector_to_matrix(generated)
171
-
172
- print(f"Matrix shapes after conversion - Original: {original.shape}, Reconstructed: {reconstructed.shape}, Generated: {generated.shape}")
173
-
174
- # Check for NaN or Inf values and handle them
175
- for name, matrix in [("Original", original), ("Reconstructed", reconstructed), ("Generated", generated)]:
176
- if np.any(np.isnan(matrix)) or np.any(np.isinf(matrix)):
177
- print(f"WARNING: {name} matrix contains NaN or Inf values, replacing with zeros")
178
- if name == "Original":
179
- original = np.nan_to_num(matrix)
180
- elif name == "Reconstructed":
181
- reconstructed = np.nan_to_num(matrix)
182
- else:
183
- generated = np.nan_to_num(matrix)
184
-
185
- # Ensure matrices have consistent dimensions
186
- print("Checking matrix dimensions...")
187
- dimensions = [original.shape[0], reconstructed.shape[0], generated.shape[0]]
188
- if len(set(dimensions)) > 1:
189
- print(f"WARNING: Matrices have inconsistent dimensions: {dimensions}")
190
- # Use smallest dimension
191
- n = min(dimensions)
192
- print(f"Resizing matrices to consistent dimension: {n}x{n}")
193
- if original.shape[0] > n:
194
- original = original[:n, :n]
195
- if reconstructed.shape[0] > n:
196
- reconstructed = reconstructed[:n, :n]
197
- if generated.shape[0] > n:
198
- generated = generated[:n, :n]
199
-
200
- # Calculate key metrics for reconstruction quality
201
- print("Calculating reconstruction quality metrics...")
202
- from sklearn.metrics import mean_squared_error, r2_score
203
-
204
- # Flatten matrices for metric calculation (excluding diagonal)
205
- mask = ~np.eye(original.shape[0], dtype=bool)
206
- orig_flat = original[mask]
207
- recon_flat = reconstructed[mask]
208
-
209
- # Calculate metrics
210
- mse = mean_squared_error(orig_flat, recon_flat)
211
- rmse = np.sqrt(mse)
212
- r2 = r2_score(orig_flat, recon_flat)
213
- corr = np.corrcoef(orig_flat, recon_flat)[0, 1]
214
-
215
- # Calculate difference matrix
216
- diff_matrix = reconstructed - original
217
- max_diff = np.max(np.abs(diff_matrix))
218
- mean_abs_diff = np.mean(np.abs(diff_matrix))
219
-
220
- # Plot original matrices
221
- print("Creating matrix plots...")
222
- im1 = ax1.imshow(original, cmap='RdBu_r', vmin=vmin, vmax=vmax)
223
- ax1.set_title('Original FC', fontsize=12, fontweight='bold')
224
-
225
- im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=vmin, vmax=vmax)
226
- ax2.set_title('Reconstructed FC', fontsize=12, fontweight='bold')
227
-
228
- im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=vmin, vmax=vmax)
229
- ax3.set_title('Generated FC', fontsize=12, fontweight='bold')
230
-
231
- # Add colorbars
232
- for ax, im in zip([ax1, ax2, ax3], [im1, im2, im3]):
233
- plt.colorbar(im, ax=ax)
234
- # Remove axis ticks for cleaner visualization
235
- ax.set_xticks([])
236
- ax.set_yticks([])
237
-
238
- # Plot difference matrix
239
- print("Creating difference matrix visualization...")
240
- diff_vmax = max(0.5, min(1.0, max_diff)) # Adaptive range
241
- im_diff = ax_diff.imshow(diff_matrix, cmap='RdBu_r',
242
- vmin=-diff_vmax, vmax=diff_vmax)
243
- ax_diff.set_title(f'Reconstruction Difference (Mean Abs Diff: {mean_abs_diff:.3f})', fontsize=12)
244
- plt.colorbar(im_diff, ax=ax_diff)
245
-
246
- # Add axis labels to indicate this represents brain regions
247
- ax_diff.set_xlabel('Brain Region Index', fontsize=10)
248
- ax_diff.set_ylabel('Brain Region Index', fontsize=10)
249
-
250
- # Find top connections (strongest positive correlations in original)
251
- print("Finding top connections...")
252
- n_regions = original.shape[0]
253
- top_connections = []
254
-
255
- # Extract top 10 connections (excluding diagonal)
256
- for i in range(n_regions):
257
- for j in range(i+1, n_regions): # upper triangle only
258
- top_connections.append((i, j, original[i, j], reconstructed[i, j]))
259
-
260
- # Sort by strength of original connection (descending)
261
- top_connections.sort(key=lambda x: abs(x[2]), reverse=True)
262
- top_connections = top_connections[:10] # Keep top 10
263
-
264
- # Plot top connections
265
- print("Creating top connections chart...")
266
- ax_top.set_title('Top 10 Strongest Region Connections', fontsize=12)
267
- ax_top.set_xlim([-1.1, 1.1]) # Range for correlation values
268
-
269
- # Create table data
270
- y_pos = np.arange(len(top_connections))
271
- labels = [f"R{i+1}-R{j+1}" for i, j, _, _ in top_connections]
272
- orig_vals = [orig for _, _, orig, _ in top_connections]
273
- recon_vals = [recon for _, _, _, recon in top_connections]
274
-
275
- # Plot horizontal bars
276
- ax_top.barh(y_pos + 0.2, orig_vals, height=0.4, color='blue', alpha=0.6, label='Original')
277
- ax_top.barh(y_pos - 0.2, recon_vals, height=0.4, color='red', alpha=0.6, label='Reconstructed')
278
-
279
- # Add zero line
280
- ax_top.axvline(x=0, color='black', linestyle='-', alpha=0.3)
281
-
282
- # Add labels and legend
283
- ax_top.set_yticks(y_pos)
284
- ax_top.set_yticklabels(labels)
285
- ax_top.set_xlabel('Correlation Strength')
286
- ax_top.legend()
287
-
288
- # Add grid for easier reading
289
- ax_top.grid(True, axis='x', alpha=0.3)
290
-
291
- # Find largest errors per region
292
- print("Analyzing regional errors...")
293
- region_errors = np.mean(np.abs(diff_matrix), axis=1)
294
- worst_regions = np.argsort(region_errors)[-10:] # 10 worst regions
295
-
296
- # Plot region-specific error analysis
297
- region_indices = np.arange(len(worst_regions))
298
- ax_region.barh(region_indices, region_errors[worst_regions], color='red', alpha=0.7)
299
- ax_region.set_yticks(region_indices)
300
- ax_region.set_yticklabels([f"Region {r+1}" for r in worst_regions])
301
- ax_region.set_title("Regions with Highest Error", fontsize=12)
302
- ax_region.set_xlabel("Mean Absolute Error")
303
- ax_region.grid(True, axis='x', alpha=0.3)
304
-
305
- # Create histogram of differences
306
- print("Creating error distribution histogram...")
307
- ax_hist.hist(diff_matrix.flatten(), bins=50, alpha=0.7, color='purple')
308
- ax_hist.set_title("Error Distribution", fontsize=12)
309
- ax_hist.set_xlabel("Reconstruction Error")
310
- ax_hist.set_ylabel("Count")
311
-
312
- # Add vertical lines for mean, median
313
- mean_err = np.mean(diff_matrix)
314
- median_err = np.median(diff_matrix)
315
- ax_hist.axvline(mean_err, color='red', linestyle='--', label=f'Mean: {mean_err:.3f}')
316
- ax_hist.axvline(median_err, color='green', linestyle='--', label=f'Median: {median_err:.3f}')
317
- ax_hist.legend()
318
-
319
- # Display metrics as a table
320
- print("Creating metrics table...")
321
- ax_metrics.axis('tight')
322
- ax_metrics.axis('off')
323
- metrics_data = [
324
- ["MSE", f"{mse:.6f}"],
325
- ["RMSE", f"{rmse:.6f}"],
326
- ["R²", f"{r2:.6f}"],
327
- ["Correlation", f"{corr:.6f}"],
328
- ["Max Error", f"{max_diff:.6f}"],
329
- ["Mean Abs Error", f"{mean_abs_diff:.6f}"]
330
- ]
331
- table = ax_metrics.table(cellText=metrics_data, loc='center',
332
- cellLoc='left', colWidths=[0.4, 0.6])
333
- table.auto_set_font_size(False)
334
- table.set_fontsize(10)
335
- table.scale(1, 1.5)
336
- for (row, col), cell in table.get_celld().items():
337
- if row == 0 or col == 0:
338
- cell.set_text_props(fontproperties=matplotlib.font_manager.FontProperties(weight='bold'))
339
- ax_metrics.set_title("Reconstruction Quality Metrics", fontsize=12)
340
-
341
- # Overall quality score (weighted average of metrics)
342
- quality_score = (0.4 * r2 + 0.4 * corr + 0.2 * (1-rmse/2)) # Scale between 0-1
343
- quality_percent = max(0, min(100, quality_score * 100)) # Clamp to 0-100%
344
-
345
- # Add overall quality score
346
- plt.figtext(0.5, 0.01, f"Overall Reconstruction Quality: {quality_percent:.1f}%",
347
- ha="center", fontsize=14, fontweight='bold',
348
- bbox={"facecolor":"lightblue", "alpha":0.5, "pad":5})
349
-
350
- plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to make room for the quality score
351
- print("FC matrix visualization completed successfully")
352
- return fig
353
-
354
- except Exception as e:
355
- import traceback
356
- print(f"Error in plot_fc_matrices: {e}")
357
- print(f"Traceback: {traceback.format_exc()}")
358
-
359
- # Create a simple error figure
360
- fig = plt.figure(figsize=(15, 5))
361
- plt.text(0.5, 0.5, f"FC visualization error: {str(e)}",
362
- ha='center', va='center', transform=plt.gca().transAxes,
363
- fontsize=12, color='red')
364
- plt.axis('off')
365
- plt.tight_layout()
366
- return fig
367
-
368
- def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None):
369
- """Plot predicted treatment trajectory"""
370
- fig = plt.figure(figsize=(10, 6))
371
 
372
- # Plot current and predicted points
373
- plt.scatter([0], [current_score], label='Current Status', color='blue', s=100)
374
- plt.scatter([months_post_stroke], [predicted_score],
375
- label='Predicted Outcome', color='red', s=100)
376
 
377
- # Plot trajectory
378
- plt.plot([0, months_post_stroke], [current_score, predicted_score],
379
- 'g--', label='Predicted Trajectory')
380
 
381
- # Add prediction interval if available
382
- if prediction_std is not None:
383
- plt.fill_between([months_post_stroke],
384
- [predicted_score - 2*prediction_std],
385
- [predicted_score + 2*prediction_std],
386
- color='red', alpha=0.2,
387
- label='95% Prediction Interval')
388
 
389
- plt.xlabel('Months Post Treatment')
390
- plt.ylabel('WAB Score')
391
- plt.title('Predicted Treatment Trajectory')
392
- plt.legend()
393
- plt.grid(True)
 
 
 
 
 
 
394
 
 
395
  return fig
396
 
397
- def plot_learning_curves(train_losses, val_losses):
398
- """Plot VAE learning curves with enhanced visualization"""
399
- try:
400
- # Handle empty or None inputs - only use real data
401
- if not train_losses or train_losses is None or len(train_losses) == 0:
402
- print("WARNING: No real training loss data provided")
403
- # Create placeholder figure with warning message
404
- fig = plt.figure(figsize=(10, 6))
405
- plt.text(0.5, 0.5, "No real training data available",
406
- ha='center', va='center', transform=plt.gca().transAxes,
407
- fontsize=14, color='darkred')
408
- plt.axis('off')
409
- plt.tight_layout()
410
- return fig
411
-
412
- if not val_losses or val_losses is None or len(val_losses) == 0:
413
- print("WARNING: No real validation loss data provided. Using training data only.")
414
- # Use training data for both
415
- val_losses = train_losses
416
-
417
- # Convert to numpy arrays for safe handling
418
- train_np = np.array(train_losses)
419
- val_np = np.array(val_losses)
420
-
421
- # Check for NaN values
422
- if np.any(np.isnan(train_np)) or np.any(np.isnan(val_np)):
423
- print("WARNING: Learning curves contain NaN values, replacing with zeros")
424
- train_np = np.nan_to_num(train_np)
425
- val_np = np.nan_to_num(val_np)
426
-
427
- # Create figure
428
- fig = plt.figure(figsize=(12, 6))
429
-
430
- # Add improved styling
431
- plt.rcParams['font.size'] = 12
432
-
433
- # Check if train and val lengths match
434
- if len(train_np) != len(val_np):
435
- print(f"Training and validation loss lengths don't match: {len(train_np)} vs {len(val_np)}")
436
- if len(train_np) > len(val_np):
437
- # Validation might be evaluated less frequently
438
- # Create epoch indices for each
439
- train_epochs = np.arange(len(train_np))
440
- val_factor = len(train_np) / len(val_np)
441
- val_epochs = np.arange(0, len(train_np), val_factor)[:len(val_np)]
442
-
443
- plt.plot(train_epochs, train_np, 'b-', linewidth=2, label='Training Loss')
444
- plt.plot(val_epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
445
- else:
446
- # This is unusual, but handle it anyway
447
- plt.plot(train_np, 'b-', linewidth=2, label='Training Loss')
448
- plt.plot(val_np[:len(train_np)], 'r-', linewidth=2, label='Validation Loss')
449
- else:
450
- # Standard case - equal length arrays
451
- epochs = np.arange(len(train_np))
452
- plt.plot(epochs, train_np, 'b-', linewidth=2, label='Training Loss')
453
- plt.plot(epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
454
-
455
- # Add shaded confidence region
456
- if len(train_np) > 5: # Only if we have enough points
457
- # Calculate moving average for smoother trend lines
458
- window_size = min(5, len(train_np) // 5)
459
- if window_size > 1:
460
- avg_train = np.convolve(train_np, np.ones(window_size)/window_size, mode='valid')
461
- avg_val = np.convolve(val_np, np.ones(window_size)/window_size, mode='valid')
462
- avg_epochs = epochs[window_size-1:]
463
- plt.plot(avg_epochs, avg_train, 'b--', linewidth=1, alpha=0.6)
464
- plt.plot(avg_epochs, avg_val, 'r--', linewidth=1, alpha=0.6)
465
-
466
- # Calculate improvement from start to end
467
- if len(train_np) > 1:
468
- train_improvement = ((train_np[0] - train_np[-1]) / train_np[0]) * 100
469
- if len(val_np) > 1:
470
- val_improvement = ((val_np[0] - val_np[-1]) / val_np[0]) * 100
471
- plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement, Validation: {val_improvement:.1f}% improvement')
472
- else:
473
- plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement')
474
- else:
475
- plt.title('VAE Learning Curves')
476
-
477
- # Add min/max annotations
478
- if len(train_np) > 0:
479
- min_train = np.min(train_np)
480
- min_train_epoch = np.argmin(train_np)
481
- plt.annotate(f'Min: {min_train:.4f}', xy=(min_train_epoch, min_train),
482
- xytext=(min_train_epoch+5, min_train+0.05),
483
- arrowprops=dict(facecolor='blue', shrink=0.05, alpha=0.5),
484
- color='blue', fontsize=10)
485
-
486
- if len(val_np) > 0:
487
- min_val = np.min(val_np)
488
- min_val_epoch = np.argmin(val_np)
489
- plt.annotate(f'Min: {min_val:.4f}', xy=(min_val_epoch, min_val),
490
- xytext=(min_val_epoch+5, min_val+0.05),
491
- arrowprops=dict(facecolor='red', shrink=0.05, alpha=0.5),
492
- color='red', fontsize=10)
493
-
494
- # Styling
495
- plt.xlabel('Epoch')
496
- plt.ylabel('Loss')
497
- plt.legend(loc='upper right')
498
- plt.grid(True, alpha=0.3)
499
-
500
- # Set reasonable y-axis limits
501
- all_losses = np.concatenate([train_np, val_np])
502
- y_min = max(0, np.min(all_losses) * 0.9) # Don't go below zero
503
- y_max = np.percentile(all_losses, 95) * 1.1 # Exclude outliers
504
- plt.ylim(y_min, y_max)
505
-
506
- plt.tight_layout()
507
- return fig
508
-
509
- except Exception as e:
510
- import traceback
511
- print(f"Error in plot_learning_curves: {e}")
512
- print(f"Traceback: {traceback.format_exc()}")
513
-
514
- # Create a simple error figure
515
- fig = plt.figure(figsize=(10, 6))
516
- plt.text(0.5, 0.5, f"Learning curves error: {str(e)}",
517
- ha='center', va='center', transform=plt.gca().transAxes,
518
- fontsize=12, color='red')
519
- plt.axis('off')
520
- plt.tight_layout()
521
- return fig
 
 
 
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
+ from utils import fc_matrix_from_triu
4
 
5
+ def visualize_fc_analysis(original_triu, reconstructed_triu, generated_triu, analysis_results=None):
6
+ fig = plt.figure(figsize=(15, 10))
7
+ gs = plt.GridSpec(2, 3)
 
 
 
 
 
 
 
 
 
 
 
8
 
9
+ ax1 = fig.add_subplot(gs[0, 0])
10
+ ax2 = fig.add_subplot(gs[0, 1])
11
+ ax3 = fig.add_subplot(gs[0, 2])
 
12
 
13
+ original = fc_matrix_from_triu(original_triu)
14
+ reconstructed = fc_matrix_from_triu(reconstructed_triu)
15
+ generated = fc_matrix_from_triu(generated_triu)
16
+
17
+ im1 = ax1.imshow(original, cmap='RdBu_r', vmin=-1, vmax=1)
18
+ ax1.set_title('Original FC')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
+ im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=-1, vmax=1)
21
+ ax2.set_title('Reconstructed FC')
 
 
22
 
23
+ im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=-1, vmax=1)
24
+ ax3.set_title('Generated FC')
 
25
 
26
+ plt.colorbar(im1, ax=ax1)
27
+ plt.colorbar(im2, ax=ax2)
28
+ plt.colorbar(im3, ax=ax3)
 
 
 
 
29
 
30
+ if analysis_results is not None:
31
+ ax4 = fig.add_subplot(gs[1, :])
32
+ for demo_name, results in analysis_results.items():
33
+ significant_dims = np.where(np.array(results['p_values']) < 0.05)[0]
34
+ correlations = np.array(results['correlations'])
35
+ ax4.plot(correlations, label=f'{demo_name} (sig. dims: {len(significant_dims)})')
36
+
37
+ ax4.set_xlabel('Latent Dimension')
38
+ ax4.set_ylabel('Correlation Strength')
39
+ ax4.set_title('Demographic Correlations with Latent Dimensions')
40
+ ax4.legend()
41
 
42
+ plt.tight_layout()
43
  return fig
44