SreekarB commited on
Commit
dbe81c1
·
verified ·
1 Parent(s): b32645b

Upload 13 files

Browse files
Files changed (10) hide show
  1. README.md +5 -67
  2. app.py +721 -77
  3. config.py +9 -0
  4. data_preprocessing.py +78 -542
  5. main.py +126 -252
  6. rcf_prediction.py +99 -164
  7. requirements.txt +2 -0
  8. utils.py +28 -60
  9. vae_model.py +0 -19
  10. visualization.py +59 -37
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🧠
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -21,27 +21,9 @@ This application implements a VAE model that:
21
  3. Conditions the generation process on demographic variables (age, sex, time post-stroke, WAB scores)
22
  4. Allows analysis of relationships between brain connectivity patterns and demographic variables
23
 
24
- ## Project Structure
25
-
26
- - **Core Modules**:
27
- - `vae_model.py`: Main VAE implementation for FC matrices with demographic integration
28
- - `data_preprocessing.py`: Preprocessing pipeline for fMRI to FC conversion
29
- - `rcf_prediction.py`: Random forest-based prediction of treatment outcomes
30
- - `visualization.py`: Consolidated visualization utilities
31
- - `main.py`: Main entry point with unified analysis pipeline
32
- - `app.py`: Gradio-based web interface for the system
33
-
34
- - **Configuration**:
35
- - `config.py`: Central configuration for model parameters
36
-
37
- - **Visualization Utilities**:
38
- - `fc_visualization.py`: Object-oriented FC visualization framework
39
- - `direct_fc_visualization.py`: Generate and visualize synthetic FC
40
- - `demo_fc_visualization.py`: Visualize FC from nilearn datasets
41
-
42
  ## Dataset
43
 
44
- This demo uses the [SreekarB/OSFData1](https://huggingface.co/datasets/SreekarB/OSFData1) dataset from HuggingFace, which contains:
45
 
46
  - NIfTI files in P01_rs.nii format containing fMRI data
47
  - Demographic information directly in the dataset:
@@ -55,12 +37,10 @@ This demo uses the [SreekarB/OSFData1](https://huggingface.co/datasets/SreekarB/
55
 
56
  The application processes the NIfTI files using the Power 264 atlas to create functional connectivity matrices that are then analyzed by the VAE model.
57
 
58
- ## Running the System
59
-
60
- ### Via Web Interface
61
 
62
  1. **Configure Parameters**:
63
- - **Data Source**: By default, it uses the SreekarB/OSFData1 HuggingFace dataset
64
  - **Latent Dimensions**: Controls the size of the latent space (default: 32)
65
  - **Number of Epochs**: Training iterations (default: 200 for demo)
66
  - **Batch Size**: Training batch size (default: 16)
@@ -74,32 +54,6 @@ The application processes the NIfTI files using the Power 264 atlas to create fu
74
  - Results will show correlations between demographic variables and latent brain patterns
75
  - The visualization shows original FC, reconstructed FC, and a new FC matrix generated from specific demographic values
76
 
77
- ### Via Command Line
78
-
79
- ```bash
80
- # Run complete analysis with real data
81
- python main.py --data_dir data/ --demographic_file demographics.csv --treatment_file treatment_outcomes.csv
82
-
83
- # Run FC analysis only (no treatment prediction)
84
- python main.py --fc_only --data_dir data/ --demographic_file demographics.csv
85
-
86
- # Run with HuggingFace dataset (if available)
87
- python main.py --data_dir SreekarB/OSFData1 --fc_only
88
-
89
- # Run web interface
90
- python app.py
91
- ```
92
-
93
- ### Data Generation Utilities
94
-
95
- ```bash
96
- # Generate synthetic FC matrices for testing
97
- python direct_fc_visualization.py
98
-
99
- # Visualize FC from nilearn dataset
100
- python demo_fc_visualization.py
101
- ```
102
-
103
  ## Outputs
104
 
105
  The application produces visualizations showing:
@@ -108,25 +62,9 @@ The application produces visualizations showing:
108
  - Generated FC matrix (based on specific demographic inputs)
109
  - Correlation plots between latent variables and demographic features
110
 
111
- ## Code Structure Notes
112
-
113
- - The code uses a standardized implementation of matrix conversion and visualization
114
- - The `vector_to_matrix` function in `visualization.py` is the canonical implementation for converting FC vectors to matrices
115
- - Latent representations are saved in `results/latents.npy` for reproducibility
116
- - The VAE incorporates demographic information in the decoder
117
-
118
  ## Technical Details
119
 
120
  - Framework: PyTorch
121
  - Interface: Gradio
122
  - Dataset: HuggingFace Datasets API
123
- - Analysis: Custom implementation of conditional VAE with demographic conditioning
124
-
125
- ## Dependencies
126
-
127
- - PyTorch
128
- - nilearn
129
- - scikit-learn
130
- - numpy/pandas
131
- - matplotlib
132
- - gradio (for web interface)
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
21
  3. Conditions the generation process on demographic variables (age, sex, time post-stroke, WAB scores)
22
  4. Allows analysis of relationships between brain connectivity patterns and demographic variables
23
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  ## Dataset
25
 
26
+ This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
27
 
28
  - NIfTI files in P01_rs.nii format containing fMRI data
29
  - Demographic information directly in the dataset:
 
37
 
38
  The application processes the NIfTI files using the Power 264 atlas to create functional connectivity matrices that are then analyzed by the VAE model.
39
 
40
+ ## How to Use
 
 
41
 
42
  1. **Configure Parameters**:
43
+ - **Data Source**: By default, it uses the SreekarB/OSFData HuggingFace dataset
44
  - **Latent Dimensions**: Controls the size of the latent space (default: 32)
45
  - **Number of Epochs**: Training iterations (default: 200 for demo)
46
  - **Batch Size**: Training batch size (default: 16)
 
54
  - Results will show correlations between demographic variables and latent brain patterns
55
  - The visualization shows original FC, reconstructed FC, and a new FC matrix generated from specific demographic values
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ## Outputs
58
 
59
  The application produces visualizations showing:
 
62
  - Generated FC matrix (based on specific demographic inputs)
63
  - Correlation plots between latent variables and demographic features
64
 
 
 
 
 
 
 
 
65
  ## Technical Details
66
 
67
  - Framework: PyTorch
68
  - Interface: Gradio
69
  - Dataset: HuggingFace Datasets API
70
+ - Analysis: Custom implementation of conditional VAE with demographic conditioning
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -1,100 +1,744 @@
1
  import gradio as gr
2
- from main import run_fc_analysis
 
 
 
 
 
3
  import os
 
 
 
 
 
 
 
4
 
5
- def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
6
- """Run the full VAE analysis pipeline"""
7
- fig = run_fc_analysis(
8
- data_dir=data_source,
9
- demographic_file=None, # We're now getting demographics directly from the dataset
10
- latent_dim=latent_dim,
11
- nepochs=nepochs,
12
- bsize=bsize,
13
- save_model=True,
14
- use_hf_dataset=use_hf_dataset
15
- )
16
- return fig, "Analysis complete! VAE model has been trained and demographic relationships analyzed."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
  def create_interface():
19
- with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
20
- gr.Markdown("""
21
- # Aphasia fMRI to FC Analysis using VAE
22
-
23
- This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
24
-
25
- ## Dataset Information
26
- By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
27
- - ID: Subject identifier
28
- - wab_aq: Aphasia severity score
29
- - age: Age of the subject
30
- - mpo: Months post onset
31
- - education: Years of education
32
- - gender: Subject gender
33
- - handedness: Subject handedness (ignored in the analysis)
34
- """)
35
 
36
- with gr.Row():
37
- with gr.Column(scale=1):
38
- # Configuration parameters
39
- data_source = gr.Textbox(
40
- label="Data Source (HF Dataset ID or Local Directory)",
41
- value="SreekarB/OSFData"
42
- )
43
- latent_dim = gr.Slider(
44
- minimum=8, maximum=64, step=8,
45
- label="Latent Dimensions", value=32
46
- )
47
- nepochs = gr.Slider(
48
- minimum=100, maximum=5000, step=100,
49
- label="Number of Epochs", value=200 # Reduced for faster demos
50
- )
51
- bsize = gr.Slider(
52
- minimum=8, maximum=64, step=8,
53
- label="Batch Size", value=16
54
- )
55
- use_hf_dataset = gr.Checkbox(
56
- label="Use HuggingFace Dataset", value=True
57
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
- # Training button
60
- train_button = gr.Button("Start Training", variant="primary")
61
- status_text = gr.Textbox(label="Status", value="Ready to start training")
62
 
63
- with gr.Column(scale=2):
64
- # Output plot
65
- output_plot = gr.Plot(label="Analysis Results")
66
-
67
- # Link the training button to the analysis function
68
- train_button.click(
69
- fn=gradio_fc_analysis,
70
- inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
71
- outputs=[output_plot, status_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  )
73
 
74
  # Add examples
75
  gr.Examples(
76
  examples=[
77
- ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
 
78
  ],
79
- inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
 
80
  )
81
 
82
- # Add explanation of the workflow
83
  gr.Markdown("""
84
- ## How this works
85
 
86
- 1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
87
- 2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
88
- 3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
89
- 4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
90
- 5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
 
 
 
 
 
 
 
 
91
 
92
- Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
 
 
 
 
93
  """)
94
 
95
- return iface
96
 
97
  if __name__ == "__main__":
98
- iface = create_interface()
99
- iface.launch(share=True)
100
-
 
1
  import gradio as gr
2
+ from main import run_analysis
3
+ from rcf_prediction import AphasiaTreatmentPredictor
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
7
+ from visualization import plot_fc_matrices, plot_learning_curves
8
  import os
9
+ from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score
10
+ import json
11
+ import pickle
12
+ import pandas as pd
13
+ import seaborn as sns
14
+ import logging
15
+ from config import MODEL_CONFIG, PREDICTION_CONFIG
16
 
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class AphasiaPredictionApp:
21
+ def __init__(self):
22
+ self.vae = None
23
+ self.predictor = None
24
+ self.trained = False
25
+ self.latent_dim = MODEL_CONFIG['latent_dim']
26
+
27
+ def train_models(self, data_dir, latent_dim, nepochs, bsize):
28
+ """
29
+ Train VAE and Random Forest models
30
+ """
31
+ # Train VAE and Random Forest
32
+ logger.info(f"Training models with data from {data_dir}")
33
+ logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
34
+
35
+ # Default prediction parameters from our config
36
+ prediction_type = PREDICTION_CONFIG.get('prediction_type', 'regression')
37
+ outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq')
38
+ logger.info(f"Prediction: type={prediction_type}, outcome={outcome_variable}")
39
+
40
+ figures = {}
41
+
42
+ try:
43
+ # Run the full analysis pipeline
44
+ results = run_analysis(
45
+ data_dir=data_dir,
46
+ demographic_file="demographics.csv",
47
+ treatment_file="treatment_outcomes.csv",
48
+ latent_dim=latent_dim,
49
+ nepochs=nepochs,
50
+ bsize=bsize,
51
+ save_model=True
52
+ )
53
+
54
+ # Get the VAE figure from results
55
+ vae_fig = results.get('figures', {}).get('vae')
56
+
57
+ figures['vae'] = vae_fig
58
+
59
+ if results:
60
+ self.vae = results.get('vae')
61
+ self.predictor = results.get('predictor')
62
+ latents = results.get('latents')
63
+ demographics = results.get('demographics')
64
+ predictor_cv_results = results.get('predictor_cv_results')
65
+
66
+ # Store the latent dimension
67
+ self.latent_dim = latent_dim
68
+
69
+ # Mark models as trained
70
+ self.trained = True
71
+
72
+ # Prepare prediction visualization if available
73
+ if self.predictor and predictor_cv_results:
74
+ # Get the outcome variable data
75
+ if outcome_variable == 'wab_aq':
76
+ outcomes = demographics['wab_aq']
77
+ elif outcome_variable == 'age':
78
+ outcomes = demographics['age']
79
+ elif outcome_variable == 'months_post_onset':
80
+ outcomes = demographics['months_post_onset']
81
+ else:
82
+ # Try to find the outcome in demographics data
83
+ outcomes = None
84
+ for key in demographics:
85
+ if outcome_variable.lower() in key.lower():
86
+ outcomes = demographics[key]
87
+ break
88
+
89
+ # Create plots
90
+ if 'prediction_stds' in predictor_cv_results and 'predictions' in predictor_cv_results:
91
+ # Create prediction plots
92
+ prediction_fig = self.create_prediction_plots(
93
+ latents,
94
+ demographics,
95
+ outcomes,
96
+ predictor_cv_results['predictions'],
97
+ predictor_cv_results['prediction_stds']
98
+ )
99
+ figures['prediction'] = prediction_fig
100
+
101
+ # Create feature importance plot if available
102
+ try:
103
+ feature_importance = self.predictor.get_feature_importance()
104
+ if feature_importance is not None:
105
+ importance_fig = self.create_importance_plot(feature_importance)
106
+ figures['importance'] = importance_fig
107
+ except Exception as e:
108
+ logger.warning(f"Could not create feature importance plot: {e}")
109
+
110
+ logger.info("Training completed successfully")
111
+
112
+ # Create learning curve plots if available
113
+ if 'fold_metrics' in predictor_cv_results:
114
+ learning_fig = self.create_learning_curve_plot(
115
+ predictor_cv_results['fold_metrics']
116
+ )
117
+ figures['learning'] = learning_fig
118
+
119
+ except Exception as e:
120
+ logger.error(f"Error in training: {str(e)}")
121
+ error_fig = plt.figure(figsize=(10, 6))
122
+ plt.text(0.5, 0.5, f"Error: {str(e)}",
123
+ horizontalalignment='center', verticalalignment='center',
124
+ fontsize=12, color='red')
125
+ plt.axis('off')
126
+ figures['error'] = error_fig
127
+
128
+ return figures
129
+
130
+ def predict_treatment(self, fmri_file=None, age=50, sex="M",
131
+ months_post_stroke=12, wab_score=50, fc_matrix=None):
132
+ """
133
+ Predict treatment outcome for a patient
134
+
135
+ Args:
136
+ fmri_file: Path to patient's fMRI file
137
+ age: Patient's age at stroke
138
+ sex: Patient's sex (M/F)
139
+ months_post_stroke: Months since stroke
140
+ wab_score: Current WAB score
141
+ fc_matrix: Pre-processed FC matrix (if fMRI file not provided)
142
+
143
+ Returns:
144
+ Prediction results and visualization
145
+ """
146
+ if not self.trained:
147
+ return "Please train the models first!", None
148
+
149
+ try:
150
+ # Process fMRI to FC matrix if provided
151
+ if fmri_file and not fc_matrix:
152
+ logger.info(f"Processing fMRI file: {fmri_file}")
153
+ # Use the single fMRI processing function
154
+ fc_matrix = process_single_fmri(fmri_file)
155
+
156
+ if fc_matrix is None:
157
+ return "Please provide either an fMRI file or an FC matrix", None
158
+
159
+ # Ensure FC matrix is properly shaped
160
+ if isinstance(fc_matrix, list):
161
+ fc_matrix = np.array(fc_matrix)
162
+
163
+ # Get latent representation
164
+ logger.info("Extracting latent representation from FC matrix")
165
+ if len(fc_matrix.shape) == 2: # If matrix is 2D (e.g., 264x264)
166
+ # Convert to flattened upper triangular form
167
+ n = fc_matrix.shape[0]
168
+ indices = np.triu_indices(n, k=1)
169
+ fc_flattened = fc_matrix[indices]
170
+ fc_flattened = fc_flattened.reshape(1, -1)
171
+ latent = self.vae.get_latents(fc_flattened)
172
+ else:
173
+ # Assume already flattened
174
+ latent = self.vae.get_latents(fc_matrix.reshape(1, -1))
175
+
176
+ # Prepare demographics
177
+ demographics = {
178
+ 'age': np.array([float(age)]),
179
+ 'gender': np.array([sex]),
180
+ 'months_post_onset': np.array([float(months_post_stroke)]),
181
+ 'wab_aq': np.array([float(wab_score)])
182
+ }
183
+
184
+ logger.info("Making prediction")
185
+ # Make prediction
186
+ if self.predictor is None:
187
+ return "Predictor model not trained", None
188
+
189
+ # Make prediction using the model's predict method
190
+ prediction, prediction_std = self.predictor.predict(latent, demographics)
191
+
192
+ # Create visualization
193
+ fig = self.plot_treatment_trajectory(
194
+ current_score=wab_score,
195
+ predicted_score=prediction[0],
196
+ months_post_stroke=months_post_stroke,
197
+ prediction_std=prediction_std[0]
198
+ )
199
+
200
+ result_text = f"Predicted treatment outcome: {prediction[0]:.2f} ± {2*prediction_std[0]:.2f}"
201
+ logger.info(result_text)
202
+
203
+ return result_text, fig
204
+
205
+ except Exception as e:
206
+ error_msg = f"Error in prediction: {str(e)}"
207
+ logger.error(error_msg)
208
+ error_fig = plt.figure(figsize=(10, 6))
209
+ plt.text(0.5, 0.5, error_msg,
210
+ horizontalalignment='center', verticalalignment='center',
211
+ fontsize=12, color='red')
212
+ plt.axis('off')
213
+ return error_msg, error_fig
214
+
215
+ def plot_treatment_trajectory(self, current_score, predicted_score,
216
+ months_post_stroke, prediction_std,
217
+ treatment_duration=6):
218
+ """
219
+ Create a visualization of predicted treatment trajectory
220
+
221
+ Args:
222
+ current_score: Current WAB score
223
+ predicted_score: Predicted WAB score after treatment
224
+ months_post_stroke: Current months post stroke
225
+ prediction_std: Standard deviation of prediction
226
+ treatment_duration: Duration of treatment in months
227
+
228
+ Returns:
229
+ matplotlib figure
230
+ """
231
+ fig = plt.figure(figsize=(10, 6))
232
+
233
+ # X-axis: months
234
+ x = np.array([months_post_stroke, months_post_stroke + treatment_duration])
235
+
236
+ # Y-axis: WAB scores
237
+ y = np.array([current_score, predicted_score])
238
+
239
+ # Plot the trajectory
240
+ plt.plot(x, y, 'bo-', linewidth=2, label='Predicted Trajectory')
241
+
242
+ # Add confidence interval
243
+ plt.fill_between(
244
+ x,
245
+ [y[0], y[1] - 2*prediction_std],
246
+ [y[0], y[1] + 2*prediction_std],
247
+ alpha=0.2, color='blue', label='95% Confidence Interval'
248
+ )
249
+
250
+ # Add reference lines
251
+ if current_score < predicted_score:
252
+ improvement = predicted_score - current_score
253
+ plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5,
254
+ label=f'Current WAB = {current_score:.1f}')
255
+ plt.axhline(y=predicted_score, color='g', linestyle='--', alpha=0.5,
256
+ label=f'Predicted WAB = {predicted_score:.1f} (+{improvement:.1f})')
257
+ else:
258
+ decline = current_score - predicted_score
259
+ plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5,
260
+ label=f'Current WAB = {current_score:.1f}')
261
+ plt.axhline(y=predicted_score, color='orange', linestyle='--', alpha=0.5,
262
+ label=f'Predicted WAB = {predicted_score:.1f} (-{decline:.1f})')
263
+
264
+ # Add labels and title
265
+ plt.xlabel('Months Post Stroke')
266
+ plt.ylabel('WAB Score')
267
+ plt.title('Predicted Treatment Trajectory')
268
+ plt.legend(loc='best')
269
+
270
+ # Set y-axis limits
271
+ plt.ylim([0, 100])
272
+
273
+ plt.tight_layout()
274
+ return fig
275
+
276
+ def create_prediction_plots(self, latents, demographics, y_true, y_pred, y_std):
277
+ """Create prediction performance plots"""
278
+ fig = plt.figure(figsize=(12, 8))
279
+
280
+ # Create a 2x2 grid for plots
281
+ gs = plt.GridSpec(2, 2, figure=fig)
282
+
283
+ # Plot predicted vs actual values
284
+ ax1 = fig.add_subplot(gs[0, 0])
285
+
286
+ if self.predictor.prediction_type == 'regression':
287
+ # Regression: scatter plot
288
+ ax1.scatter(y_true, y_pred, alpha=0.7)
289
+
290
+ # Add perfect prediction line
291
+ min_val = min(np.min(y_true), np.min(y_pred))
292
+ max_val = max(np.max(y_true), np.max(y_pred))
293
+ ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
294
+
295
+ ax1.set_xlabel('Actual Values')
296
+ ax1.set_ylabel('Predicted Values')
297
+ ax1.set_title('Predicted vs. Actual Values')
298
+
299
+ # Add R² to the plot
300
+ r2 = r2_score(y_true, y_pred)
301
+ ax1.text(0.05, 0.95, f'R² = {r2:.4f}', transform=ax1.transAxes,
302
+ bbox=dict(facecolor='white', alpha=0.5))
303
+
304
+ # Plot residuals
305
+ ax2 = fig.add_subplot(gs[0, 1])
306
+ residuals = y_true - y_pred
307
+ ax2.scatter(y_pred, residuals, alpha=0.7)
308
+ ax2.axhline(y=0, color='r', linestyle='--')
309
+ ax2.set_xlabel('Predicted Values')
310
+ ax2.set_ylabel('Residuals')
311
+ ax2.set_title('Residual Plot')
312
+
313
+ # Plot prediction errors
314
+ ax3 = fig.add_subplot(gs[1, 0])
315
+ ax3.errorbar(range(len(y_pred)), y_pred, yerr=2*y_std, fmt='o', alpha=0.7,
316
+ label='Predicted ± 2σ')
317
+ ax3.plot(range(len(y_true)), y_true, 'rx', alpha=0.7, label='Actual')
318
+ ax3.set_xlabel('Sample Index')
319
+ ax3.set_ylabel('Value')
320
+ ax3.set_title('Prediction with Error Bars')
321
+ ax3.legend()
322
+
323
+ # Plot error distribution
324
+ ax4 = fig.add_subplot(gs[1, 1])
325
+ ax4.hist(residuals, bins=20, alpha=0.7)
326
+ ax4.axvline(x=0, color='r', linestyle='--')
327
+ ax4.set_xlabel('Prediction Error')
328
+ ax4.set_ylabel('Frequency')
329
+ ax4.set_title('Error Distribution')
330
+
331
+ else: # classification
332
+ # Convert to integer classes if they're strings
333
+ if isinstance(y_true[0], str) or isinstance(y_pred[0], str):
334
+ # Create mapping of class labels to integers
335
+ classes = sorted(list(set(list(y_true) + list(y_pred))))
336
+ class_to_int = {c: i for i, c in enumerate(classes)}
337
+
338
+ y_true_int = np.array([class_to_int[c] for c in y_true])
339
+ y_pred_int = np.array([class_to_int[c] for c in y_pred])
340
+ else:
341
+ y_true_int = y_true
342
+ y_pred_int = y_pred
343
+ classes = sorted(list(set(list(y_true_int) + list(y_pred_int))))
344
+
345
+ # Confusion matrix
346
+ from sklearn.metrics import confusion_matrix
347
+ cm = confusion_matrix(y_true_int, y_pred_int)
348
+
349
+ # Plot confusion matrix
350
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes,
351
+ yticklabels=classes, ax=ax1)
352
+ ax1.set_xlabel('Predicted')
353
+ ax1.set_ylabel('True')
354
+ ax1.set_title('Confusion Matrix')
355
+
356
+ # Class distribution
357
+ ax2 = fig.add_subplot(gs[0, 1])
358
+ unique_classes, true_counts = np.unique(y_true_int, return_counts=True)
359
+ unique_classes, pred_counts = np.unique(y_pred_int, return_counts=True)
360
+
361
+ # Create class distribution DataFrame
362
+ class_dist = pd.DataFrame({
363
+ 'Class': classes,
364
+ 'True': 0,
365
+ 'Predicted': 0
366
+ })
367
+
368
+ for c, count in zip(unique_classes, true_counts):
369
+ class_dist.loc[class_dist['Class'] == classes[c], 'True'] = count
370
+
371
+ for c, count in zip(unique_classes, pred_counts):
372
+ class_dist.loc[class_dist['Class'] == classes[c], 'Predicted'] = count
373
+
374
+ # Plot class distribution
375
+ ax2.bar(class_dist['Class'].astype(str), class_dist['True'], label='True', alpha=0.7)
376
+ ax2.bar(class_dist['Class'].astype(str), class_dist['Predicted'], label='Predicted', alpha=0.5)
377
+ ax2.set_xlabel('Class')
378
+ ax2.set_ylabel('Count')
379
+ ax2.set_title('Class Distribution')
380
+ ax2.legend()
381
+
382
+ # Performance metrics
383
+ ax3 = fig.add_subplot(gs[1, 0])
384
+ ax3.axis('off')
385
+
386
+ # Calculate metrics
387
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
388
+ acc = accuracy_score(y_true_int, y_pred_int)
389
+ prec = precision_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
390
+ rec = recall_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
391
+ f1 = f1_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
392
+
393
+ metrics_text = (
394
+ f"Classification Metrics:\n\n"
395
+ f"Accuracy: {acc:.4f}\n"
396
+ f"Precision: {prec:.4f}\n"
397
+ f"Recall: {rec:.4f}\n"
398
+ f"F1 Score: {f1:.4f}"
399
+ )
400
+
401
+ ax3.text(0.5, 0.5, metrics_text, ha='center', va='center', fontsize=12)
402
+
403
+ # Confidence distribution
404
+ ax4 = fig.add_subplot(gs[1, 1])
405
+ ax4.hist(1 - y_std, bins=20, alpha=0.7)
406
+ ax4.set_xlabel('Prediction Confidence')
407
+ ax4.set_ylabel('Frequency')
408
+ ax4.set_title('Confidence Distribution')
409
+
410
+ plt.tight_layout()
411
+ return fig
412
+
413
+ def create_importance_plot(self, feature_importance, top_n=15):
414
+ """Create feature importance plot"""
415
+ # If feature_importance is a DataFrame, use it directly
416
+ if isinstance(feature_importance, pd.DataFrame):
417
+ importance_df = feature_importance
418
+ else:
419
+ # Create DataFrame
420
+ importance_df = pd.DataFrame({
421
+ 'feature': [f'Feature {i}' for i in range(len(feature_importance))],
422
+ 'importance': feature_importance
423
+ })
424
+
425
+ # Get top N features
426
+ top_features = importance_df.sort_values('importance', ascending=False).head(top_n)
427
+
428
+ # Create plot
429
+ fig = plt.figure(figsize=(10, 6))
430
+ plt.barh(range(len(top_features)), top_features['importance'], align='center')
431
+ plt.yticks(range(len(top_features)), top_features['feature'])
432
+ plt.xlabel('Importance')
433
+ plt.ylabel('Features')
434
+ plt.title(f'Top {top_n} Features by Importance')
435
+ plt.tight_layout()
436
+
437
+ return fig
438
+
439
+ def create_learning_curve_plot(self, fold_metrics):
440
+ """Create learning curve plots from cross-validation results"""
441
+ fig = plt.figure(figsize=(12, 6))
442
+
443
+ # Create a grid for plots
444
+ if self.predictor.prediction_type == 'regression':
445
+ # For regression, show R² and RMSE
446
+ ax1 = plt.subplot(1, 2, 1)
447
+ ax2 = plt.subplot(1, 2, 2)
448
+
449
+ # Plot R² for each fold
450
+ for i, metrics in enumerate(fold_metrics):
451
+ ax1.plot(i+1, metrics['r2'], 'bo')
452
+
453
+ # Plot average R²
454
+ avg_r2 = np.mean([m['r2'] for m in fold_metrics])
455
+ ax1.axhline(y=avg_r2, color='r', linestyle='--',
456
+ label=f'Average R² = {avg_r2:.4f}')
457
+
458
+ ax1.set_xlabel('Fold')
459
+ ax1.set_ylabel('R²')
460
+ ax1.set_title('R² by Fold')
461
+ ax1.set_xticks(range(1, len(fold_metrics)+1))
462
+ ax1.legend()
463
+
464
+ # Plot RMSE for each fold
465
+ for i, metrics in enumerate(fold_metrics):
466
+ ax2.plot(i+1, metrics['rmse'], 'go')
467
+
468
+ # Plot average RMSE
469
+ avg_rmse = np.mean([m['rmse'] for m in fold_metrics])
470
+ ax2.axhline(y=avg_rmse, color='r', linestyle='--',
471
+ label=f'Average RMSE = {avg_rmse:.4f}')
472
+
473
+ ax2.set_xlabel('Fold')
474
+ ax2.set_ylabel('RMSE')
475
+ ax2.set_title('RMSE by Fold')
476
+ ax2.set_xticks(range(1, len(fold_metrics)+1))
477
+ ax2.legend()
478
+
479
+ else: # classification
480
+ # For classification, show accuracy and F1
481
+ ax1 = plt.subplot(1, 2, 1)
482
+ ax2 = plt.subplot(1, 2, 2)
483
+
484
+ # Plot accuracy for each fold
485
+ for i, metrics in enumerate(fold_metrics):
486
+ ax1.plot(i+1, metrics['accuracy'], 'bo')
487
+
488
+ # Plot average accuracy
489
+ avg_acc = np.mean([m['accuracy'] for m in fold_metrics])
490
+ ax1.axhline(y=avg_acc, color='r', linestyle='--',
491
+ label=f'Average Accuracy = {avg_acc:.4f}')
492
+
493
+ ax1.set_xlabel('Fold')
494
+ ax1.set_ylabel('Accuracy')
495
+ ax1.set_title('Accuracy by Fold')
496
+ ax1.set_xticks(range(1, len(fold_metrics)+1))
497
+ ax1.legend()
498
+
499
+ # Plot F1 for each fold
500
+ for i, metrics in enumerate(fold_metrics):
501
+ ax2.plot(i+1, metrics['f1'], 'go')
502
+
503
+ # Plot average F1
504
+ avg_f1 = np.mean([m['f1'] for m in fold_metrics])
505
+ ax2.axhline(y=avg_f1, color='r', linestyle='--',
506
+ label=f'Average F1 = {avg_f1:.4f}')
507
+
508
+ ax2.set_xlabel('Fold')
509
+ ax2.set_ylabel('F1 Score')
510
+ ax2.set_title('F1 Score by Fold')
511
+ ax2.set_xticks(range(1, len(fold_metrics)+1))
512
+ ax2.legend()
513
+
514
+ plt.tight_layout()
515
+ return fig
516
+
517
+ def calculate_fc_accuracy(original_fc, reconstructed_fc):
518
+ """
519
+ Calculate accuracy metrics between original and reconstructed FC matrices
520
+ """
521
+ # Mean Squared Error (lower is better)
522
+ mse = mean_squared_error(original_fc, reconstructed_fc)
523
+
524
+ # Root Mean Squared Error (lower is better)
525
+ rmse = np.sqrt(mse)
526
+
527
+ # R² Score (higher is better, 1 is perfect)
528
+ r2 = r2_score(original_fc, reconstructed_fc)
529
+
530
+ # Correlation between matrices (higher is better)
531
+ corr = np.corrcoef(original_fc.flatten(), reconstructed_fc.flatten())[0, 1]
532
+
533
+ # Custom similarity score based on normalized dot product (higher is better)
534
+ norm_dot = np.dot(original_fc.flatten(), reconstructed_fc.flatten()) / (
535
+ np.linalg.norm(original_fc.flatten()) * np.linalg.norm(reconstructed_fc.flatten()))
536
+
537
+ return {
538
+ "MSE": float(mse),
539
+ "RMSE": float(rmse),
540
+ "R²": float(r2),
541
+ "Correlation": float(corr),
542
+ "Cosine Similarity": float(norm_dot)
543
+ }
544
+
545
+ def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'):
546
+ """
547
+ Save latent representations and associated demographics to file
548
+ """
549
+ os.makedirs('results', exist_ok=True)
550
+
551
+ # Create a dictionary with latents and demographics
552
+ data = {
553
+ 'latents': latents,
554
+ 'demographics': demographics
555
+ }
556
+
557
+ if subjects is not None:
558
+ data['subjects'] = subjects
559
+
560
+ # Save as pickle for easy loading in Python
561
+ with open(os.path.join('results', file_path), 'wb') as f:
562
+ pickle.dump(data, f)
563
+
564
+ # Also save as JSON for more universal access
565
+ json_data = {
566
+ 'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents,
567
+ 'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v
568
+ for k, v in demographics.items()}
569
+ }
570
+
571
+ if subjects is not None:
572
+ json_data['subjects'] = subjects
573
+
574
+ with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f:
575
+ json.dump(json_data, f)
576
+
577
+ return os.path.join('results', file_path)
578
+
579
+ # Make sure directory exists for saving results
580
+ os.makedirs('results', exist_ok=True)
581
 
582
  def create_interface():
583
+ """Create the Gradio interface"""
584
+ app = AphasiaPredictionApp()
585
+
586
+ with gr.Blocks(title="Aphasia Treatment Trajectory Prediction") as interface:
587
+ gr.Markdown("# Aphasia Treatment Trajectory Prediction")
 
 
 
 
 
 
 
 
 
 
 
588
 
589
+ with gr.Tabs():
590
+ # Training Tab
591
+ with gr.Tab("Train Models"):
592
+ with gr.Row():
593
+ with gr.Column(scale=1):
594
+ data_dir = gr.Textbox(
595
+ label="Data Directory",
596
+ value="SreekarB/OSFData"
597
+ )
598
+ latent_dim = gr.Slider(
599
+ minimum=8, maximum=64, step=8,
600
+ label="Latent Dimensions", value=32
601
+ )
602
+ nepochs = gr.Slider(
603
+ minimum=100, maximum=5000, step=100,
604
+ label="Number of Epochs", value=200 # Reduced for faster demos
605
+ )
606
+
607
+ with gr.Column(scale=1):
608
+ bsize = gr.Slider(
609
+ minimum=8, maximum=64, step=8,
610
+ label="Batch Size", value=16
611
+ )
612
+ use_hf_dataset = gr.Checkbox(
613
+ label="Use HuggingFace Dataset", value=True
614
+ )
615
+ with gr.Group("Prediction Options"):
616
+ prediction_type = gr.Radio(
617
+ label="Prediction Type",
618
+ choices=["regression", "classification"],
619
+ value="regression"
620
+ )
621
+ outcome_variable = gr.Dropdown(
622
+ label="Outcome Variable",
623
+ choices=["wab_aq", "age", "months_post_onset"],
624
+ value="wab_aq"
625
+ )
626
+
627
+ train_btn = gr.Button("Train Models", variant="primary")
628
+
629
+ with gr.Row():
630
+ fc_plot = gr.Plot(label="FC Analysis")
631
+
632
+ with gr.Row():
633
+ with gr.Column(scale=1):
634
+ importance_plot = gr.Plot(label="Feature Importance")
635
+ with gr.Column(scale=1):
636
+ prediction_plot = gr.Plot(label="Prediction Performance")
637
+
638
+ with gr.Row():
639
+ learning_plot = gr.Plot(label="Cross-validation Results")
640
+
641
+ # Prediction Tab
642
+ with gr.Tab("Predict Treatment"):
643
+ with gr.Row():
644
+ with gr.Column(scale=1):
645
+ fmri_file = gr.File(label="Patient fMRI Data")
646
+ with gr.Column(scale=1):
647
+ with gr.Group("Patient Demographics"):
648
+ age = gr.Number(label="Age at Stroke", value=60)
649
+ sex = gr.Dropdown(choices=["M", "F"], label="Sex", value="M")
650
+ months = gr.Number(label="Months Post Stroke", value=12)
651
+ wab = gr.Number(label="Current WAB Score", value=50)
652
 
653
+ predict_btn = gr.Button("Predict Treatment Outcome", variant="primary")
 
 
654
 
655
+ with gr.Row():
656
+ prediction_text = gr.Textbox(label="Prediction Result")
657
+
658
+ with gr.Row():
659
+ trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory")
660
+
661
+ # Connect components
662
+ train_outputs = {
663
+ 'vae': fc_plot,
664
+ 'importance': importance_plot,
665
+ 'prediction': prediction_plot,
666
+ 'learning': learning_plot
667
+ }
668
+
669
+ # Handle train button click
670
+ def handle_train(data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
671
+ prediction_type, outcome_variable):
672
+ # Ensure we have the necessary files before training
673
+ # This is a placeholder - in a real app you'd validate these files exist
674
+ demographic_file = os.path.join(data_dir, "demographics.csv")
675
+ treatment_file = os.path.join(data_dir, "treatment_outcomes.csv")
676
+
677
+ results = app.train_models(
678
+ data_dir=data_dir,
679
+ latent_dim=latent_dim,
680
+ nepochs=nepochs,
681
+ bsize=bsize
682
+ )
683
+
684
+ # Return plots in the expected order
685
+ return [
686
+ results.get('vae', None),
687
+ results.get('importance', None),
688
+ results.get('prediction', None),
689
+ results.get('learning', None)
690
+ ]
691
+
692
+ train_btn.click(
693
+ fn=handle_train,
694
+ inputs=[data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
695
+ prediction_type, outcome_variable],
696
+ outputs=[fc_plot, importance_plot, prediction_plot, learning_plot]
697
+ )
698
+
699
+ predict_btn.click(
700
+ fn=app.predict_treatment,
701
+ inputs=[fmri_file, age, sex, months, wab],
702
+ outputs=[prediction_text, trajectory_plot]
703
  )
704
 
705
  # Add examples
706
  gr.Examples(
707
  examples=[
708
+ ["SreekarB/OSFData", 32, 200, 16, True, "regression", "wab_aq"], # Standard training
709
+ ["SreekarB/OSFData", 16, 100, 8, True, "classification", "wab_aq"] # Faster training with classification
710
  ],
711
+ inputs=[data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
712
+ prediction_type, outcome_variable],
713
  )
714
 
715
+ # Add explanation
716
  gr.Markdown("""
717
+ ## How to use this tool
718
 
719
+ 1. **Train Models Tab**: First train the VAE and Random Forest models using your dataset
720
+ - Use the default SreekarB/OSFData dataset or specify your own data source
721
+ - Adjust parameters like latent dimensions and training epochs
722
+ - Choose regression or classification prediction type
723
+ - Select which variable to predict (WAB score by default)
724
+
725
+ 2. **Predict Treatment Tab**: Use the trained models to predict treatment outcomes
726
+ - Upload a patient's fMRI scan or use synthetic data
727
+ - Enter the patient's demographic information
728
+ - Click "Predict Treatment Outcome" to see the projected treatment trajectory
729
+ - The visualization shows the predicted outcome with confidence intervals
730
+
731
+ ## Interpreting Results
732
 
733
+ - The **Feature Importance** plot shows which latent dimensions and demographic variables most strongly predict treatment outcomes
734
+ - The **Prediction Performance** plot shows how well the model predicts known outcomes
735
+ - The **Treatment Trajectory** shows the projected change in WAB score over the course of treatment
736
+
737
+ Note: For optimal results, train with at least 500 epochs and latent dimension of 32 or higher.
738
  """)
739
 
740
+ return interface
741
 
742
  if __name__ == "__main__":
743
+ interface = create_interface()
744
+ interface.launch(share=True)
 
config.py CHANGED
@@ -22,3 +22,12 @@ DATASET_CONFIG = {
22
  'split': 'train'
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
22
  'split': 'train'
23
  }
24
 
25
+ # Prediction configuration
26
+ PREDICTION_CONFIG = {
27
+ 'n_estimators': 100,
28
+ 'max_depth': None,
29
+ 'cv_folds': 5,
30
+ 'prediction_type': 'regression',
31
+ 'default_outcome': 'wab_aq',
32
+ 'save_path': 'results/treatment_predictor.joblib'
33
+ }
data_preprocessing.py CHANGED
@@ -1,557 +1,93 @@
1
  import numpy as np
2
  import pandas as pd
3
- from datasets import load_dataset
4
  from nilearn import input_data, connectome
5
  from nilearn.image import load_img
6
  import nibabel as nib
7
- import os
 
8
 
9
- def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
10
  """
11
- Process fMRI data to generate functional connectivity matrices
 
 
 
 
 
12
 
13
- Parameters:
14
- - dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
15
- - demo_data: Optional demographic data, required if providing NIfTI files
16
- - demo_types: Optional demographic data types, required if providing NIfTI files
 
 
 
 
 
 
 
 
 
17
 
18
- Returns:
19
- - X: Array of FC matrices
20
- - demo_data: Demographic data
21
- - demo_types: Demographic data types
22
- """
23
- print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
24
 
25
- # For SreekarB/OSFData dataset, the data will be loaded from dataset features
26
- if isinstance(dataset_or_niifiles, str):
27
- dataset_name = dataset_or_niifiles
28
- print(f"Loading data from dataset: {dataset_name}")
29
- try:
30
- # Try multiple approaches to load the dataset
31
- approaches = [
32
- lambda: load_dataset(dataset_name, split="train"),
33
- lambda: load_dataset(dataset_name), # Try without split
34
- lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code
35
- lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
36
- ]
37
-
38
- dataset = None
39
- last_error = None
40
-
41
- for i, approach in enumerate(approaches):
42
- if approach is None:
43
- continue
44
-
45
- try:
46
- print(f"Attempt {i+1} to load dataset...")
47
- dataset = approach()
48
- print(f"Successfully loaded dataset with approach {i+1}!")
49
- break
50
- except Exception as e:
51
- print(f"Attempt {i+1} failed: {e}")
52
- last_error = e
53
-
54
- if dataset is None:
55
- print(f"All attempts to load dataset failed. Last error: {last_error}")
56
- raise ValueError(f"Could not load dataset {dataset_name}")
57
- except Exception as e:
58
- print(f"Error during dataset loading: {e}")
59
- raise
60
-
61
- # Prepare demographics data from the dataset
62
- if demo_data is None:
63
- # Create demo_data from the dataset
64
- demo_df = pd.DataFrame({
65
- 'age': dataset['age'],
66
- 'gender': dataset['gender'],
67
- 'mpo': dataset['mpo'],
68
- 'wab_aq': dataset['wab_aq']
69
- })
70
-
71
- demo_data = [
72
- demo_df['age'].values,
73
- demo_df['gender'].values,
74
- demo_df['mpo'].values,
75
- demo_df['wab_aq'].values
76
- ]
77
-
78
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
79
-
80
- # Look for NIfTI files in P01_rs.nii format
81
- print("Searching for NIfTI files in dataset columns...")
82
- nii_files = []
83
-
84
- # Create a temp directory for downloads
85
- import tempfile
86
- from huggingface_hub import hf_hub_download
87
- import shutil
88
-
89
- temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
90
- print(f"Created temporary directory for NIfTI files: {temp_dir}")
91
-
92
- try:
93
- # First approach: Check if there are any columns containing file paths
94
- nii_columns = []
95
- for col in dataset.column_names:
96
- # Check if column name suggests NIfTI files
97
- if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
98
- nii_columns.append(col)
99
- # Or check if column contains file paths
100
- elif len(dataset) > 0:
101
- first_val = dataset[0][col]
102
- if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
103
- nii_columns.append(col)
104
-
105
- if nii_columns:
106
- print(f"Found columns that may contain NIfTI files: {nii_columns}")
107
-
108
- for col in nii_columns:
109
- print(f"Processing column '{col}'...")
110
-
111
- for i, item in enumerate(dataset[col]):
112
- if not isinstance(item, str):
113
- print(f"Item {i} in column {col} is not a string but {type(item)}")
114
- continue
115
-
116
- if not (item.endswith('.nii') or item.endswith('.nii.gz')):
117
- print(f"Item {i} in column {col} is not a NIfTI file: {item}")
118
- continue
119
-
120
- print(f"Downloading {item} from dataset {dataset_name}...")
121
-
122
- try:
123
- # Attempt to download with explicit filename
124
- file_path = hf_hub_download(
125
- repo_id=dataset_name,
126
- filename=item,
127
- repo_type="dataset",
128
- cache_dir=temp_dir
129
- )
130
- nii_files.append(file_path)
131
- print(f"✓ Successfully downloaded {item}")
132
- except Exception as e1:
133
- print(f"Error downloading with explicit filename: {e1}")
134
-
135
- # Second attempt: try with the item's basename
136
- try:
137
- basename = os.path.basename(item)
138
- print(f"Trying with basename: {basename}")
139
- file_path = hf_hub_download(
140
- repo_id=dataset_name,
141
- filename=basename,
142
- repo_type="dataset",
143
- cache_dir=temp_dir
144
- )
145
- nii_files.append(file_path)
146
- print(f"✓ Successfully downloaded {basename}")
147
- except Exception as e2:
148
- print(f"Error downloading with basename: {e2}")
149
-
150
- # Third attempt: check if it's a binary blob in the dataset
151
- try:
152
- if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
153
- print("Found binary data in dataset, saving to temporary file...")
154
- binary_data = dataset[i]['bytes']
155
- temp_file = os.path.join(temp_dir, basename)
156
- with open(temp_file, 'wb') as f:
157
- f.write(binary_data)
158
- nii_files.append(temp_file)
159
- print(f"✓ Saved binary data to {temp_file}")
160
- except Exception as e3:
161
- print(f"Error handling binary data: {e3}")
162
-
163
- # Last resort: look for the file locally
164
- local_path = os.path.join(os.getcwd(), item)
165
- if os.path.exists(local_path):
166
- nii_files.append(local_path)
167
- print(f"✓ Found {item} locally")
168
- else:
169
- print(f"❌ Warning: Could not find {item} anywhere")
170
-
171
- # Second approach: Try to find NIfTI files in dataset repository directly
172
- if not nii_files:
173
- print("No NIfTI files found in dataset columns. Trying direct repository search...")
174
-
175
- try:
176
- from huggingface_hub import list_repo_files, hf_hub_download
177
-
178
- # Try to list all files in the repository
179
- try:
180
- print("Listing all repository files...")
181
- all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
182
- print(f"Found {len(all_repo_files)} files in repository")
183
-
184
- # First prioritize P*_rs.nii files
185
- p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
186
-
187
- # Then include all other NIfTI files
188
- other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
189
-
190
- # Combine, with P*_rs.nii files first
191
- nii_repo_files = p_rs_files + other_nii_files
192
-
193
- if nii_repo_files:
194
- print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
195
-
196
- # Download each file
197
- for nii_file in nii_repo_files:
198
- try:
199
- file_path = hf_hub_download(
200
- repo_id=dataset_name,
201
- filename=nii_file,
202
- repo_type="dataset",
203
- cache_dir=temp_dir
204
- )
205
- nii_files.append(file_path)
206
- print(f"✓ Downloaded {nii_file}")
207
- except Exception as e:
208
- print(f"Error downloading {nii_file}: {e}")
209
- except Exception as e:
210
- print(f"Error listing repository files: {e}")
211
- print("Will try alternative approaches...")
212
-
213
- # If repo listing fails, try with common NIfTI file patterns directly
214
- if not nii_files:
215
- print("Trying common NIfTI file patterns...")
216
-
217
- # Focus specifically on P*_rs.nii pattern
218
- patterns = []
219
-
220
- # Generate P01_rs.nii through P30_rs.nii
221
- for i in range(1, 31): # Try subjects 1-30
222
- patterns.append(f"P{i:02d}_rs.nii")
223
-
224
- # Also try with .nii.gz extension
225
- for i in range(1, 31):
226
- patterns.append(f"P{i:02d}_rs.nii.gz")
227
-
228
- # Include a few other common patterns as fallbacks
229
- patterns.extend([
230
- "sub-01_task-rest_bold.nii.gz", # BIDS format
231
- "fmri.nii.gz", "bold.nii.gz",
232
- "rest.nii.gz"
233
- ])
234
-
235
- for pattern in patterns:
236
- try:
237
- print(f"Trying to download {pattern}...")
238
- file_path = hf_hub_download(
239
- repo_id=dataset_name,
240
- filename=pattern,
241
- repo_type="dataset",
242
- cache_dir=temp_dir
243
- )
244
- nii_files.append(file_path)
245
- print(f"✓ Successfully downloaded {pattern}")
246
- except Exception as e:
247
- print(f"× Failed to download {pattern}")
248
-
249
- # If we still couldn't find any files, check if data files are nested
250
- if not nii_files:
251
- print("Checking for nested data files...")
252
- nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
253
-
254
- for path in nested_paths:
255
- for pattern in patterns:
256
- nested_file = f"{path}{pattern}"
257
- try:
258
- print(f"Trying to download {nested_file}...")
259
- file_path = hf_hub_download(
260
- repo_id=dataset_name,
261
- filename=nested_file,
262
- repo_type="dataset",
263
- cache_dir=temp_dir
264
- )
265
- nii_files.append(file_path)
266
- print(f"✓ Successfully downloaded {nested_file}")
267
- # If we found one file in this directory, try to find all files in it
268
- try:
269
- all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
270
- nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
271
- print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
272
-
273
- for nii_file in nii_files_in_dir:
274
- if nii_file != nested_file: # Skip the one we already downloaded
275
- try:
276
- file_path = hf_hub_download(
277
- repo_id=dataset_name,
278
- filename=nii_file,
279
- repo_type="dataset",
280
- cache_dir=temp_dir
281
- )
282
- nii_files.append(file_path)
283
- print(f"✓ Downloaded {nii_file}")
284
- except Exception as e:
285
- print(f"Error downloading {nii_file}: {e}")
286
- except Exception as e:
287
- print(f"Error finding additional files in {path}: {e}")
288
- except Exception as e:
289
- pass
290
-
291
- except Exception as e:
292
- print(f"Error during repository exploration: {e}")
293
-
294
- # If we still don't have any files, try to search for P*_rs.nii pattern specifically
295
- if not nii_files:
296
- print("Trying to find files matching P*_rs.nii pattern specifically...")
297
-
298
- try:
299
- # List all files in the repository (if we haven't already)
300
- if not 'all_repo_files' in locals():
301
- from huggingface_hub import list_repo_files
302
- try:
303
- all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
304
- except Exception as e:
305
- print(f"Error listing repo files: {e}")
306
- all_repo_files = []
307
-
308
- # Look for files matching the pattern exactly (P*_rs.nii)
309
- pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
310
-
311
- # If we don't find any exact matches, try a more relaxed pattern
312
- if not pattern_files:
313
- pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
314
-
315
- if pattern_files:
316
- print(f"Found {len(pattern_files)} files matching rs.nii pattern")
317
-
318
- # Download each file
319
- for pattern_file in pattern_files:
320
- try:
321
- file_path = hf_hub_download(
322
- repo_id=dataset_name,
323
- filename=pattern_file,
324
- repo_type="dataset",
325
- cache_dir=temp_dir
326
- )
327
- nii_files.append(file_path)
328
- print(f"✓ Downloaded {pattern_file}")
329
- except Exception as e:
330
- print(f"Error downloading {pattern_file}: {e}")
331
- except Exception as e:
332
- print(f"Error searching for pattern files: {e}")
333
-
334
- print(f"Found total of {len(nii_files)} NIfTI files")
335
- except Exception as e:
336
- print(f"Unexpected error during NIfTI file search: {e}")
337
- import traceback
338
- traceback.print_exc()
339
-
340
- # If we found NIfTI files, process them to FC matrices
341
- if nii_files:
342
- print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
343
-
344
- # Load Power 264 atlas
345
- from nilearn import datasets
346
- power = datasets.fetch_coords_power_2011()
347
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
348
-
349
- masker = input_data.NiftiSpheresMasker(
350
- coords, radius=5,
351
- standardize=True,
352
- memory='nilearn_cache', memory_level=1,
353
- verbose=0,
354
- detrend=True,
355
- low_pass=0.1,
356
- high_pass=0.01,
357
- t_r=2.0 # Adjust TR according to your data
358
- )
359
-
360
- # Process fMRI data and compute FC matrices
361
- fc_matrices = []
362
- valid_files = 0
363
- total_files = len(nii_files)
364
-
365
- for nii_file in nii_files:
366
- try:
367
- print(f"Processing {nii_file}...")
368
- fmri_img = load_img(nii_file)
369
-
370
- # Check image dimensions
371
- if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
372
- print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
373
- continue
374
-
375
- time_series = masker.fit_transform(fmri_img)
376
-
377
- # Validate time series data
378
- if np.isnan(time_series).any() or np.isinf(time_series).any():
379
- print(f"Warning: {nii_file} contains NaN or Inf values after masking")
380
- # Replace NaNs with zeros for this file
381
- time_series = np.nan_to_num(time_series)
382
-
383
- correlation_measure = connectome.ConnectivityMeasure(
384
- kind='correlation',
385
- vectorize=False,
386
- discard_diagonal=False
387
- )
388
-
389
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
390
-
391
- # Check for invalid correlation values
392
- if np.isnan(fc_matrix).any():
393
- print(f"Warning: {nii_file} produced NaN correlation values")
394
- continue
395
-
396
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
397
- fc_triu = fc_matrix[triu_indices]
398
-
399
- # Fisher z-transform with proper bounds check
400
- # Clip correlation values to valid range for arctanh
401
- fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
402
- fc_triu = np.arctanh(fc_triu_clipped)
403
-
404
- fc_matrices.append(fc_triu)
405
- valid_files += 1
406
- print(f"Successfully processed {nii_file} to FC matrix")
407
-
408
- except Exception as e:
409
- print(f"Error processing {nii_file}: {e}")
410
-
411
- if fc_matrices:
412
- print(f"Successfully processed {valid_files} out of {total_files} files")
413
-
414
- # Ensure all matrices have the same dimensions
415
- dims = [m.shape[0] for m in fc_matrices]
416
- if len(set(dims)) > 1:
417
- print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
418
- # Use the most common dimension
419
- from collections import Counter
420
- most_common_dim = Counter(dims).most_common(1)[0][0]
421
- print(f"Using most common dimension: {most_common_dim}")
422
- fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
423
-
424
- X = np.array(fc_matrices)
425
-
426
- # Normalize the FC data
427
- mean_x = np.mean(X, axis=0)
428
- std_x = np.std(X, axis=0)
429
-
430
- # Handle zero standard deviation
431
- std_x[std_x == 0] = 1.0
432
-
433
- X = (X - mean_x) / std_x
434
- print(f"Created FC matrices with shape {X.shape}")
435
-
436
- # Make sure demo_data matches the number of FC matrices
437
- if len(demo_data[0]) != X.shape[0]:
438
- print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
439
- f"doesn't match number of FC matrices ({X.shape[0]})")
440
- # Adjust demo_data to match FC matrices
441
- indices = list(range(min(len(demo_data[0]), X.shape[0])))
442
- X = X[indices]
443
- demo_data = [d[indices] for d in demo_data]
444
-
445
- return X, demo_data, demo_types
446
-
447
- print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
448
- # Return a placeholder with the right demographics but empty FC
449
- n_subjects = len(dataset)
450
- n_rois = 264
451
- fc_dim = (n_rois * (n_rois - 1)) // 2
452
- X = np.zeros((n_subjects, fc_dim))
453
- print(f"Created placeholder FC matrices with shape {X.shape}")
454
- return X, demo_data, demo_types
455
-
456
- elif isinstance(dataset_or_niifiles, str):
457
- # Handle real dataset with actual fMRI data
458
- dataset = load_dataset(dataset_or_niifiles, split="train")
459
-
460
- # Load Power 264 atlas
461
- from nilearn import datasets
462
- power = datasets.fetch_coords_power_2011()
463
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
464
-
465
- masker = input_data.NiftiSpheresMasker(
466
- coords, radius=5,
467
- standardize=True,
468
- memory='nilearn_cache', memory_level=1,
469
- verbose=0,
470
- detrend=True,
471
- low_pass=0.1,
472
- high_pass=0.01,
473
- t_r=2.0 # Adjust TR according to your data
474
- )
475
 
476
- # Load demographic data if needed
477
- if demo_data is None:
478
- if 'demographics' in dataset.features:
479
- demo_df = pd.DataFrame(dataset['demographics'])
480
-
481
- demo_data = [
482
- demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
483
- demo_df['sex'].values if 'sex' in demo_df.columns else [],
484
- demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
485
- demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
486
- ]
487
-
488
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
489
-
490
- # Process fMRI data and compute FC matrices
491
- fc_matrices = []
492
- for nii_file in dataset['nii_files']:
493
- fmri_img = load_img(nii_file)
494
- time_series = masker.fit_transform(fmri_img)
495
-
496
- correlation_measure = connectome.ConnectivityMeasure(
497
- kind='correlation', vectorize=False, discard_diagonal=False
498
- )
499
-
500
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
501
-
502
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
503
- fc_triu = fc_matrix[triu_indices]
504
-
505
- fc_triu = np.arctanh(fc_triu) # Fisher z-transform
506
-
507
- fc_matrices.append(fc_triu)
508
-
509
- X = np.array(fc_matrices)
510
-
511
- elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
512
- # Handle a list of NIfTI files
513
- # Similar processing as above but with local files
514
- print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
515
-
516
- # Load Power 264 atlas
517
- from nilearn import datasets
518
- power = datasets.fetch_coords_power_2011()
519
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
520
-
521
- masker = input_data.NiftiSpheresMasker(
522
- coords, radius=5,
523
- standardize=True,
524
- memory='nilearn_cache', memory_level=1,
525
- verbose=0,
526
- detrend=True,
527
- low_pass=0.1,
528
- high_pass=0.01,
529
- t_r=2.0
530
- )
531
-
532
- fc_matrices = []
533
- for nii_file in dataset_or_niifiles:
534
- fmri_img = load_img(nii_file)
535
- time_series = masker.fit_transform(fmri_img)
536
-
537
- correlation_measure = connectome.ConnectivityMeasure(
538
- kind='correlation', vectorize=False, discard_diagonal=False
539
- )
540
-
541
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
542
-
543
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
544
- fc_triu = fc_matrix[triu_indices]
545
-
546
- fc_triu = np.arctanh(fc_triu) # Fisher z-transform
547
-
548
- fc_matrices.append(fc_triu)
549
-
550
- X = np.array(fc_matrices)
551
- else:
552
- raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
553
 
554
  # Normalize the FC data
555
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
556
 
557
- return X, demo_data, demo_types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
 
3
  from nilearn import input_data, connectome
4
  from nilearn.image import load_img
5
  import nibabel as nib
6
+ from pathlib import Path
7
+ from config import PREPROCESS_CONFIG
8
 
9
+ def process_single_fmri(fmri_file):
10
  """
11
+ Process a single fMRI file to FC matrix
12
+ """
13
+ # Use Power 264 atlas
14
+ from nilearn import datasets
15
+ power = datasets.fetch_coords_power_2011()
16
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
17
 
18
+ # Create masker
19
+ masker = input_data.NiftiSpheresMasker(
20
+ coords,
21
+ radius=PREPROCESS_CONFIG['radius'],
22
+ standardize=True,
23
+ memory='nilearn_cache',
24
+ memory_level=1,
25
+ verbose=0,
26
+ detrend=True,
27
+ low_pass=PREPROCESS_CONFIG['low_pass'],
28
+ high_pass=PREPROCESS_CONFIG['high_pass'],
29
+ t_r=PREPROCESS_CONFIG['t_r']
30
+ )
31
 
32
+ # Load and process fMRI
33
+ fmri_img = load_img(fmri_file)
34
+ time_series = masker.fit_transform(fmri_img)
 
 
 
35
 
36
+ # Compute FC matrix
37
+ correlation_measure = connectome.ConnectivityMeasure(
38
+ kind='correlation',
39
+ vectorize=False,
40
+ discard_diagonal=False
41
+ )
42
+
43
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
44
+
45
+ # Get upper triangular part
46
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
47
+ fc_triu = fc_matrix[triu_indices]
48
+
49
+ # Fisher z-transform
50
+ fc_triu = np.arctanh(fc_triu)
51
+
52
+ return fc_triu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def preprocess_fmri_to_fc(nii_files, demo_data, demo_types):
55
+ """
56
+ Convert multiple fMRI files to FC matrices
57
+ """
58
+ fc_matrices = []
59
+
60
+ for nii_file in nii_files:
61
+ fc_triu = process_single_fmri(nii_file)
62
+ fc_matrices.append(fc_triu)
63
+
64
+ X = np.array(fc_matrices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Normalize the FC data
67
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
68
 
69
+ return X, demo_data, demo_types
70
+
71
+ def load_and_preprocess_data(data_dir, demographic_file):
72
+ """
73
+ Load and preprocess both fMRI data and demographics
74
+ """
75
+ # Load demographics
76
+ demo_df = pd.read_csv(demographic_file)
77
+
78
+ demo_data = [
79
+ demo_df['age_at_stroke'].values,
80
+ demo_df['sex'].values,
81
+ demo_df['months_post_stroke'].values,
82
+ demo_df['wab_score'].values
83
+ ]
84
+
85
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
86
+
87
+ # Load fMRI files
88
+ nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
89
+
90
+ # Process fMRI files to FC matrices
91
+ X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
92
+
93
+ return X, demo_data, demo_types
main.py CHANGED
@@ -1,272 +1,150 @@
1
  import os
2
- import sys
3
- # Add the src directory to the path so we can import from demovae
4
- sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
5
-
6
  import numpy as np
7
  import torch
8
  from pathlib import Path
9
- import nibabel as nib
10
- from data_preprocessing import preprocess_fmri_to_fc
11
- from src.demovae.sklearn import DemoVAE
12
- from analysis import analyze_fc_patterns
13
- from visualization import visualize_fc_analysis
14
- from config import MODEL_CONFIG, DATASET_CONFIG
15
  import pandas as pd
16
- import io
17
- from typing import List, Dict, Union, Tuple, Any
 
 
 
 
18
 
19
- def train_fc_vae(X, demo_data, demo_types, model_config):
 
 
 
 
 
 
20
  """
21
- Train a VAE model on functional connectivity matrices
22
  """
23
- n_rois = 264
24
- input_dim = (n_rois * (n_rois - 1)) // 2
 
 
 
 
25
 
26
- print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
 
 
27
 
28
- # Ensure X is a numpy array with correct data type
29
- if not isinstance(X, np.ndarray):
30
- print(f"Converting X from {type(X)} to numpy array")
31
- X = np.array(X, dtype=np.float32)
32
 
33
- # Ensure demo_data contains numpy arrays
34
- for i, d in enumerate(demo_data):
35
- if not isinstance(d, np.ndarray):
36
- print(f"Converting demographic {i} from {type(d)} to numpy array")
37
- demo_data[i] = np.array(d)
38
 
39
- # Check for NaN or Inf values
40
- if np.isnan(X).any() or np.isinf(X).any():
41
- print("Warning: X contains NaN or Inf values. Replacing with zeros.")
42
- X = np.nan_to_num(X)
43
 
44
- # Create the VAE model
45
- vae = DemoVAE(
46
- latent_dim=model_config['latent_dim'],
47
- nepochs=model_config['nepochs'],
48
- bsize=model_config['bsize'],
49
- loss_rec_mult=model_config.get('loss_rec_mult', 100),
50
- loss_decor_mult=model_config.get('loss_decor_mult', 10),
51
- lr=model_config.get('lr', 1e-4),
52
- use_cuda=torch.cuda.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
- print("Fitting VAE model...")
56
- vae.fit(X, demo_data, demo_types)
57
 
58
- return vae, X, demo_data, demo_types
59
-
60
- def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
61
- """
62
- Load fMRI data and demographics from HuggingFace dataset or local files
63
- """
64
- if use_hf_dataset:
65
- # Load from HuggingFace Datasets
66
- from datasets import load_dataset
67
-
68
- print(f"Loading dataset from HuggingFace: {data_dir}")
69
- dataset = load_dataset(data_dir)
70
-
71
- print(f"Dataset columns: {dataset['train'].column_names}")
72
-
73
- # Get demographics directly from the dataset
74
- # Create a DataFrame from the dataset features
75
- demo_df = pd.DataFrame({
76
- 'ID': dataset['train']['ID'],
77
- 'wab_aq': dataset['train']['wab_aq'],
78
- 'age': dataset['train']['age'],
79
- 'mpo': dataset['train']['mpo'],
80
- 'education': dataset['train']['education'],
81
- 'gender': dataset['train']['gender'],
82
- 'handedness': dataset['train']['handedness']
83
- })
84
-
85
- print(f"Loaded demographic data with {len(demo_df)} subjects")
86
-
87
- # Extract demographic data matching our expected format
88
- # Map the dataset columns to our expected format
89
- demo_data = [
90
- demo_df['age'].values, # age at stroke -> age
91
- demo_df['gender'].values, # sex -> gender
92
- demo_df['mpo'].values, # months post stroke -> mpo
93
- demo_df['wab_aq'].values # wab score -> wab_aq
94
- ]
95
-
96
- # Check for FC matrices in the dataset
97
- fc_columns = []
98
- for col in dataset['train'].column_names:
99
- if col.startswith("fc_") or "_fc" in col:
100
- fc_columns.append(col)
101
-
102
- if fc_columns:
103
- print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
104
- # Extract FC matrices
105
- fc_matrices = []
106
- for fc_col in fc_columns:
107
- fc_matrices.append(dataset['train'][fc_col])
108
-
109
- # If we have FC matrices, return them directly
110
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
111
- return fc_matrices, demo_data, demo_types
112
-
113
- # If no FC matrices, look for .nii files
114
- nii_files = []
115
- for col in dataset['train'].column_names:
116
- if col.endswith(".nii.gz") or col.endswith(".nii"):
117
- nii_files.append(dataset['train'][col])
118
-
119
- if nii_files:
120
- print(f"Found {len(nii_files)} .nii files")
121
- else:
122
- print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
123
- # If no structured data is found, we can try to download raw files later
124
-
125
- else:
126
- # Original local file loading
127
- # Load demographics
128
- demo_df = pd.read_csv(demographic_file)
129
-
130
- demo_data = [
131
- demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
132
- demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
133
- demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
134
- demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
135
- ]
136
-
137
- # Load fMRI files
138
- nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
139
 
140
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
141
- return nii_files, demo_data, demo_types
142
-
143
- def run_fc_analysis(data_dir="SreekarB/OSFData",
144
- demographic_file=None,
145
- latent_dim=32,
146
- nepochs=1000,
147
- bsize=16,
148
- save_model=True,
149
- use_hf_dataset=True):
150
 
151
- # Update MODEL_CONFIG with user-specified parameters
152
- MODEL_CONFIG.update({
153
- 'latent_dim': latent_dim,
154
- 'nepochs': nepochs,
155
- 'bsize': bsize
156
- })
157
 
158
- try:
159
- # Load data
160
- print("Loading data...")
161
- nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
162
-
163
- # For SreekarB/OSFData, directly generate synthetic FC matrices
164
- if data_dir == "SreekarB/OSFData" and use_hf_dataset:
165
- print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
166
- X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
167
- # Check if we got FC matrices directly
168
- elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
169
- print("Using pre-computed FC matrices...")
170
- # Convert list of FC matrices to numpy array
171
- X = np.stack([np.array(fc) for fc in nii_files])
172
- else:
173
- # Prepare data by converting fMRI to FC matrices
174
- print("Converting fMRI data to FC matrices...")
175
- X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
176
-
177
- # Print shapes and data types
178
- print(f"X shape: {X.shape}, type: {type(X)}")
179
- for i, d in enumerate(demo_data):
180
- print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
181
-
182
- # Train VAE and get data
183
- print("Training VAE...")
184
- try:
185
- # Use the proper DemoVAE implementation from src/demovae/sklearn.py
186
- vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
187
-
188
- if save_model:
189
- print("Saving model...")
190
- os.makedirs('models', exist_ok=True)
191
- # Use the save method from DemoVAE
192
- vae.save('models/vae_model.pth')
193
- print("Model saved successfully.")
194
- except Exception as e:
195
- print(f"Error during VAE training: {e}")
196
- raise
197
-
198
- # Get latent representations
199
- print("Getting latent representations...")
200
- latents = vae.get_latents(X)
201
-
202
- # Analyze results
203
- print("Analyzing demographic relationships...")
204
- demographics = {
205
- 'age': demo_data[0],
206
- 'months_post_onset': demo_data[2],
207
- 'wab_aq': demo_data[3]
208
  }
209
- analysis_results = analyze_fc_patterns(latents, demographics)
210
-
211
- # Generate new FC matrix
212
- print("Generating new FC matrices...")
213
-
214
- # Get data types from original demographic data for proper conversion
215
- demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
216
-
217
- # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
218
- new_demographics = [
219
- np.array([60.0], dtype=np.float64), # age
220
- np.array(['M'], dtype=np.str_), # gender
221
- np.array([12.0], dtype=np.float64), # months post onset
222
- np.array([80.0], dtype=np.float64) # wab score
223
- ]
224
-
225
- # Verify the demographic data arrays match the expected types
226
- print("Demographic data types:")
227
- for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
228
- print(f" {name}: shape={data.shape}, dtype={data.dtype}")
229
-
230
- print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
231
- try:
232
- generated_fc = vae.transform(1, new_demographics, demo_types)
233
- except Exception as e:
234
- print(f"Error generating new FC matrix: {e}")
235
- # Try with a fallback approach
236
- print("Trying alternative generation approach...")
237
- # If specific gender is causing issues, try the first gender from training data
238
- new_demographics[1] = np.array([demo_data[1][0]])
239
- generated_fc = vae.transform(1, new_demographics, demo_types)
240
- reconstructed_fc = vae.transform(X, demo_data, demo_types)
241
-
242
- # Visualize results
243
- print("Creating visualizations...")
244
- fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
245
-
246
- return fig
247
-
248
- except Exception as e:
249
- import traceback
250
- print(f"Error in run_fc_analysis: {str(e)}")
251
- print(traceback.format_exc())
252
-
253
- # Create a dummy figure with error message
254
- import matplotlib.pyplot as plt
255
- fig = plt.figure(figsize=(10, 6))
256
- plt.text(0.5, 0.5, f"Error: {str(e)}",
257
- horizontalalignment='center', verticalalignment='center',
258
- fontsize=12, color='red')
259
- plt.axis('off')
260
- return fig
261
 
262
  if __name__ == "__main__":
263
  import argparse
264
 
265
- parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
266
- parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
267
- help='HuggingFace dataset ID or directory containing fMRI data')
268
- parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
269
  help='Path to demographic data CSV file')
 
 
270
  parser.add_argument('--latent_dim', type=int, default=32,
271
  help='Dimension of latent space')
272
  parser.add_argument('--nepochs', type=int, default=1000,
@@ -274,20 +152,16 @@ if __name__ == "__main__":
274
  parser.add_argument('--bsize', type=int, default=16,
275
  help='Batch size for training')
276
  parser.add_argument('--no_save', action='store_false',
277
- help='Do not save the model')
278
- parser.add_argument('--use_local', action='store_true',
279
- help='Use local data instead of HuggingFace dataset')
280
 
281
  args = parser.parse_args()
282
 
283
- fig = run_fc_analysis(
284
  data_dir=args.data_dir,
285
  demographic_file=args.demographic_file,
 
286
  latent_dim=args.latent_dim,
287
  nepochs=args.nepochs,
288
  bsize=args.bsize,
289
- save_model=args.no_save,
290
- use_hf_dataset=not args.use_local
291
  )
292
- fig.show()
293
-
 
1
  import os
 
 
 
 
2
  import numpy as np
3
  import torch
4
  from pathlib import Path
 
 
 
 
 
 
5
  import pandas as pd
6
+ from data_preprocessing import load_and_preprocess_data
7
+ from vae_model import DemoVAE
8
+ from rcf_prediction import AphasiaTreatmentPredictor
9
+ from visualization import plot_fc_matrices, plot_learning_curves
10
+ from config import MODEL_CONFIG
11
+ import matplotlib.pyplot as plt
12
 
13
+ def run_analysis(data_dir="data",
14
+ demographic_file="demographics.csv",
15
+ treatment_file="treatment_outcomes.csv",
16
+ latent_dim=32,
17
+ nepochs=1000,
18
+ bsize=16,
19
+ save_model=True):
20
  """
21
+ Run the complete analysis pipeline
22
  """
23
+ # Update MODEL_CONFIG with user-specified parameters
24
+ MODEL_CONFIG.update({
25
+ 'latent_dim': latent_dim,
26
+ 'nepochs': nepochs,
27
+ 'bsize': bsize
28
+ })
29
 
30
+ # Create output directories
31
+ os.makedirs('models', exist_ok=True)
32
+ os.makedirs('results', exist_ok=True)
33
 
34
+ # Load and preprocess data
35
+ print("Loading and preprocessing data...")
36
+ X, demo_data, demo_types = load_and_preprocess_data(data_dir, demographic_file)
 
37
 
38
+ # Load treatment outcomes
39
+ treatment_df = pd.read_csv(treatment_file)
40
+ treatment_outcomes = treatment_df['outcome_score'].values
 
 
41
 
42
+ # Initialize and train VAE
43
+ print("Training VAE...")
44
+ vae = DemoVAE(**MODEL_CONFIG)
45
+ train_losses, val_losses = vae.fit(X, demo_data, demo_types)
46
 
47
+ # Get latent representations
48
+ print("Extracting latent representations...")
49
+ latents = vae.get_latents(X)
50
+
51
+ # Initialize and train treatment predictor
52
+ print("Training treatment predictor...")
53
+ predictor = AphasiaTreatmentPredictor(n_estimators=100)
54
+
55
+ # Prepare demographics for predictor
56
+ demographics = {
57
+ 'age_at_stroke': demo_data[0],
58
+ 'sex': demo_data[1],
59
+ 'months_post_stroke': demo_data[2],
60
+ 'wab_score': demo_data[3]
61
+ }
62
+
63
+ # Cross-validate the predictor
64
+ print("Performing cross-validation...")
65
+ cv_mean, cv_std, predictions, prediction_stds = predictor.cross_validate(
66
+ latents=latents,
67
+ demographics=demographics,
68
+ treatment_outcomes=treatment_outcomes
69
  )
70
 
71
+ # Fit final predictor model
72
+ predictor.fit(latents, demographics, treatment_outcomes)
73
 
74
+ # Save models if requested
75
+ if save_model:
76
+ print("Saving models...")
77
+ vae.save('models/vae_model.pt')
78
+ torch.save({
79
+ 'predictor_state': predictor.rf_regressor,
80
+ 'feature_importance': predictor.feature_importance
81
+ }, 'models/predictor_model.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Generate visualizations
84
+ print("Generating visualizations...")
 
 
 
 
 
 
 
 
85
 
86
+ # FC matrix visualization
87
+ reconstructed = vae.transform(X, demo_data, demo_types)
88
+ generated = vae.transform(1,
89
+ [d[:1] for d in demo_data],
90
+ demo_types)
91
+ fc_fig = plot_fc_matrices(X[0], reconstructed[0], generated[0])
92
 
93
+ # Learning curves
94
+ learning_fig = plot_learning_curves(train_losses, val_losses)
95
+
96
+ # Feature importance
97
+ importance_fig = predictor.plot_feature_importance()
98
+
99
+ # Prediction performance
100
+ performance_fig = plt.figure(figsize=(8, 6))
101
+ plt.scatter(treatment_outcomes, predictions)
102
+ plt.plot([min(treatment_outcomes), max(treatment_outcomes)],
103
+ [min(treatment_outcomes), max(treatment_outcomes)],
104
+ 'r--')
105
+ plt.fill_between(treatment_outcomes,
106
+ predictions - 2*prediction_stds,
107
+ predictions + 2*prediction_stds,
108
+ alpha=0.2, color='gray')
109
+ plt.xlabel('Actual Outcome')
110
+ plt.ylabel('Predicted Outcome')
111
+ plt.title(f'Treatment Outcome Prediction\nR² = {cv_mean:.3f} ± {cv_std:.3f}')
112
+ plt.tight_layout()
113
+
114
+ # Save results
115
+ print("Saving results...")
116
+ np.save('results/latents.npy', latents)
117
+ np.save('results/predictions.npy', predictions)
118
+ np.save('results/prediction_stds.npy', prediction_stds)
119
+
120
+ results = {
121
+ 'vae': vae,
122
+ 'predictor': predictor,
123
+ 'latents': latents,
124
+ 'cv_scores': (cv_mean, cv_std),
125
+ 'predictions': predictions,
126
+ 'prediction_stds': prediction_stds,
127
+ 'figures': {
128
+ 'fc_analysis': fc_fig,
129
+ 'learning_curves': learning_fig,
130
+ 'importance': importance_fig,
131
+ 'performance': performance_fig
 
 
 
 
 
 
 
 
 
 
 
132
  }
133
+ }
134
+
135
+ print("Analysis complete!")
136
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  if __name__ == "__main__":
139
  import argparse
140
 
141
+ parser = argparse.ArgumentParser(description='Run Aphasia Treatment Analysis')
142
+ parser.add_argument('--data_dir', type=str, default='data',
143
+ help='Directory containing fMRI data')
144
+ parser.add_argument('--demographic_file', type=str, default='demographics.csv',
145
  help='Path to demographic data CSV file')
146
+ parser.add_argument('--treatment_file', type=str, default='treatment_outcomes.csv',
147
+ help='Path to treatment outcomes CSV file')
148
  parser.add_argument('--latent_dim', type=int, default=32,
149
  help='Dimension of latent space')
150
  parser.add_argument('--nepochs', type=int, default=1000,
 
152
  parser.add_argument('--bsize', type=int, default=16,
153
  help='Batch size for training')
154
  parser.add_argument('--no_save', action='store_false',
155
+ help='Do not save the models')
 
 
156
 
157
  args = parser.parse_args()
158
 
159
+ results = run_analysis(
160
  data_dir=args.data_dir,
161
  demographic_file=args.demographic_file,
162
+ treatment_file=args.treatment_file,
163
  latent_dim=args.latent_dim,
164
  nepochs=args.nepochs,
165
  bsize=args.bsize,
166
+ save_model=args.no_save
 
167
  )
 
 
rcf_prediction.py CHANGED
@@ -1,11 +1,8 @@
1
  import numpy as np
2
- from sklearn.ensemble import RandomForestRegressor
3
  from sklearn.model_selection import cross_val_score, KFold
4
  import pandas as pd
5
- from sklearn.metrics import mean_squared_error, r2_score
6
- # Configure matplotlib for headless environment
7
- import matplotlib
8
- matplotlib.use('Agg') # Use non-interactive backend
9
  import matplotlib.pyplot as plt
10
  import os
11
  import joblib
@@ -15,27 +12,35 @@ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(
15
  logger = logging.getLogger(__name__)
16
 
17
  class AphasiaTreatmentPredictor:
18
- def __init__(self, n_estimators=100, max_depth=None, random_state=42):
19
  """
20
- Initialize the Treatment Predictor with Random Forest Regressor
21
 
22
  Args:
 
23
  n_estimators (int): Number of trees in the forest
24
  max_depth (int): Maximum depth of trees (None for unlimited)
25
  random_state (int): Random seed for reproducibility
26
  """
27
- self.prediction_type = "regression"
28
  self.n_estimators = n_estimators
29
  self.max_depth = max_depth
30
  self.random_state = random_state
31
  self.feature_importance = None
32
  self.feature_names = None
33
 
34
- self.model = RandomForestRegressor(
35
- n_estimators=n_estimators,
36
- max_depth=max_depth,
37
- random_state=random_state
38
- )
 
 
 
 
 
 
 
39
 
40
  def prepare_features(self, latents, demographics):
41
  """
@@ -49,54 +54,10 @@ class AphasiaTreatmentPredictor:
49
  tuple: Combined features array and feature names
50
  """
51
  if isinstance(demographics, dict):
52
- # For dictionary input, ensure all arrays are same length as latents
53
- n_samples = latents.shape[0]
54
- aligned_demos = {}
55
-
56
- for key, values in demographics.items():
57
- if len(values) != n_samples:
58
- print(f"WARNING: Demographics '{key}' length ({len(values)}) doesn't match latents ({n_samples})")
59
- # Truncate or pad to match latent samples
60
- if len(values) > n_samples:
61
- aligned_demos[key] = values[:n_samples] # Truncate
62
- print(f" Truncated '{key}' to {n_samples} samples")
63
- else:
64
- # Pad with repeated values or zeros depending on type
65
- if len(values) > 0:
66
- # Use mean for numerical, mode for categorical
67
- if isinstance(values[0], (int, float, np.number)):
68
- filler = np.mean(values)
69
- else:
70
- # Use most common value
71
- from collections import Counter
72
- filler = Counter(values).most_common(1)[0][0]
73
-
74
- padding = [filler] * (n_samples - len(values))
75
- aligned_demos[key] = list(values) + padding
76
- print(f" Padded '{key}' with {filler} to {n_samples} samples")
77
- else:
78
- # Empty array, fill with zeros
79
- aligned_demos[key] = [0] * n_samples
80
- print(f" Filled empty '{key}' with zeros to {n_samples} samples")
81
- else:
82
- aligned_demos[key] = values
83
-
84
- demo_df = pd.DataFrame(aligned_demos)
85
  else:
86
  demo_df = demographics.copy()
87
 
88
- # Ensure DataFrame has same number of rows as latents
89
- if len(demo_df) != latents.shape[0]:
90
- print(f"WARNING: Demographics DataFrame size ({len(demo_df)}) doesn't match latents ({latents.shape[0]})")
91
- if len(demo_df) > latents.shape[0]:
92
- demo_df = demo_df.iloc[:latents.shape[0]] # Truncate
93
- print(f" Truncated demographics to {latents.shape[0]} samples")
94
- else:
95
- # Cannot easily pad DataFrame, use last row or means
96
- print(f" ERROR: Cannot pad demographics DataFrame - using latents only")
97
- # Create a DataFrame with the same columns but zeros
98
- demo_df = pd.DataFrame(0, index=range(latents.shape[0]), columns=demo_df.columns)
99
-
100
  # Get categorical columns
101
  cat_columns = demo_df.select_dtypes(include=['object']).columns.tolist()
102
 
@@ -110,16 +71,7 @@ class AphasiaTreatmentPredictor:
110
  feature_names = latent_names + demo_names
111
 
112
  # Combine latents with demographics
113
- try:
114
- features = np.hstack([latents, demo_df.values])
115
- except ValueError as e:
116
- print(f"ERROR combining features: {e}")
117
- print(f"Latents shape: {latents.shape}, Demographics shape: {demo_df.values.shape}")
118
- # Fall back to using just latents
119
- print("Falling back to using only latent features")
120
- features = latents
121
- feature_names = latent_names
122
-
123
  return features, feature_names
124
 
125
  def fit(self, latents, demographics, treatment_outcomes):
@@ -138,11 +90,6 @@ class AphasiaTreatmentPredictor:
138
  self.feature_names = feature_names
139
 
140
  logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
141
- print(f"Random Forest: Building {self.n_estimators} trees...")
142
-
143
- # Track progress during fit with verbose
144
- # Set verbose to 2 for detailed per-tree progress
145
- self.model.verbose = 1
146
  self.model.fit(X, treatment_outcomes)
147
 
148
  # Calculate feature importance
@@ -151,7 +98,6 @@ class AphasiaTreatmentPredictor:
151
  'importance': self.model.feature_importances_
152
  }).sort_values('importance', ascending=False)
153
 
154
- print(f"Random Forest: Training complete. Top features: {', '.join(self.feature_importance['feature'].head(3).tolist())}")
155
  return self
156
 
157
  def predict(self, latents, demographics):
@@ -169,11 +115,34 @@ class AphasiaTreatmentPredictor:
169
  predictions = self.model.predict(X)
170
 
171
  # Get prediction intervals using tree variance
172
- tree_predictions = np.array([tree.predict(X)
173
- for tree in self.model.estimators_])
174
- prediction_std = np.std(tree_predictions, axis=0)
 
 
 
 
 
 
175
 
176
  return predictions, prediction_std
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
179
  """
@@ -191,46 +160,32 @@ class AphasiaTreatmentPredictor:
191
  X, feature_names = self.prepare_features(latents, demographics)
192
  self.feature_names = feature_names
193
 
194
- # Adjust n_splits if we have too few samples
195
- sample_count = len(treatment_outcomes)
196
- if sample_count < n_splits * 2: # Need at least 2 samples per fold
197
- adjusted_n_splits = max(2, sample_count // 2) # At least 2 folds, each with multiple samples
198
- logger.warning(f"Too few samples ({sample_count}) for {n_splits} folds. Adjusting to {adjusted_n_splits} folds.")
199
- n_splits = adjusted_n_splits
200
-
201
- logger.info(f"Running {n_splits}-fold cross-validation on {sample_count} samples")
202
- print(f"Random Forest: Starting {n_splits}-fold cross-validation with {sample_count} samples")
203
-
204
- # Use stratified KFold for regression to ensure balanced folds
205
- # or LeaveOneOut for very small datasets
206
- if sample_count <= 5:
207
- from sklearn.model_selection import LeaveOneOut
208
- logger.warning(f"Using Leave-One-Out CV for small dataset with {sample_count} samples")
209
- print(f"Random Forest: Using Leave-One-Out cross-validation due to small sample size ({sample_count})")
210
- kf = LeaveOneOut()
211
- cv_iterator = kf.split(X)
212
- else:
213
- kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
214
- cv_iterator = kf.split(X)
215
 
216
  cv_scores = []
217
  predictions = np.zeros_like(treatment_outcomes)
218
  prediction_stds = np.zeros_like(treatment_outcomes)
219
  fold_metrics = []
220
 
221
- for fold, (train_idx, test_idx) in enumerate(cv_iterator):
222
  X_train, X_test = X[train_idx], X[test_idx]
223
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
224
 
225
- print(f"Random Forest: Training fold {fold+1}/{n_splits} - {len(X_train)} training samples, {len(X_test)} test samples")
226
-
227
  # Clone the model for this fold
228
- fold_model = RandomForestRegressor(
229
- n_estimators=self.n_estimators,
230
- max_depth=self.max_depth,
231
- random_state=self.random_state,
232
- verbose=1 # Add verbosity
233
- )
 
 
 
 
 
 
234
 
235
  # Train the model
236
  fold_model.fit(X_train, y_train)
@@ -242,72 +197,50 @@ class AphasiaTreatmentPredictor:
242
  predictions[test_idx] = pred
243
 
244
  # Calculate metrics
245
- rmse = np.sqrt(mean_squared_error(y_test, pred))
246
-
247
- # R-squared requires at least 2 samples and some variance in the target
248
- if len(y_test) >= 2 and np.var(y_test) > 1e-10:
249
  r2 = r2_score(y_test, pred)
250
- else:
251
- r2 = np.nan
252
- logger.warning(f"Fold {fold+1}: R² not calculated (insufficient samples or variance)")
253
- print(f"Random Forest: Fold {fold+1} - R² not calculated (insufficient samples or variance)")
254
-
255
- # MSE can always be calculated
256
- mse = rmse**2
257
-
258
- # Add other useful metrics if there are enough samples
259
- metrics = {
260
- "r2": r2,
261
- "rmse": rmse,
262
- "mse": mse
263
- }
264
-
265
- # Add explained variance if possible
266
- if len(y_test) >= 2 and np.var(y_test) > 1e-10:
267
- from sklearn.metrics import explained_variance_score
268
- try:
269
- ev = explained_variance_score(y_test, pred)
270
- metrics["explained_variance"] = ev
271
- except:
272
- # Skip if it can't be calculated
273
- pass
274
-
275
- # Get prediction intervals using tree variance
276
- tree_predictions = np.array([tree.predict(X_test)
277
- for tree in fold_model.estimators_])
278
- pred_std = np.std(tree_predictions, axis=0)
279
- prediction_stds[test_idx] = pred_std
280
 
281
  fold_metrics.append(metrics)
282
  logger.info(f"Fold {fold+1} metrics: {metrics}")
283
 
284
- # Print a more user-friendly version of the fold results
285
- r2_val = metrics.get('r2', np.nan)
286
- rmse_val = metrics.get('rmse', np.nan)
287
- r2_text = f"R² = {r2_val:.4f}" if not np.isnan(r2_val) else "R² = N/A"
288
- print(f"Random Forest: Fold {fold+1} results - {r2_text}, RMSE = {rmse_val:.4f}")
289
-
290
  # Calculate average metrics
291
  avg_metrics = {}
292
  for key in fold_metrics[0].keys():
293
- # Filter out nan values when calculating means
294
- values = [fold[key] for fold in fold_metrics if key in fold and not (isinstance(fold[key], float) and np.isnan(fold[key]))]
295
- if values: # Only calculate mean if we have valid values
296
- avg_metrics[key] = np.mean(values)
297
- else:
298
- avg_metrics[key] = np.nan
299
 
300
  logger.info(f"Average CV metrics: {avg_metrics}")
301
 
302
- # Print a summary of cross-validation performance
303
- r2_avg = avg_metrics.get('r2', np.nan)
304
- rmse_avg = avg_metrics.get('rmse', np.nan)
305
- r2_text = f"R² = {r2_avg:.4f}" if not np.isnan(r2_avg) else "R² = N/A"
306
- print(f"Random Forest: Cross-validation complete - Average {r2_text}, RMSE = {rmse_avg:.4f}")
307
-
308
  # Train final model on all data
309
- print(f"Random Forest: Training final model on all {len(X)} samples...")
310
- self.model.verbose = 1
311
  self.model.fit(X, treatment_outcomes)
312
 
313
  # Calculate feature importance
@@ -402,6 +335,7 @@ class AphasiaTreatmentPredictor:
402
 
403
  # Create new instance
404
  predictor = cls(
 
405
  n_estimators=data['n_estimators'],
406
  max_depth=data['max_depth'],
407
  random_state=data['random_state']
@@ -416,7 +350,7 @@ class AphasiaTreatmentPredictor:
416
  return predictor
417
 
418
 
419
- def train_predictor_from_latents(latents, outcomes, demographics=None, cv=5, **kwargs):
420
  """
421
  Train a treatment outcome predictor from VAE latent representations
422
 
@@ -424,16 +358,17 @@ def train_predictor_from_latents(latents, outcomes, demographics=None, cv=5, **k
424
  latents (np.ndarray): Latent representations from VAE
425
  outcomes (np.ndarray): Treatment outcome values
426
  demographics (dict or pd.DataFrame, optional): Demographic information to include as features
 
427
  cv (int): Number of folds for cross-validation
428
  **kwargs: Additional parameters for the AphasiaTreatmentPredictor
429
 
430
  Returns:
431
  dict: Training results and trained model
432
  """
433
- logger.info(f"Training regression model for treatment prediction")
434
 
435
  # Create predictor
436
- predictor = AphasiaTreatmentPredictor(**kwargs)
437
 
438
  # Run cross-validation
439
  cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
 
1
  import numpy as np
2
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
3
  from sklearn.model_selection import cross_val_score, KFold
4
  import pandas as pd
5
+ from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score
 
 
 
6
  import matplotlib.pyplot as plt
7
  import os
8
  import joblib
 
12
  logger = logging.getLogger(__name__)
13
 
14
  class AphasiaTreatmentPredictor:
15
+ def __init__(self, prediction_type="regression", n_estimators=100, max_depth=None, random_state=42):
16
  """
17
+ Initialize the Treatment Predictor with Random Forest
18
 
19
  Args:
20
+ prediction_type (str): "classification" or "regression" depending on outcome variable type
21
  n_estimators (int): Number of trees in the forest
22
  max_depth (int): Maximum depth of trees (None for unlimited)
23
  random_state (int): Random seed for reproducibility
24
  """
25
+ self.prediction_type = prediction_type
26
  self.n_estimators = n_estimators
27
  self.max_depth = max_depth
28
  self.random_state = random_state
29
  self.feature_importance = None
30
  self.feature_names = None
31
 
32
+ if prediction_type == "classification":
33
+ self.model = RandomForestClassifier(
34
+ n_estimators=n_estimators,
35
+ max_depth=max_depth,
36
+ random_state=random_state
37
+ )
38
+ else: # regression
39
+ self.model = RandomForestRegressor(
40
+ n_estimators=n_estimators,
41
+ max_depth=max_depth,
42
+ random_state=random_state
43
+ )
44
 
45
  def prepare_features(self, latents, demographics):
46
  """
 
54
  tuple: Combined features array and feature names
55
  """
56
  if isinstance(demographics, dict):
57
+ demo_df = pd.DataFrame(demographics)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  else:
59
  demo_df = demographics.copy()
60
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  # Get categorical columns
62
  cat_columns = demo_df.select_dtypes(include=['object']).columns.tolist()
63
 
 
71
  feature_names = latent_names + demo_names
72
 
73
  # Combine latents with demographics
74
+ features = np.hstack([latents, demo_df.values])
 
 
 
 
 
 
 
 
 
75
  return features, feature_names
76
 
77
  def fit(self, latents, demographics, treatment_outcomes):
 
90
  self.feature_names = feature_names
91
 
92
  logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
 
 
 
 
 
93
  self.model.fit(X, treatment_outcomes)
94
 
95
  # Calculate feature importance
 
98
  'importance': self.model.feature_importances_
99
  }).sort_values('importance', ascending=False)
100
 
 
101
  return self
102
 
103
  def predict(self, latents, demographics):
 
115
  predictions = self.model.predict(X)
116
 
117
  # Get prediction intervals using tree variance
118
+ if self.prediction_type == "regression":
119
+ tree_predictions = np.array([tree.predict(X)
120
+ for tree in self.model.estimators_])
121
+ prediction_std = np.std(tree_predictions, axis=0)
122
+ else: # classification
123
+ # For classification, use probability as a measure of confidence
124
+ proba = self.model.predict_proba(X)
125
+ # Use max probability as confidence measure
126
+ prediction_std = 1 - np.max(proba, axis=1)
127
 
128
  return predictions, prediction_std
129
+
130
+ def predict_proba(self, latents, demographics):
131
+ """
132
+ Get probability estimates for classification
133
+
134
+ Args:
135
+ latents (np.ndarray): Latent representations from VAE
136
+ demographics (dict or pd.DataFrame): Demographic information
137
+
138
+ Returns:
139
+ np.ndarray: Probability estimates for each class
140
+ """
141
+ if self.prediction_type != "classification":
142
+ raise ValueError("Probability prediction only available for classification")
143
+
144
+ X, _ = self.prepare_features(latents, demographics)
145
+ return self.model.predict_proba(X)
146
 
147
  def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
148
  """
 
160
  X, feature_names = self.prepare_features(latents, demographics)
161
  self.feature_names = feature_names
162
 
163
+ logger.info(f"Running {n_splits}-fold cross-validation")
164
+
165
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
 
167
  cv_scores = []
168
  predictions = np.zeros_like(treatment_outcomes)
169
  prediction_stds = np.zeros_like(treatment_outcomes)
170
  fold_metrics = []
171
 
172
+ for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
173
  X_train, X_test = X[train_idx], X[test_idx]
174
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
175
 
 
 
176
  # Clone the model for this fold
177
+ if self.prediction_type == "classification":
178
+ fold_model = RandomForestClassifier(
179
+ n_estimators=self.n_estimators,
180
+ max_depth=self.max_depth,
181
+ random_state=self.random_state
182
+ )
183
+ else:
184
+ fold_model = RandomForestRegressor(
185
+ n_estimators=self.n_estimators,
186
+ max_depth=self.max_depth,
187
+ random_state=self.random_state
188
+ )
189
 
190
  # Train the model
191
  fold_model.fit(X_train, y_train)
 
197
  predictions[test_idx] = pred
198
 
199
  # Calculate metrics
200
+ if self.prediction_type == "regression":
201
+ rmse = np.sqrt(mean_squared_error(y_test, pred))
 
 
202
  r2 = r2_score(y_test, pred)
203
+ metrics = {
204
+ "r2": r2,
205
+ "rmse": rmse,
206
+ "mse": rmse**2
207
+ }
208
+
209
+ # Get prediction intervals using tree variance
210
+ tree_predictions = np.array([tree.predict(X_test)
211
+ for tree in fold_model.estimators_])
212
+ pred_std = np.std(tree_predictions, axis=0)
213
+ prediction_stds[test_idx] = pred_std
214
+
215
+ else: # classification
216
+ acc = accuracy_score(y_test, pred)
217
+ prec = precision_score(y_test, pred, average='weighted', zero_division=0)
218
+ rec = recall_score(y_test, pred, average='weighted', zero_division=0)
219
+ f1 = f1_score(y_test, pred, average='weighted', zero_division=0)
220
+ metrics = {
221
+ "accuracy": acc,
222
+ "precision": prec,
223
+ "recall": rec,
224
+ "f1": f1
225
+ }
226
+
227
+ # Use probability as a measure of confidence
228
+ proba = fold_model.predict_proba(X_test)
229
+ # Use max probability as confidence measure
230
+ pred_std = 1 - np.max(proba, axis=1)
231
+ prediction_stds[test_idx] = pred_std
 
232
 
233
  fold_metrics.append(metrics)
234
  logger.info(f"Fold {fold+1} metrics: {metrics}")
235
 
 
 
 
 
 
 
236
  # Calculate average metrics
237
  avg_metrics = {}
238
  for key in fold_metrics[0].keys():
239
+ avg_metrics[key] = np.mean([fold[key] for fold in fold_metrics])
 
 
 
 
 
240
 
241
  logger.info(f"Average CV metrics: {avg_metrics}")
242
 
 
 
 
 
 
 
243
  # Train final model on all data
 
 
244
  self.model.fit(X, treatment_outcomes)
245
 
246
  # Calculate feature importance
 
335
 
336
  # Create new instance
337
  predictor = cls(
338
+ prediction_type=data['prediction_type'],
339
  n_estimators=data['n_estimators'],
340
  max_depth=data['max_depth'],
341
  random_state=data['random_state']
 
350
  return predictor
351
 
352
 
353
+ def train_predictor_from_latents(latents, outcomes, demographics=None, prediction_type="regression", cv=5, **kwargs):
354
  """
355
  Train a treatment outcome predictor from VAE latent representations
356
 
 
358
  latents (np.ndarray): Latent representations from VAE
359
  outcomes (np.ndarray): Treatment outcome values
360
  demographics (dict or pd.DataFrame, optional): Demographic information to include as features
361
+ prediction_type (str): "classification" or "regression"
362
  cv (int): Number of folds for cross-validation
363
  **kwargs: Additional parameters for the AphasiaTreatmentPredictor
364
 
365
  Returns:
366
  dict: Training results and trained model
367
  """
368
+ logger.info(f"Training {prediction_type} model for treatment prediction")
369
 
370
  # Create predictor
371
+ predictor = AphasiaTreatmentPredictor(prediction_type=prediction_type, **kwargs)
372
 
373
  # Run cross-validation
374
  cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
requirements.txt CHANGED
@@ -9,4 +9,6 @@ gradio>=2.0.0
9
  datasets>=1.11.0
10
  huggingface_hub>=0.15.0
11
  transformers>=4.15.0
 
 
12
 
 
9
  datasets>=1.11.0
10
  huggingface_hub>=0.15.0
11
  transformers>=4.15.0
12
+ seaborn>=0.11.2
13
+ joblib>=1.0.1
14
 
utils.py CHANGED
@@ -11,15 +11,6 @@ def to_cuda(x, use_cuda):
11
  def to_numpy(x):
12
  return x.detach().cpu().numpy()
13
 
14
- def fc_matrix_from_triu(triu_values, n_rois=264):
15
- fc_matrix = np.zeros((n_rois, n_rois))
16
- triu_indices = np.triu_indices(n_rois, k=1)
17
- triu_values = np.tanh(triu_values)
18
- fc_matrix[triu_indices] = triu_values
19
- fc_matrix = fc_matrix + fc_matrix.T
20
- np.fill_diagonal(fc_matrix, 1)
21
- return fc_matrix
22
-
23
  def rmse(a, b, mean=torch.mean):
24
  return mean((a-b)**2)**0.5
25
 
@@ -47,6 +38,7 @@ def decor_loss(z, demo, use_cuda=True):
47
  ps.append(p)
48
  losses = torch.stack(losses)
49
  return losses, ps
 
50
  def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
51
  demo_t = []
52
  demo_idx = 0
@@ -70,10 +62,13 @@ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
70
  def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
71
  loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
72
  loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
 
73
  # Get linear predictors for demographics
74
  pred_w = []
75
  pred_i = []
76
  pred_stats = []
 
 
77
 
78
  for i, d, t in zip(range(len(demo)), demo, demo_types):
79
  print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
@@ -115,6 +110,9 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
115
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
116
 
117
  for e in range(nepochs):
 
 
 
118
  for bs in range(0, len(x), bsize):
119
  xb = x[bs:(bs+bsize)]
120
  db = demo_t[bs:(bs+bsize)]
@@ -128,59 +126,29 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
128
  loss_decor = sum(loss_decor)
129
  loss_rec = rmse(xb, y)
130
 
131
- # Sample demographics
132
- demo_gen = []
133
- for s, t in zip(pred_stats, demo_types):
134
- if t == 'continuous':
135
- mu, std = s
136
- dd = torch.randn(100).float()
137
- dd = dd*std+mu
138
- dd = to_cuda(dd, vae.use_cuda)
139
- demo_gen.append(dd)
140
- elif t == 'categorical':
141
- idx = np.random.randint(0, len(s))
142
- for i in range(len(s)):
143
- dd = torch.ones(100).float() if idx == i else torch.zeros(100).float()
144
- dd = to_cuda(dd, vae.use_cuda)
145
- demo_gen.append(dd)
146
-
147
- demo_gen = torch.stack(demo_gen).permute(1,0)
148
-
149
- # Generate
150
- z = vae.gen(100)
151
- y = vae.dec(z, demo_gen)
152
-
153
- # Regressor/classifier guidance loss
154
- losses_pred = []
155
- idcs = []
156
- dg_idx = 0
157
-
158
- for s, t in zip(pred_stats, demo_types):
159
- if t == 'continuous':
160
- yy = y@pred_w[dg_idx]+pred_i[dg_idx]
161
- loss = rmse(demo_gen[:,dg_idx], yy)
162
- losses_pred.append(loss)
163
- idcs.append(float(demo_gen[0,dg_idx]))
164
- dg_idx += 1
165
- elif t == 'categorical':
166
- loss = 0
167
- for i in range(len(s)):
168
- yy = y@pred_w[dg_idx]+pred_i[dg_idx]
169
- loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
170
- idcs.append(int(demo_gen[0,dg_idx]))
171
- dg_idx += 1
172
- losses_pred.append(loss)
173
-
174
  total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
175
- loss_rec_mult*loss_rec + loss_decor_mult*loss_decor +
176
- loss_pred_mult*sum(losses_pred))
177
 
178
  total_loss.backward()
179
  optim.step()
180
 
181
- if e%pperiod == 0 or e == nepochs-1:
182
- print(f'Epoch {e} ReconLoss {loss_rec:.4f} CovarianceLoss {loss_C:.4f} '
183
- f'MeanLoss {loss_mu:.4f} DecorLoss {loss_decor:.4f}')
184
- print(f'GuidanceTargets {idcs}')
185
- print(f'GuidanceLosses {[f"{loss:.4f}" for loss in losses_pred]}')
186
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def to_numpy(x):
12
  return x.detach().cpu().numpy()
13
 
 
 
 
 
 
 
 
 
 
14
  def rmse(a, b, mean=torch.mean):
15
  return mean((a-b)**2)**0.5
16
 
 
38
  ps.append(p)
39
  losses = torch.stack(losses)
40
  return losses, ps
41
+
42
  def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
43
  demo_t = []
44
  demo_idx = 0
 
62
  def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
63
  loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
64
  loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
65
+
66
  # Get linear predictors for demographics
67
  pred_w = []
68
  pred_i = []
69
  pred_stats = []
70
+ train_losses = []
71
+ val_losses = []
72
 
73
  for i, d, t in zip(range(len(demo)), demo, demo_types):
74
  print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
 
110
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
111
 
112
  for e in range(nepochs):
113
+ epoch_losses = []
114
+ vae.train()
115
+
116
  for bs in range(0, len(x), bsize):
117
  xb = x[bs:(bs+bsize)]
118
  db = demo_t[bs:(bs+bsize)]
 
126
  loss_decor = sum(loss_decor)
127
  loss_rec = rmse(xb, y)
128
 
129
+ # Calculate total loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
131
+ loss_rec_mult*loss_rec + loss_decor_mult*loss_decor)
 
132
 
133
  total_loss.backward()
134
  optim.step()
135
 
136
+ epoch_losses.append(total_loss.item())
137
+
138
+ # Record training loss
139
+ train_losses.append(np.mean(epoch_losses))
140
+
141
+ # Validation step
142
+ if e % pperiod == 0:
143
+ vae.eval()
144
+ with torch.no_grad():
145
+ z = vae.enc(x)
146
+ y = vae.dec(z, demo_t)
147
+ val_loss = rmse(x, y).item()
148
+ val_losses.append(val_loss)
149
+
150
+ print(f'Epoch {e}/{nepochs} - '
151
+ f'Train Loss: {train_losses[-1]:.4f} - '
152
+ f'Val Loss: {val_loss:.4f}')
153
+
154
+ return train_losses, val_losses
vae_model.py CHANGED
@@ -129,22 +129,3 @@ class DemoVAE(BaseEstimator):
129
  self.demo_dim = checkpoint['demo_dim']
130
  self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
131
  self.vae.load_state_dict(checkpoint['model_state_dict'])
132
-
133
- def train_fc_vae(X, demo_data, demo_types, model_config):
134
- n_rois = 264
135
- input_dim = (n_rois * (n_rois - 1)) // 2
136
-
137
- vae = DemoVAE(
138
- latent_dim=model_config['latent_dim'],
139
- nepochs=model_config['nepochs'],
140
- bsize=model_config['bsize'],
141
- loss_rec_mult=model_config['loss_rec_mult'],
142
- loss_decor_mult=model_config['loss_decor_mult'],
143
- lr=model_config['lr'],
144
- use_cuda=torch.cuda.is_available()
145
- )
146
-
147
- vae.fit(X, demo_data, demo_types)
148
-
149
- return vae, X, demo_data, demo_types
150
-
 
129
  self.demo_dim = checkpoint['demo_dim']
130
  self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
131
  self.vae.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
visualization.py CHANGED
@@ -1,44 +1,66 @@
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
- from utils import fc_matrix_from_triu
4
 
5
- def visualize_fc_analysis(original_triu, reconstructed_triu, generated_triu, analysis_results=None):
6
- fig = plt.figure(figsize=(15, 10))
7
- gs = plt.GridSpec(2, 3)
8
-
9
- ax1 = fig.add_subplot(gs[0, 0])
10
- ax2 = fig.add_subplot(gs[0, 1])
11
- ax3 = fig.add_subplot(gs[0, 2])
12
-
13
- original = fc_matrix_from_triu(original_triu)
14
- reconstructed = fc_matrix_from_triu(reconstructed_triu)
15
- generated = fc_matrix_from_triu(generated_triu)
16
-
17
- im1 = ax1.imshow(original, cmap='RdBu_r', vmin=-1, vmax=1)
18
- ax1.set_title('Original FC')
19
-
20
- im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=-1, vmax=1)
21
- ax2.set_title('Reconstructed FC')
22
-
23
- im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=-1, vmax=1)
24
- ax3.set_title('Generated FC')
25
-
26
- plt.colorbar(im1, ax=ax1)
27
- plt.colorbar(im2, ax=ax2)
28
- plt.colorbar(im3, ax=ax3)
29
-
30
- if analysis_results is not None:
31
- ax4 = fig.add_subplot(gs[1, :])
32
- for demo_name, results in analysis_results.items():
33
- significant_dims = np.where(np.array(results['p_values']) < 0.05)[0]
34
- correlations = np.array(results['correlations'])
35
- ax4.plot(correlations, label=f'{demo_name} (sig. dims: {len(significant_dims)})')
36
-
37
- ax4.set_xlabel('Latent Dimension')
38
- ax4.set_ylabel('Correlation Strength')
39
- ax4.set_title('Demographic Correlations with Latent Dimensions')
40
- ax4.legend()
41
 
42
  plt.tight_layout()
43
  return fig
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
 
3
 
4
+ def plot_fc_matrices(original, reconstructed, generated):
5
+ """Plot FC matrices comparison"""
6
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
7
+
8
+ vmin, vmax = -1, 1
9
+
10
+ im1 = axes[0].imshow(original, cmap='RdBu_r', vmin=vmin, vmax=vmax)
11
+ axes[0].set_title('Original FC')
12
+
13
+ im2 = axes[1].imshow(reconstructed, cmap='RdBu_r', vmin=vmin, vmax=vmax)
14
+ axes[1].set_title('Reconstructed FC')
15
+
16
+ im3 = axes[2].imshow(generated, cmap='RdBu_r', vmin=vmin, vmax=vmax)
17
+ axes[2].set_title('Generated FC')
18
+
19
+ for ax, im in zip(axes, [im1, im2, im3]):
20
+ plt.colorbar(im, ax=ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  plt.tight_layout()
23
  return fig
24
 
25
+ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None):
26
+ """Plot predicted treatment trajectory"""
27
+ fig = plt.figure(figsize=(10, 6))
28
+
29
+ # Plot current and predicted points
30
+ plt.scatter([0], [current_score], label='Current Status', color='blue', s=100)
31
+ plt.scatter([months_post_stroke], [predicted_score],
32
+ label='Predicted Outcome', color='red', s=100)
33
+
34
+ # Plot trajectory
35
+ plt.plot([0, months_post_stroke], [current_score, predicted_score],
36
+ 'g--', label='Predicted Trajectory')
37
+
38
+ # Add prediction interval if available
39
+ if prediction_std is not None:
40
+ plt.fill_between([months_post_stroke],
41
+ [predicted_score - 2*prediction_std],
42
+ [predicted_score + 2*prediction_std],
43
+ color='red', alpha=0.2,
44
+ label='95% Prediction Interval')
45
+
46
+ plt.xlabel('Months Post Treatment')
47
+ plt.ylabel('WAB Score')
48
+ plt.title('Predicted Treatment Trajectory')
49
+ plt.legend()
50
+ plt.grid(True)
51
+
52
+ return fig
53
+
54
+ def plot_learning_curves(train_losses, val_losses):
55
+ """Plot VAE learning curves"""
56
+ fig = plt.figure(figsize=(10, 6))
57
+
58
+ plt.plot(train_losses, label='Training Loss')
59
+ plt.plot(val_losses, label='Validation Loss')
60
+ plt.xlabel('Epoch')
61
+ plt.ylabel('Loss')
62
+ plt.title('VAE Learning Curves')
63
+ plt.legend()
64
+ plt.grid(True)
65
+
66
+ return fig