SreekarB commited on
Commit
67303f6
·
verified ·
1 Parent(s): f0cfbe9

Upload 14 files

Browse files
Files changed (10) hide show
  1. app.py +655 -141
  2. config.py +9 -0
  3. data_preprocessing.py +78 -578
  4. main.py +126 -271
  5. rcf_prediction.py +383 -0
  6. requirements.txt +2 -0
  7. src/.DS_Store +0 -0
  8. utils.py +28 -60
  9. vae_model.py +0 -19
  10. visualization.py +59 -37
app.py CHANGED
@@ -1,10 +1,518 @@
1
  import gradio as gr
2
- from main import run_fc_analysis
3
- import os
4
  import numpy as np
5
- from sklearn.metrics import mean_squared_error, r2_score
 
 
 
 
6
  import json
7
  import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def calculate_fc_accuracy(original_fc, reconstructed_fc):
10
  """
@@ -68,163 +576,169 @@ def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'):
68
 
69
  return os.path.join('results', file_path)
70
 
71
- def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
72
- """Run the full VAE analysis pipeline with accuracy metrics"""
73
- # Run the original analysis
74
- fig, results = run_fc_analysis(
75
- data_dir=data_source,
76
- demographic_file=None, # We're now getting demographics directly from the dataset
77
- latent_dim=latent_dim,
78
- nepochs=nepochs,
79
- bsize=bsize,
80
- save_model=True,
81
- use_hf_dataset=use_hf_dataset,
82
- return_data=True # New parameter to return data, will need to update main.py
83
- )
84
-
85
- if results:
86
- vae = results.get('vae')
87
- X = results.get('X')
88
- latents = results.get('latents')
89
- demographics = results.get('demographics')
90
- reconstructed_fc = results.get('reconstructed_fc')
91
- generated_fc = results.get('generated_fc')
92
-
93
- # Calculate accuracy metrics
94
- accuracy_metrics = {}
95
- if X is not None and reconstructed_fc is not None:
96
- for i in range(min(5, len(X))): # Calculate for up to 5 samples
97
- metrics = calculate_fc_accuracy(X[i], reconstructed_fc[i])
98
- accuracy_metrics[f"Subject_{i+1}"] = metrics
99
-
100
- # Average metrics across subjects
101
- avg_metrics = {}
102
- for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]:
103
- avg_metrics[metric] = np.mean([subject_metrics[metric]
104
- for subject_metrics in accuracy_metrics.values()])
105
- accuracy_metrics["Average"] = avg_metrics
106
-
107
- # Save latent representations if available
108
- if latents is not None and demographics is not None:
109
- latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
110
- print(f"Saved latents to {latents_path}")
111
-
112
- # Prepare status message with accuracy metrics
113
- if accuracy_metrics:
114
- avg = accuracy_metrics["Average"]
115
- status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n"
116
- f"Reconstruction Accuracy Metrics (Average):\n"
117
- f"• MSE: {avg['MSE']:.6f}\n"
118
- f"• RMSE: {avg['RMSE']:.6f}\n"
119
- f"• R²: {avg['R²']:.6f}\n"
120
- f"• Correlation: {avg['Correlation']:.6f}\n"
121
- f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n"
122
- f"Latent representations saved to results/latents_dim{latent_dim}.pkl")
123
- else:
124
- status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
125
- else:
126
- status = "Analysis complete, but no results were returned for accuracy calculation."
127
-
128
- return fig, status
129
 
130
  def create_interface():
131
- with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
132
- gr.Markdown("""
133
- # Aphasia fMRI to FC Analysis using VAE
134
-
135
- This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
136
-
137
- ## Dataset Information
138
- By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
139
- - ID: Subject identifier
140
- - wab_aq: Aphasia severity score
141
- - age: Age of the subject
142
- - mpo: Months post onset
143
- - education: Years of education
144
- - gender: Subject gender
145
- - handedness: Subject handedness (ignored in the analysis)
146
- """)
147
 
148
- with gr.Row():
149
- with gr.Column(scale=1):
150
- # Configuration parameters
151
- data_source = gr.Textbox(
152
- label="Data Source (HF Dataset ID or Local Directory)",
153
- value="SreekarB/OSFData"
154
- )
155
- latent_dim = gr.Slider(
156
- minimum=8, maximum=64, step=8,
157
- label="Latent Dimensions", value=32
158
- )
159
- nepochs = gr.Slider(
160
- minimum=100, maximum=5000, step=100,
161
- label="Number of Epochs", value=200 # Reduced for faster demos
162
- )
163
- bsize = gr.Slider(
164
- minimum=8, maximum=64, step=8,
165
- label="Batch Size", value=16
166
- )
167
- use_hf_dataset = gr.Checkbox(
168
- label="Use HuggingFace Dataset", value=True
169
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
- # Training button
172
- train_button = gr.Button("Start Training", variant="primary")
173
- status_text = gr.Textbox(label="Status", value="Ready to start training")
 
 
174
 
175
- with gr.Column(scale=2):
176
- # Output plot
177
- output_plot = gr.Plot(label="Analysis Results")
178
- accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
179
-
180
- # Link the training button to the analysis function
181
- train_button.click(
182
- fn=gradio_fc_analysis,
183
- inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
184
- outputs=[output_plot, status_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
 
187
- # Custom function to update the accuracy box
188
- def update_accuracy_display(status_text):
189
- if "Accuracy Metrics" in status_text:
190
- # Extract the accuracy metrics section
191
- parts = status_text.split("Reconstruction Accuracy Metrics (Average):")
192
- if len(parts) > 1:
193
- metrics_text = parts[1].split("\n\n")[0]
194
- return f"### Reconstruction Accuracy Metrics\n{metrics_text}"
195
- return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics."
196
-
197
- # Update accuracy box when status changes
198
- status_text.change(
199
- fn=update_accuracy_display,
200
- inputs=[status_text],
201
- outputs=[accuracy_box]
202
  )
203
 
204
  # Add examples
205
  gr.Examples(
206
  examples=[
207
- ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
 
208
  ],
209
- inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
 
210
  )
211
 
212
- # Add explanation of the workflow
213
  gr.Markdown("""
214
- ## How this works
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
- 1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
217
- 2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
218
- 3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
219
- 4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
220
- 5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
221
 
222
- Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
223
  """)
224
 
225
- return iface
226
 
227
  if __name__ == "__main__":
228
- iface = create_interface()
229
- iface.launch(share=True)
230
-
 
1
  import gradio as gr
2
+ from main import run_analysis
3
+ from rcf_prediction import AphasiaTreatmentPredictor
4
  import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
7
+ from visualization import plot_fc_matrices, plot_learning_curves
8
+ import os
9
+ from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score
10
  import json
11
  import pickle
12
+ import pandas as pd
13
+ import seaborn as sns
14
+ import logging
15
+ from config import MODEL_CONFIG, PREDICTION_CONFIG
16
+
17
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
+ logger = logging.getLogger(__name__)
19
+
20
+ class AphasiaPredictionApp:
21
+ def __init__(self):
22
+ self.vae = None
23
+ self.predictor = None
24
+ self.trained = False
25
+ self.latent_dim = MODEL_CONFIG['latent_dim']
26
+
27
+ def train_models(self, data_dir, latent_dim, nepochs, bsize):
28
+ """
29
+ Train VAE and Random Forest models
30
+ """
31
+ # Train VAE and Random Forest
32
+ logger.info(f"Training models with data from {data_dir}")
33
+ logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
34
+
35
+ # Default prediction parameters from our config
36
+ prediction_type = PREDICTION_CONFIG.get('prediction_type', 'regression')
37
+ outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq')
38
+ logger.info(f"Prediction: type={prediction_type}, outcome={outcome_variable}")
39
+
40
+ figures = {}
41
+
42
+ try:
43
+ # Run the full analysis pipeline
44
+ results = run_analysis(
45
+ data_dir=data_dir,
46
+ demographic_file="demographics.csv",
47
+ treatment_file="treatment_outcomes.csv",
48
+ latent_dim=latent_dim,
49
+ nepochs=nepochs,
50
+ bsize=bsize,
51
+ save_model=True
52
+ )
53
+
54
+ # Get the VAE figure from results
55
+ vae_fig = results.get('figures', {}).get('vae')
56
+
57
+ figures['vae'] = vae_fig
58
+
59
+ if results:
60
+ self.vae = results.get('vae')
61
+ self.predictor = results.get('predictor')
62
+ latents = results.get('latents')
63
+ demographics = results.get('demographics')
64
+ predictor_cv_results = results.get('predictor_cv_results')
65
+
66
+ # Store the latent dimension
67
+ self.latent_dim = latent_dim
68
+
69
+ # Mark models as trained
70
+ self.trained = True
71
+
72
+ # Prepare prediction visualization if available
73
+ if self.predictor and predictor_cv_results:
74
+ # Get the outcome variable data
75
+ if outcome_variable == 'wab_aq':
76
+ outcomes = demographics['wab_aq']
77
+ elif outcome_variable == 'age':
78
+ outcomes = demographics['age']
79
+ elif outcome_variable == 'months_post_onset':
80
+ outcomes = demographics['months_post_onset']
81
+ else:
82
+ # Try to find the outcome in demographics data
83
+ outcomes = None
84
+ for key in demographics:
85
+ if outcome_variable.lower() in key.lower():
86
+ outcomes = demographics[key]
87
+ break
88
+
89
+ # Create plots
90
+ if 'prediction_stds' in predictor_cv_results and 'predictions' in predictor_cv_results:
91
+ # Create prediction plots
92
+ prediction_fig = self.create_prediction_plots(
93
+ latents,
94
+ demographics,
95
+ outcomes,
96
+ predictor_cv_results['predictions'],
97
+ predictor_cv_results['prediction_stds']
98
+ )
99
+ figures['prediction'] = prediction_fig
100
+
101
+ # Create feature importance plot if available
102
+ try:
103
+ feature_importance = self.predictor.get_feature_importance()
104
+ if feature_importance is not None:
105
+ importance_fig = self.create_importance_plot(feature_importance)
106
+ figures['importance'] = importance_fig
107
+ except Exception as e:
108
+ logger.warning(f"Could not create feature importance plot: {e}")
109
+
110
+ logger.info("Training completed successfully")
111
+
112
+ # Create learning curve plots if available
113
+ if 'fold_metrics' in predictor_cv_results:
114
+ learning_fig = self.create_learning_curve_plot(
115
+ predictor_cv_results['fold_metrics']
116
+ )
117
+ figures['learning'] = learning_fig
118
+
119
+ except Exception as e:
120
+ logger.error(f"Error in training: {str(e)}")
121
+ error_fig = plt.figure(figsize=(10, 6))
122
+ plt.text(0.5, 0.5, f"Error: {str(e)}",
123
+ horizontalalignment='center', verticalalignment='center',
124
+ fontsize=12, color='red')
125
+ plt.axis('off')
126
+ figures['error'] = error_fig
127
+
128
+ return figures
129
+
130
+ def predict_treatment(self, fmri_file=None, age=50, sex="M",
131
+ months_post_stroke=12, wab_score=50, fc_matrix=None):
132
+ """
133
+ Predict treatment outcome for a patient
134
+
135
+ Args:
136
+ fmri_file: Path to patient's fMRI file
137
+ age: Patient's age at stroke
138
+ sex: Patient's sex (M/F)
139
+ months_post_stroke: Months since stroke
140
+ wab_score: Current WAB score
141
+ fc_matrix: Pre-processed FC matrix (if fMRI file not provided)
142
+
143
+ Returns:
144
+ Prediction results and visualization
145
+ """
146
+ if not self.trained:
147
+ return "Please train the models first!", None
148
+
149
+ try:
150
+ # Process fMRI to FC matrix if provided
151
+ if fmri_file and not fc_matrix:
152
+ logger.info(f"Processing fMRI file: {fmri_file}")
153
+ # Use the single fMRI processing function
154
+ fc_matrix = process_single_fmri(fmri_file)
155
+
156
+ if fc_matrix is None:
157
+ return "Please provide either an fMRI file or an FC matrix", None
158
+
159
+ # Ensure FC matrix is properly shaped
160
+ if isinstance(fc_matrix, list):
161
+ fc_matrix = np.array(fc_matrix)
162
+
163
+ # Get latent representation
164
+ logger.info("Extracting latent representation from FC matrix")
165
+ if len(fc_matrix.shape) == 2: # If matrix is 2D (e.g., 264x264)
166
+ # Convert to flattened upper triangular form
167
+ n = fc_matrix.shape[0]
168
+ indices = np.triu_indices(n, k=1)
169
+ fc_flattened = fc_matrix[indices]
170
+ fc_flattened = fc_flattened.reshape(1, -1)
171
+ latent = self.vae.get_latents(fc_flattened)
172
+ else:
173
+ # Assume already flattened
174
+ latent = self.vae.get_latents(fc_matrix.reshape(1, -1))
175
+
176
+ # Prepare demographics
177
+ demographics = {
178
+ 'age': np.array([float(age)]),
179
+ 'gender': np.array([sex]),
180
+ 'months_post_onset': np.array([float(months_post_stroke)]),
181
+ 'wab_aq': np.array([float(wab_score)])
182
+ }
183
+
184
+ logger.info("Making prediction")
185
+ # Make prediction
186
+ if self.predictor is None:
187
+ return "Predictor model not trained", None
188
+
189
+ # Make prediction using the model's predict method
190
+ prediction, prediction_std = self.predictor.predict(latent, demographics)
191
+
192
+ # Create visualization
193
+ fig = self.plot_treatment_trajectory(
194
+ current_score=wab_score,
195
+ predicted_score=prediction[0],
196
+ months_post_stroke=months_post_stroke,
197
+ prediction_std=prediction_std[0]
198
+ )
199
+
200
+ result_text = f"Predicted treatment outcome: {prediction[0]:.2f} ± {2*prediction_std[0]:.2f}"
201
+ logger.info(result_text)
202
+
203
+ return result_text, fig
204
+
205
+ except Exception as e:
206
+ error_msg = f"Error in prediction: {str(e)}"
207
+ logger.error(error_msg)
208
+ error_fig = plt.figure(figsize=(10, 6))
209
+ plt.text(0.5, 0.5, error_msg,
210
+ horizontalalignment='center', verticalalignment='center',
211
+ fontsize=12, color='red')
212
+ plt.axis('off')
213
+ return error_msg, error_fig
214
+
215
+ def plot_treatment_trajectory(self, current_score, predicted_score,
216
+ months_post_stroke, prediction_std,
217
+ treatment_duration=6):
218
+ """
219
+ Create a visualization of predicted treatment trajectory
220
+
221
+ Args:
222
+ current_score: Current WAB score
223
+ predicted_score: Predicted WAB score after treatment
224
+ months_post_stroke: Current months post stroke
225
+ prediction_std: Standard deviation of prediction
226
+ treatment_duration: Duration of treatment in months
227
+
228
+ Returns:
229
+ matplotlib figure
230
+ """
231
+ fig = plt.figure(figsize=(10, 6))
232
+
233
+ # X-axis: months
234
+ x = np.array([months_post_stroke, months_post_stroke + treatment_duration])
235
+
236
+ # Y-axis: WAB scores
237
+ y = np.array([current_score, predicted_score])
238
+
239
+ # Plot the trajectory
240
+ plt.plot(x, y, 'bo-', linewidth=2, label='Predicted Trajectory')
241
+
242
+ # Add confidence interval
243
+ plt.fill_between(
244
+ x,
245
+ [y[0], y[1] - 2*prediction_std],
246
+ [y[0], y[1] + 2*prediction_std],
247
+ alpha=0.2, color='blue', label='95% Confidence Interval'
248
+ )
249
+
250
+ # Add reference lines
251
+ if current_score < predicted_score:
252
+ improvement = predicted_score - current_score
253
+ plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5,
254
+ label=f'Current WAB = {current_score:.1f}')
255
+ plt.axhline(y=predicted_score, color='g', linestyle='--', alpha=0.5,
256
+ label=f'Predicted WAB = {predicted_score:.1f} (+{improvement:.1f})')
257
+ else:
258
+ decline = current_score - predicted_score
259
+ plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5,
260
+ label=f'Current WAB = {current_score:.1f}')
261
+ plt.axhline(y=predicted_score, color='orange', linestyle='--', alpha=0.5,
262
+ label=f'Predicted WAB = {predicted_score:.1f} (-{decline:.1f})')
263
+
264
+ # Add labels and title
265
+ plt.xlabel('Months Post Stroke')
266
+ plt.ylabel('WAB Score')
267
+ plt.title('Predicted Treatment Trajectory')
268
+ plt.legend(loc='best')
269
+
270
+ # Set y-axis limits
271
+ plt.ylim([0, 100])
272
+
273
+ plt.tight_layout()
274
+ return fig
275
+
276
+ def create_prediction_plots(self, latents, demographics, y_true, y_pred, y_std):
277
+ """Create prediction performance plots"""
278
+ fig = plt.figure(figsize=(12, 8))
279
+
280
+ # Create a 2x2 grid for plots
281
+ gs = plt.GridSpec(2, 2, figure=fig)
282
+
283
+ # Plot predicted vs actual values
284
+ ax1 = fig.add_subplot(gs[0, 0])
285
+
286
+ if self.predictor.prediction_type == 'regression':
287
+ # Regression: scatter plot
288
+ ax1.scatter(y_true, y_pred, alpha=0.7)
289
+
290
+ # Add perfect prediction line
291
+ min_val = min(np.min(y_true), np.min(y_pred))
292
+ max_val = max(np.max(y_true), np.max(y_pred))
293
+ ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
294
+
295
+ ax1.set_xlabel('Actual Values')
296
+ ax1.set_ylabel('Predicted Values')
297
+ ax1.set_title('Predicted vs. Actual Values')
298
+
299
+ # Add R² to the plot
300
+ r2 = r2_score(y_true, y_pred)
301
+ ax1.text(0.05, 0.95, f'R² = {r2:.4f}', transform=ax1.transAxes,
302
+ bbox=dict(facecolor='white', alpha=0.5))
303
+
304
+ # Plot residuals
305
+ ax2 = fig.add_subplot(gs[0, 1])
306
+ residuals = y_true - y_pred
307
+ ax2.scatter(y_pred, residuals, alpha=0.7)
308
+ ax2.axhline(y=0, color='r', linestyle='--')
309
+ ax2.set_xlabel('Predicted Values')
310
+ ax2.set_ylabel('Residuals')
311
+ ax2.set_title('Residual Plot')
312
+
313
+ # Plot prediction errors
314
+ ax3 = fig.add_subplot(gs[1, 0])
315
+ ax3.errorbar(range(len(y_pred)), y_pred, yerr=2*y_std, fmt='o', alpha=0.7,
316
+ label='Predicted ± 2σ')
317
+ ax3.plot(range(len(y_true)), y_true, 'rx', alpha=0.7, label='Actual')
318
+ ax3.set_xlabel('Sample Index')
319
+ ax3.set_ylabel('Value')
320
+ ax3.set_title('Prediction with Error Bars')
321
+ ax3.legend()
322
+
323
+ # Plot error distribution
324
+ ax4 = fig.add_subplot(gs[1, 1])
325
+ ax4.hist(residuals, bins=20, alpha=0.7)
326
+ ax4.axvline(x=0, color='r', linestyle='--')
327
+ ax4.set_xlabel('Prediction Error')
328
+ ax4.set_ylabel('Frequency')
329
+ ax4.set_title('Error Distribution')
330
+
331
+ else: # classification
332
+ # Convert to integer classes if they're strings
333
+ if isinstance(y_true[0], str) or isinstance(y_pred[0], str):
334
+ # Create mapping of class labels to integers
335
+ classes = sorted(list(set(list(y_true) + list(y_pred))))
336
+ class_to_int = {c: i for i, c in enumerate(classes)}
337
+
338
+ y_true_int = np.array([class_to_int[c] for c in y_true])
339
+ y_pred_int = np.array([class_to_int[c] for c in y_pred])
340
+ else:
341
+ y_true_int = y_true
342
+ y_pred_int = y_pred
343
+ classes = sorted(list(set(list(y_true_int) + list(y_pred_int))))
344
+
345
+ # Confusion matrix
346
+ from sklearn.metrics import confusion_matrix
347
+ cm = confusion_matrix(y_true_int, y_pred_int)
348
+
349
+ # Plot confusion matrix
350
+ sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes,
351
+ yticklabels=classes, ax=ax1)
352
+ ax1.set_xlabel('Predicted')
353
+ ax1.set_ylabel('True')
354
+ ax1.set_title('Confusion Matrix')
355
+
356
+ # Class distribution
357
+ ax2 = fig.add_subplot(gs[0, 1])
358
+ unique_classes, true_counts = np.unique(y_true_int, return_counts=True)
359
+ unique_classes, pred_counts = np.unique(y_pred_int, return_counts=True)
360
+
361
+ # Create class distribution DataFrame
362
+ class_dist = pd.DataFrame({
363
+ 'Class': classes,
364
+ 'True': 0,
365
+ 'Predicted': 0
366
+ })
367
+
368
+ for c, count in zip(unique_classes, true_counts):
369
+ class_dist.loc[class_dist['Class'] == classes[c], 'True'] = count
370
+
371
+ for c, count in zip(unique_classes, pred_counts):
372
+ class_dist.loc[class_dist['Class'] == classes[c], 'Predicted'] = count
373
+
374
+ # Plot class distribution
375
+ ax2.bar(class_dist['Class'].astype(str), class_dist['True'], label='True', alpha=0.7)
376
+ ax2.bar(class_dist['Class'].astype(str), class_dist['Predicted'], label='Predicted', alpha=0.5)
377
+ ax2.set_xlabel('Class')
378
+ ax2.set_ylabel('Count')
379
+ ax2.set_title('Class Distribution')
380
+ ax2.legend()
381
+
382
+ # Performance metrics
383
+ ax3 = fig.add_subplot(gs[1, 0])
384
+ ax3.axis('off')
385
+
386
+ # Calculate metrics
387
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
388
+ acc = accuracy_score(y_true_int, y_pred_int)
389
+ prec = precision_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
390
+ rec = recall_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
391
+ f1 = f1_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
392
+
393
+ metrics_text = (
394
+ f"Classification Metrics:\n\n"
395
+ f"Accuracy: {acc:.4f}\n"
396
+ f"Precision: {prec:.4f}\n"
397
+ f"Recall: {rec:.4f}\n"
398
+ f"F1 Score: {f1:.4f}"
399
+ )
400
+
401
+ ax3.text(0.5, 0.5, metrics_text, ha='center', va='center', fontsize=12)
402
+
403
+ # Confidence distribution
404
+ ax4 = fig.add_subplot(gs[1, 1])
405
+ ax4.hist(1 - y_std, bins=20, alpha=0.7)
406
+ ax4.set_xlabel('Prediction Confidence')
407
+ ax4.set_ylabel('Frequency')
408
+ ax4.set_title('Confidence Distribution')
409
+
410
+ plt.tight_layout()
411
+ return fig
412
+
413
+ def create_importance_plot(self, feature_importance, top_n=15):
414
+ """Create feature importance plot"""
415
+ # If feature_importance is a DataFrame, use it directly
416
+ if isinstance(feature_importance, pd.DataFrame):
417
+ importance_df = feature_importance
418
+ else:
419
+ # Create DataFrame
420
+ importance_df = pd.DataFrame({
421
+ 'feature': [f'Feature {i}' for i in range(len(feature_importance))],
422
+ 'importance': feature_importance
423
+ })
424
+
425
+ # Get top N features
426
+ top_features = importance_df.sort_values('importance', ascending=False).head(top_n)
427
+
428
+ # Create plot
429
+ fig = plt.figure(figsize=(10, 6))
430
+ plt.barh(range(len(top_features)), top_features['importance'], align='center')
431
+ plt.yticks(range(len(top_features)), top_features['feature'])
432
+ plt.xlabel('Importance')
433
+ plt.ylabel('Features')
434
+ plt.title(f'Top {top_n} Features by Importance')
435
+ plt.tight_layout()
436
+
437
+ return fig
438
+
439
+ def create_learning_curve_plot(self, fold_metrics):
440
+ """Create learning curve plots from cross-validation results"""
441
+ fig = plt.figure(figsize=(12, 6))
442
+
443
+ # Create a grid for plots
444
+ if self.predictor.prediction_type == 'regression':
445
+ # For regression, show R² and RMSE
446
+ ax1 = plt.subplot(1, 2, 1)
447
+ ax2 = plt.subplot(1, 2, 2)
448
+
449
+ # Plot R² for each fold
450
+ for i, metrics in enumerate(fold_metrics):
451
+ ax1.plot(i+1, metrics['r2'], 'bo')
452
+
453
+ # Plot average R²
454
+ avg_r2 = np.mean([m['r2'] for m in fold_metrics])
455
+ ax1.axhline(y=avg_r2, color='r', linestyle='--',
456
+ label=f'Average R² = {avg_r2:.4f}')
457
+
458
+ ax1.set_xlabel('Fold')
459
+ ax1.set_ylabel('R²')
460
+ ax1.set_title('R² by Fold')
461
+ ax1.set_xticks(range(1, len(fold_metrics)+1))
462
+ ax1.legend()
463
+
464
+ # Plot RMSE for each fold
465
+ for i, metrics in enumerate(fold_metrics):
466
+ ax2.plot(i+1, metrics['rmse'], 'go')
467
+
468
+ # Plot average RMSE
469
+ avg_rmse = np.mean([m['rmse'] for m in fold_metrics])
470
+ ax2.axhline(y=avg_rmse, color='r', linestyle='--',
471
+ label=f'Average RMSE = {avg_rmse:.4f}')
472
+
473
+ ax2.set_xlabel('Fold')
474
+ ax2.set_ylabel('RMSE')
475
+ ax2.set_title('RMSE by Fold')
476
+ ax2.set_xticks(range(1, len(fold_metrics)+1))
477
+ ax2.legend()
478
+
479
+ else: # classification
480
+ # For classification, show accuracy and F1
481
+ ax1 = plt.subplot(1, 2, 1)
482
+ ax2 = plt.subplot(1, 2, 2)
483
+
484
+ # Plot accuracy for each fold
485
+ for i, metrics in enumerate(fold_metrics):
486
+ ax1.plot(i+1, metrics['accuracy'], 'bo')
487
+
488
+ # Plot average accuracy
489
+ avg_acc = np.mean([m['accuracy'] for m in fold_metrics])
490
+ ax1.axhline(y=avg_acc, color='r', linestyle='--',
491
+ label=f'Average Accuracy = {avg_acc:.4f}')
492
+
493
+ ax1.set_xlabel('Fold')
494
+ ax1.set_ylabel('Accuracy')
495
+ ax1.set_title('Accuracy by Fold')
496
+ ax1.set_xticks(range(1, len(fold_metrics)+1))
497
+ ax1.legend()
498
+
499
+ # Plot F1 for each fold
500
+ for i, metrics in enumerate(fold_metrics):
501
+ ax2.plot(i+1, metrics['f1'], 'go')
502
+
503
+ # Plot average F1
504
+ avg_f1 = np.mean([m['f1'] for m in fold_metrics])
505
+ ax2.axhline(y=avg_f1, color='r', linestyle='--',
506
+ label=f'Average F1 = {avg_f1:.4f}')
507
+
508
+ ax2.set_xlabel('Fold')
509
+ ax2.set_ylabel('F1 Score')
510
+ ax2.set_title('F1 Score by Fold')
511
+ ax2.set_xticks(range(1, len(fold_metrics)+1))
512
+ ax2.legend()
513
+
514
+ plt.tight_layout()
515
+ return fig
516
 
517
  def calculate_fc_accuracy(original_fc, reconstructed_fc):
518
  """
 
576
 
577
  return os.path.join('results', file_path)
578
 
579
+ # Make sure directory exists for saving results
580
+ os.makedirs('results', exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  def create_interface():
583
+ """Create the Gradio interface"""
584
+ app = AphasiaPredictionApp()
585
+
586
+ with gr.Blocks(title="Aphasia Treatment Trajectory Prediction") as interface:
587
+ gr.Markdown("# Aphasia Treatment Trajectory Prediction")
 
 
 
 
 
 
 
 
 
 
 
588
 
589
+ with gr.Tabs():
590
+ # Training Tab
591
+ with gr.Tab("Train Models"):
592
+ with gr.Row():
593
+ with gr.Column(scale=1):
594
+ data_dir = gr.Textbox(
595
+ label="Data Directory",
596
+ value="SreekarB/OSFData"
597
+ )
598
+ latent_dim = gr.Slider(
599
+ minimum=8, maximum=64, step=8,
600
+ label="Latent Dimensions", value=32
601
+ )
602
+ nepochs = gr.Slider(
603
+ minimum=100, maximum=5000, step=100,
604
+ label="Number of Epochs", value=200 # Reduced for faster demos
605
+ )
606
+
607
+ with gr.Column(scale=1):
608
+ bsize = gr.Slider(
609
+ minimum=8, maximum=64, step=8,
610
+ label="Batch Size", value=16
611
+ )
612
+ use_hf_dataset = gr.Checkbox(
613
+ label="Use HuggingFace Dataset", value=True
614
+ )
615
+ with gr.Group("Prediction Options"):
616
+ prediction_type = gr.Radio(
617
+ label="Prediction Type",
618
+ choices=["regression", "classification"],
619
+ value="regression"
620
+ )
621
+ outcome_variable = gr.Dropdown(
622
+ label="Outcome Variable",
623
+ choices=["wab_aq", "age", "months_post_onset"],
624
+ value="wab_aq"
625
+ )
626
+
627
+ train_btn = gr.Button("Train Models", variant="primary")
628
+
629
+ with gr.Row():
630
+ fc_plot = gr.Plot(label="FC Analysis")
631
 
632
+ with gr.Row():
633
+ with gr.Column(scale=1):
634
+ importance_plot = gr.Plot(label="Feature Importance")
635
+ with gr.Column(scale=1):
636
+ prediction_plot = gr.Plot(label="Prediction Performance")
637
 
638
+ with gr.Row():
639
+ learning_plot = gr.Plot(label="Cross-validation Results")
640
+
641
+ # Prediction Tab
642
+ with gr.Tab("Predict Treatment"):
643
+ with gr.Row():
644
+ with gr.Column(scale=1):
645
+ fmri_file = gr.File(label="Patient fMRI Data")
646
+ with gr.Column(scale=1):
647
+ with gr.Group("Patient Demographics"):
648
+ age = gr.Number(label="Age at Stroke", value=60)
649
+ sex = gr.Dropdown(choices=["M", "F"], label="Sex", value="M")
650
+ months = gr.Number(label="Months Post Stroke", value=12)
651
+ wab = gr.Number(label="Current WAB Score", value=50)
652
+
653
+ predict_btn = gr.Button("Predict Treatment Outcome", variant="primary")
654
+
655
+ with gr.Row():
656
+ prediction_text = gr.Textbox(label="Prediction Result")
657
+
658
+ with gr.Row():
659
+ trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory")
660
+
661
+ # Connect components
662
+ train_outputs = {
663
+ 'vae': fc_plot,
664
+ 'importance': importance_plot,
665
+ 'prediction': prediction_plot,
666
+ 'learning': learning_plot
667
+ }
668
+
669
+ # Handle train button click
670
+ def handle_train(data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
671
+ prediction_type, outcome_variable):
672
+ # Ensure we have the necessary files before training
673
+ # This is a placeholder - in a real app you'd validate these files exist
674
+ demographic_file = os.path.join(data_dir, "demographics.csv")
675
+ treatment_file = os.path.join(data_dir, "treatment_outcomes.csv")
676
+
677
+ results = app.train_models(
678
+ data_dir=data_dir,
679
+ latent_dim=latent_dim,
680
+ nepochs=nepochs,
681
+ bsize=bsize
682
+ )
683
+
684
+ # Return plots in the expected order
685
+ return [
686
+ results.get('vae', None),
687
+ results.get('importance', None),
688
+ results.get('prediction', None),
689
+ results.get('learning', None)
690
+ ]
691
+
692
+ train_btn.click(
693
+ fn=handle_train,
694
+ inputs=[data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
695
+ prediction_type, outcome_variable],
696
+ outputs=[fc_plot, importance_plot, prediction_plot, learning_plot]
697
  )
698
 
699
+ predict_btn.click(
700
+ fn=app.predict_treatment,
701
+ inputs=[fmri_file, age, sex, months, wab],
702
+ outputs=[prediction_text, trajectory_plot]
 
 
 
 
 
 
 
 
 
 
 
703
  )
704
 
705
  # Add examples
706
  gr.Examples(
707
  examples=[
708
+ ["SreekarB/OSFData", 32, 200, 16, True, "regression", "wab_aq"], # Standard training
709
+ ["SreekarB/OSFData", 16, 100, 8, True, "classification", "wab_aq"] # Faster training with classification
710
  ],
711
+ inputs=[data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
712
+ prediction_type, outcome_variable],
713
  )
714
 
715
+ # Add explanation
716
  gr.Markdown("""
717
+ ## How to use this tool
718
+
719
+ 1. **Train Models Tab**: First train the VAE and Random Forest models using your dataset
720
+ - Use the default SreekarB/OSFData dataset or specify your own data source
721
+ - Adjust parameters like latent dimensions and training epochs
722
+ - Choose regression or classification prediction type
723
+ - Select which variable to predict (WAB score by default)
724
+
725
+ 2. **Predict Treatment Tab**: Use the trained models to predict treatment outcomes
726
+ - Upload a patient's fMRI scan or use synthetic data
727
+ - Enter the patient's demographic information
728
+ - Click "Predict Treatment Outcome" to see the projected treatment trajectory
729
+ - The visualization shows the predicted outcome with confidence intervals
730
+
731
+ ## Interpreting Results
732
 
733
+ - The **Feature Importance** plot shows which latent dimensions and demographic variables most strongly predict treatment outcomes
734
+ - The **Prediction Performance** plot shows how well the model predicts known outcomes
735
+ - The **Treatment Trajectory** shows the projected change in WAB score over the course of treatment
 
 
736
 
737
+ Note: For optimal results, train with at least 500 epochs and latent dimension of 32 or higher.
738
  """)
739
 
740
+ return interface
741
 
742
  if __name__ == "__main__":
743
+ interface = create_interface()
744
+ interface.launch(share=True)
 
config.py CHANGED
@@ -22,3 +22,12 @@ DATASET_CONFIG = {
22
  'split': 'train'
23
  }
24
 
 
 
 
 
 
 
 
 
 
 
22
  'split': 'train'
23
  }
24
 
25
+ # Prediction configuration
26
+ PREDICTION_CONFIG = {
27
+ 'n_estimators': 100,
28
+ 'max_depth': None,
29
+ 'cv_folds': 5,
30
+ 'prediction_type': 'regression',
31
+ 'default_outcome': 'wab_aq',
32
+ 'save_path': 'results/treatment_predictor.joblib'
33
+ }
data_preprocessing.py CHANGED
@@ -1,593 +1,93 @@
1
  import numpy as np
2
  import pandas as pd
3
- from datasets import load_dataset
4
  from nilearn import input_data, connectome
5
  from nilearn.image import load_img
6
  import nibabel as nib
7
- import os
 
8
 
9
- def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
10
  """
11
- Process fMRI data to generate functional connectivity matrices
 
 
 
 
 
12
 
13
- Parameters:
14
- - dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
15
- - demo_data: Optional demographic data, required if providing NIfTI files
16
- - demo_types: Optional demographic data types, required if providing NIfTI files
 
 
 
 
 
 
 
 
 
17
 
18
- Returns:
19
- - X: Array of FC matrices
20
- - demo_data: Demographic data
21
- - demo_types: Demographic data types
22
- """
23
- print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
24
 
25
- # For SreekarB/OSFData dataset, the data will be loaded from dataset features
26
- if isinstance(dataset_or_niifiles, str):
27
- dataset_name = dataset_or_niifiles
28
- print(f"Loading data from dataset: {dataset_name}")
29
- try:
30
- # Try multiple approaches to load the dataset
31
- approaches = [
32
- lambda: load_dataset(dataset_name, split="train"),
33
- lambda: load_dataset(dataset_name), # Try without split
34
- lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code
35
- lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
36
- ]
37
-
38
- dataset = None
39
- last_error = None
40
-
41
- for i, approach in enumerate(approaches):
42
- if approach is None:
43
- continue
44
-
45
- try:
46
- print(f"Attempt {i+1} to load dataset...")
47
- dataset = approach()
48
- print(f"Successfully loaded dataset with approach {i+1}!")
49
- break
50
- except Exception as e:
51
- print(f"Attempt {i+1} failed: {e}")
52
- last_error = e
53
-
54
- if dataset is None:
55
- print(f"All attempts to load dataset failed. Last error: {last_error}")
56
- raise ValueError(f"Could not load dataset {dataset_name}")
57
- except Exception as e:
58
- print(f"Error during dataset loading: {e}")
59
- raise
60
-
61
- # Prepare demographics data from the dataset
62
- if demo_data is None:
63
- # Create demo_data from the dataset
64
- demo_df = pd.DataFrame({
65
- 'age': dataset['age'],
66
- 'gender': dataset['gender'],
67
- 'mpo': dataset['mpo'],
68
- 'wab_aq': dataset['wab_aq']
69
- })
70
-
71
- demo_data = [
72
- demo_df['age'].values,
73
- demo_df['gender'].values,
74
- demo_df['mpo'].values,
75
- demo_df['wab_aq'].values
76
- ]
77
-
78
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
79
-
80
- # Look for NIfTI files in P01_rs.nii format
81
- print("Searching for NIfTI files in dataset columns...")
82
- nii_files = []
83
-
84
- # Create a temp directory for downloads
85
- import tempfile
86
- from huggingface_hub import hf_hub_download
87
- import shutil
88
-
89
- temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
90
- print(f"Created temporary directory for NIfTI files: {temp_dir}")
91
-
92
- try:
93
- # First approach: Check if there are any columns containing file paths
94
- nii_columns = []
95
- for col in dataset.column_names:
96
- # Check if column name suggests NIfTI files
97
- if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
98
- nii_columns.append(col)
99
- # Or check if column contains file paths
100
- elif len(dataset) > 0:
101
- first_val = dataset[0][col]
102
- if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
103
- nii_columns.append(col)
104
-
105
- if nii_columns:
106
- print(f"Found columns that may contain NIfTI files: {nii_columns}")
107
-
108
- for col in nii_columns:
109
- print(f"Processing column '{col}'...")
110
-
111
- for i, item in enumerate(dataset[col]):
112
- if not isinstance(item, str):
113
- print(f"Item {i} in column {col} is not a string but {type(item)}")
114
- continue
115
-
116
- if not (item.endswith('.nii') or item.endswith('.nii.gz')):
117
- print(f"Item {i} in column {col} is not a NIfTI file: {item}")
118
- continue
119
-
120
- print(f"Downloading {item} from dataset {dataset_name}...")
121
-
122
- try:
123
- # Attempt to download with explicit filename
124
- file_path = hf_hub_download(
125
- repo_id=dataset_name,
126
- filename=item,
127
- repo_type="dataset",
128
- cache_dir=temp_dir
129
- )
130
- nii_files.append(file_path)
131
- print(f"✓ Successfully downloaded {item}")
132
- except Exception as e1:
133
- print(f"Error downloading with explicit filename: {e1}")
134
-
135
- # Second attempt: try with the item's basename
136
- try:
137
- basename = os.path.basename(item)
138
- print(f"Trying with basename: {basename}")
139
- file_path = hf_hub_download(
140
- repo_id=dataset_name,
141
- filename=basename,
142
- repo_type="dataset",
143
- cache_dir=temp_dir
144
- )
145
- nii_files.append(file_path)
146
- print(f"✓ Successfully downloaded {basename}")
147
- except Exception as e2:
148
- print(f"Error downloading with basename: {e2}")
149
-
150
- # Third attempt: check if it's a binary blob in the dataset
151
- try:
152
- if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
153
- print("Found binary data in dataset, saving to temporary file...")
154
- binary_data = dataset[i]['bytes']
155
- temp_file = os.path.join(temp_dir, basename)
156
- with open(temp_file, 'wb') as f:
157
- f.write(binary_data)
158
- nii_files.append(temp_file)
159
- print(f"✓ Saved binary data to {temp_file}")
160
- except Exception as e3:
161
- print(f"Error handling binary data: {e3}")
162
-
163
- # Last resort: look for the file locally
164
- local_path = os.path.join(os.getcwd(), item)
165
- if os.path.exists(local_path):
166
- nii_files.append(local_path)
167
- print(f"✓ Found {item} locally")
168
- else:
169
- print(f"❌ Warning: Could not find {item} anywhere")
170
-
171
- # Second approach: Try to find NIfTI files in dataset repository directly
172
- if not nii_files:
173
- print("No NIfTI files found in dataset columns. Trying direct repository search...")
174
-
175
- try:
176
- from huggingface_hub import list_repo_files, hf_hub_download
177
-
178
- # Try to list all files in the repository
179
- try:
180
- print("Listing all repository files...")
181
- all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
182
- print(f"Found {len(all_repo_files)} files in repository")
183
-
184
- # First prioritize P*_rs.nii files
185
- p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
186
-
187
- # Then include all other NIfTI files
188
- other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
189
-
190
- # Combine, with P*_rs.nii files first
191
- nii_repo_files = p_rs_files + other_nii_files
192
-
193
- if nii_repo_files:
194
- print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
195
-
196
- # Download each file
197
- for nii_file in nii_repo_files:
198
- try:
199
- file_path = hf_hub_download(
200
- repo_id=dataset_name,
201
- filename=nii_file,
202
- repo_type="dataset",
203
- cache_dir=temp_dir
204
- )
205
- nii_files.append(file_path)
206
- print(f"✓ Downloaded {nii_file}")
207
- except Exception as e:
208
- print(f"Error downloading {nii_file}: {e}")
209
- except Exception as e:
210
- print(f"Error listing repository files: {e}")
211
- print("Will try alternative approaches...")
212
-
213
- # If repo listing fails, try with common NIfTI file patterns directly
214
- if not nii_files:
215
- print("Trying common NIfTI file patterns...")
216
-
217
- # Focus specifically on P*_rs.nii pattern
218
- patterns = []
219
-
220
- # Generate P01_rs.nii through P30_rs.nii
221
- for i in range(1, 31): # Try subjects 1-30
222
- patterns.append(f"P{i:02d}_rs.nii")
223
-
224
- # Also try with .nii.gz extension
225
- for i in range(1, 31):
226
- patterns.append(f"P{i:02d}_rs.nii.gz")
227
-
228
- # Include a few other common patterns as fallbacks
229
- patterns.extend([
230
- "sub-01_task-rest_bold.nii.gz", # BIDS format
231
- "fmri.nii.gz", "bold.nii.gz",
232
- "rest.nii.gz"
233
- ])
234
-
235
- for pattern in patterns:
236
- try:
237
- print(f"Trying to download {pattern}...")
238
- file_path = hf_hub_download(
239
- repo_id=dataset_name,
240
- filename=pattern,
241
- repo_type="dataset",
242
- cache_dir=temp_dir
243
- )
244
- nii_files.append(file_path)
245
- print(f"✓ Successfully downloaded {pattern}")
246
- except Exception as e:
247
- print(f"× Failed to download {pattern}")
248
-
249
- # If we still couldn't find any files, check if data files are nested
250
- if not nii_files:
251
- print("Checking for nested data files...")
252
- nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
253
-
254
- for path in nested_paths:
255
- for pattern in patterns:
256
- nested_file = f"{path}{pattern}"
257
- try:
258
- print(f"Trying to download {nested_file}...")
259
- file_path = hf_hub_download(
260
- repo_id=dataset_name,
261
- filename=nested_file,
262
- repo_type="dataset",
263
- cache_dir=temp_dir
264
- )
265
- nii_files.append(file_path)
266
- print(f"✓ Successfully downloaded {nested_file}")
267
- # If we found one file in this directory, try to find all files in it
268
- try:
269
- all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
270
- nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
271
- print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
272
-
273
- for nii_file in nii_files_in_dir:
274
- if nii_file != nested_file: # Skip the one we already downloaded
275
- try:
276
- file_path = hf_hub_download(
277
- repo_id=dataset_name,
278
- filename=nii_file,
279
- repo_type="dataset",
280
- cache_dir=temp_dir
281
- )
282
- nii_files.append(file_path)
283
- print(f"✓ Downloaded {nii_file}")
284
- except Exception as e:
285
- print(f"Error downloading {nii_file}: {e}")
286
- except Exception as e:
287
- print(f"Error finding additional files in {path}: {e}")
288
- except Exception as e:
289
- pass
290
-
291
- except Exception as e:
292
- print(f"Error during repository exploration: {e}")
293
-
294
- # If we still don't have any files, try to search for P*_rs.nii pattern specifically
295
- if not nii_files:
296
- print("Trying to find files matching P*_rs.nii pattern specifically...")
297
-
298
- try:
299
- # List all files in the repository (if we haven't already)
300
- if not 'all_repo_files' in locals():
301
- from huggingface_hub import list_repo_files
302
- try:
303
- all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
304
- except Exception as e:
305
- print(f"Error listing repo files: {e}")
306
- all_repo_files = []
307
-
308
- # Look for files matching the pattern exactly (P*_rs.nii)
309
- pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
310
-
311
- # If we don't find any exact matches, try a more relaxed pattern
312
- if not pattern_files:
313
- pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
314
-
315
- if pattern_files:
316
- print(f"Found {len(pattern_files)} files matching rs.nii pattern")
317
-
318
- # Download each file
319
- for pattern_file in pattern_files:
320
- try:
321
- file_path = hf_hub_download(
322
- repo_id=dataset_name,
323
- filename=pattern_file,
324
- repo_type="dataset",
325
- cache_dir=temp_dir
326
- )
327
- nii_files.append(file_path)
328
- print(f"✓ Downloaded {pattern_file}")
329
- except Exception as e:
330
- print(f"Error downloading {pattern_file}: {e}")
331
- except Exception as e:
332
- print(f"Error searching for pattern files: {e}")
333
-
334
- print(f"Found total of {len(nii_files)} NIfTI files")
335
- except Exception as e:
336
- print(f"Unexpected error during NIfTI file search: {e}")
337
- import traceback
338
- traceback.print_exc()
339
-
340
- # If we found NIfTI files, process them to FC matrices
341
- if nii_files:
342
- print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
343
-
344
- # Load Power 264 atlas
345
- from nilearn import datasets
346
- power = datasets.fetch_coords_power_2011()
347
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
348
-
349
- masker = input_data.NiftiSpheresMasker(
350
- coords, radius=5,
351
- standardize=True,
352
- memory='nilearn_cache', memory_level=1,
353
- verbose=0,
354
- detrend=True,
355
- low_pass=0.1,
356
- high_pass=0.01,
357
- t_r=2.0 # Adjust TR according to your data
358
- )
359
-
360
- # Process fMRI data and compute FC matrices
361
- fc_matrices = []
362
- valid_files = 0
363
- total_files = len(nii_files)
364
-
365
- for nii_file in nii_files:
366
- try:
367
- print(f"Processing {nii_file}...")
368
- fmri_img = load_img(nii_file)
369
-
370
- # Check image dimensions
371
- if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
372
- print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
373
- continue
374
-
375
- try:
376
- # Explicitly handle warnings about empty spheres
377
- import warnings
378
- with warnings.catch_warnings():
379
- warnings.filterwarnings('ignore', message='.*empty.*')
380
- time_series = masker.fit_transform(fmri_img)
381
- except Exception as e:
382
- if "empty" in str(e):
383
- print(f"Warning: Some spheres are empty in {nii_file}. Using a different sphere radius.")
384
-
385
- # Extract the list of empty spheres for logging
386
- import re
387
- empty_spheres = re.findall(r"\[(.*?)\]", str(e))
388
- if empty_spheres:
389
- print(f"Empty spheres: {empty_spheres[0]}")
390
-
391
- # Try with a different radius
392
- alternate_masker = input_data.NiftiSpheresMasker(
393
- coords, radius=8, # Larger radius
394
- standardize=True,
395
- memory='nilearn_cache', memory_level=1,
396
- verbose=0,
397
- detrend=True,
398
- low_pass=0.1,
399
- high_pass=0.01,
400
- t_r=2.0
401
- )
402
- try:
403
- time_series = alternate_masker.fit_transform(fmri_img)
404
- print(f"Successfully extracted time series with larger radius")
405
- except Exception as e2:
406
- print(f"Error with alternate masker: {e2}")
407
- print(f"Skipping this file due to empty spheres")
408
- continue # Skip this file entirely
409
- else:
410
- print(f"Unknown error in masker: {e}")
411
- continue # Skip this file if there's any other error
412
-
413
- # Validate time series data
414
- if np.isnan(time_series).any() or np.isinf(time_series).any():
415
- print(f"Warning: {nii_file} contains NaN or Inf values after masking")
416
- # Replace NaNs with zeros for this file
417
- time_series = np.nan_to_num(time_series)
418
-
419
- correlation_measure = connectome.ConnectivityMeasure(
420
- kind='correlation',
421
- vectorize=False,
422
- discard_diagonal=False
423
- )
424
-
425
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
426
-
427
- # Check for invalid correlation values
428
- if np.isnan(fc_matrix).any():
429
- print(f"Warning: {nii_file} produced NaN correlation values")
430
- continue
431
-
432
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
433
- fc_triu = fc_matrix[triu_indices]
434
-
435
- # Fisher z-transform with proper bounds check
436
- # Clip correlation values to valid range for arctanh
437
- fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
438
- fc_triu = np.arctanh(fc_triu_clipped)
439
-
440
- fc_matrices.append(fc_triu)
441
- valid_files += 1
442
- print(f"Successfully processed {nii_file} to FC matrix")
443
-
444
- except Exception as e:
445
- print(f"Error processing {nii_file}: {e}")
446
-
447
- if fc_matrices:
448
- print(f"Successfully processed {valid_files} out of {total_files} files")
449
-
450
- # Ensure all matrices have the same dimensions
451
- dims = [m.shape[0] for m in fc_matrices]
452
- if len(set(dims)) > 1:
453
- print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
454
- # Use the most common dimension
455
- from collections import Counter
456
- most_common_dim = Counter(dims).most_common(1)[0][0]
457
- print(f"Using most common dimension: {most_common_dim}")
458
- fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
459
-
460
- X = np.array(fc_matrices)
461
-
462
- # Normalize the FC data
463
- mean_x = np.mean(X, axis=0)
464
- std_x = np.std(X, axis=0)
465
-
466
- # Handle zero standard deviation
467
- std_x[std_x == 0] = 1.0
468
-
469
- X = (X - mean_x) / std_x
470
- print(f"Created FC matrices with shape {X.shape}")
471
-
472
- # Make sure demo_data matches the number of FC matrices
473
- if len(demo_data[0]) != X.shape[0]:
474
- print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
475
- f"doesn't match number of FC matrices ({X.shape[0]})")
476
- # Adjust demo_data to match FC matrices
477
- indices = list(range(min(len(demo_data[0]), X.shape[0])))
478
- X = X[indices]
479
- demo_data = [d[indices] for d in demo_data]
480
-
481
- return X, demo_data, demo_types
482
-
483
- print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
484
- # Return a placeholder with the right demographics but empty FC
485
- n_subjects = len(dataset)
486
- n_rois = 264
487
- fc_dim = (n_rois * (n_rois - 1)) // 2
488
- X = np.zeros((n_subjects, fc_dim))
489
- print(f"Created placeholder FC matrices with shape {X.shape}")
490
- return X, demo_data, demo_types
491
-
492
- elif isinstance(dataset_or_niifiles, str):
493
- # Handle real dataset with actual fMRI data
494
- dataset = load_dataset(dataset_or_niifiles, split="train")
495
-
496
- # Load Power 264 atlas
497
- from nilearn import datasets
498
- power = datasets.fetch_coords_power_2011()
499
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
500
-
501
- masker = input_data.NiftiSpheresMasker(
502
- coords, radius=5,
503
- standardize=True,
504
- memory='nilearn_cache', memory_level=1,
505
- verbose=0,
506
- detrend=True,
507
- low_pass=0.1,
508
- high_pass=0.01,
509
- t_r=2.0 # Adjust TR according to your data
510
- )
511
 
512
- # Load demographic data if needed
513
- if demo_data is None:
514
- if 'demographics' in dataset.features:
515
- demo_df = pd.DataFrame(dataset['demographics'])
516
-
517
- demo_data = [
518
- demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
519
- demo_df['sex'].values if 'sex' in demo_df.columns else [],
520
- demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
521
- demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
522
- ]
523
-
524
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
525
-
526
- # Process fMRI data and compute FC matrices
527
- fc_matrices = []
528
- for nii_file in dataset['nii_files']:
529
- fmri_img = load_img(nii_file)
530
- time_series = masker.fit_transform(fmri_img)
531
-
532
- correlation_measure = connectome.ConnectivityMeasure(
533
- kind='correlation', vectorize=False, discard_diagonal=False
534
- )
535
-
536
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
537
-
538
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
539
- fc_triu = fc_matrix[triu_indices]
540
-
541
- fc_triu = np.arctanh(fc_triu) # Fisher z-transform
542
-
543
- fc_matrices.append(fc_triu)
544
-
545
- X = np.array(fc_matrices)
546
-
547
- elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
548
- # Handle a list of NIfTI files
549
- # Similar processing as above but with local files
550
- print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
551
-
552
- # Load Power 264 atlas
553
- from nilearn import datasets
554
- power = datasets.fetch_coords_power_2011()
555
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
556
-
557
- masker = input_data.NiftiSpheresMasker(
558
- coords, radius=5,
559
- standardize=True,
560
- memory='nilearn_cache', memory_level=1,
561
- verbose=0,
562
- detrend=True,
563
- low_pass=0.1,
564
- high_pass=0.01,
565
- t_r=2.0
566
- )
567
-
568
- fc_matrices = []
569
- for nii_file in dataset_or_niifiles:
570
- fmri_img = load_img(nii_file)
571
- time_series = masker.fit_transform(fmri_img)
572
-
573
- correlation_measure = connectome.ConnectivityMeasure(
574
- kind='correlation', vectorize=False, discard_diagonal=False
575
- )
576
-
577
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
578
-
579
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
580
- fc_triu = fc_matrix[triu_indices]
581
-
582
- fc_triu = np.arctanh(fc_triu) # Fisher z-transform
583
-
584
- fc_matrices.append(fc_triu)
585
-
586
- X = np.array(fc_matrices)
587
- else:
588
- raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
589
 
590
  # Normalize the FC data
591
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
592
 
593
- return X, demo_data, demo_types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import pandas as pd
 
3
  from nilearn import input_data, connectome
4
  from nilearn.image import load_img
5
  import nibabel as nib
6
+ from pathlib import Path
7
+ from config import PREPROCESS_CONFIG
8
 
9
+ def process_single_fmri(fmri_file):
10
  """
11
+ Process a single fMRI file to FC matrix
12
+ """
13
+ # Use Power 264 atlas
14
+ from nilearn import datasets
15
+ power = datasets.fetch_coords_power_2011()
16
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
17
 
18
+ # Create masker
19
+ masker = input_data.NiftiSpheresMasker(
20
+ coords,
21
+ radius=PREPROCESS_CONFIG['radius'],
22
+ standardize=True,
23
+ memory='nilearn_cache',
24
+ memory_level=1,
25
+ verbose=0,
26
+ detrend=True,
27
+ low_pass=PREPROCESS_CONFIG['low_pass'],
28
+ high_pass=PREPROCESS_CONFIG['high_pass'],
29
+ t_r=PREPROCESS_CONFIG['t_r']
30
+ )
31
 
32
+ # Load and process fMRI
33
+ fmri_img = load_img(fmri_file)
34
+ time_series = masker.fit_transform(fmri_img)
 
 
 
35
 
36
+ # Compute FC matrix
37
+ correlation_measure = connectome.ConnectivityMeasure(
38
+ kind='correlation',
39
+ vectorize=False,
40
+ discard_diagonal=False
41
+ )
42
+
43
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
44
+
45
+ # Get upper triangular part
46
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
47
+ fc_triu = fc_matrix[triu_indices]
48
+
49
+ # Fisher z-transform
50
+ fc_triu = np.arctanh(fc_triu)
51
+
52
+ return fc_triu
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
+ def preprocess_fmri_to_fc(nii_files, demo_data, demo_types):
55
+ """
56
+ Convert multiple fMRI files to FC matrices
57
+ """
58
+ fc_matrices = []
59
+
60
+ for nii_file in nii_files:
61
+ fc_triu = process_single_fmri(nii_file)
62
+ fc_matrices.append(fc_triu)
63
+
64
+ X = np.array(fc_matrices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Normalize the FC data
67
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
68
 
69
+ return X, demo_data, demo_types
70
+
71
+ def load_and_preprocess_data(data_dir, demographic_file):
72
+ """
73
+ Load and preprocess both fMRI data and demographics
74
+ """
75
+ # Load demographics
76
+ demo_df = pd.read_csv(demographic_file)
77
+
78
+ demo_data = [
79
+ demo_df['age_at_stroke'].values,
80
+ demo_df['sex'].values,
81
+ demo_df['months_post_stroke'].values,
82
+ demo_df['wab_score'].values
83
+ ]
84
+
85
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
86
+
87
+ # Load fMRI files
88
+ nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
89
+
90
+ # Process fMRI files to FC matrices
91
+ X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
92
+
93
+ return X, demo_data, demo_types
main.py CHANGED
@@ -1,291 +1,150 @@
1
  import os
2
- import sys
3
- # Add the src directory to the path so we can import from demovae
4
- sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
5
-
6
  import numpy as np
7
  import torch
8
  from pathlib import Path
9
- import nibabel as nib
10
- from data_preprocessing import preprocess_fmri_to_fc
11
- from src.demovae.sklearn import DemoVAE
12
- from analysis import analyze_fc_patterns
13
- from visualization import visualize_fc_analysis
14
- from config import MODEL_CONFIG, DATASET_CONFIG
15
  import pandas as pd
16
- import io
17
- from typing import List, Dict, Union, Tuple, Any
 
 
 
 
18
 
19
- def train_fc_vae(X, demo_data, demo_types, model_config):
 
 
 
 
 
 
20
  """
21
- Train a VAE model on functional connectivity matrices
22
  """
23
- n_rois = 264
24
- input_dim = (n_rois * (n_rois - 1)) // 2
 
 
 
 
25
 
26
- print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
 
 
27
 
28
- # Ensure X is a numpy array with correct data type
29
- if not isinstance(X, np.ndarray):
30
- print(f"Converting X from {type(X)} to numpy array")
31
- X = np.array(X, dtype=np.float32)
32
 
33
- # Ensure demo_data contains numpy arrays
34
- for i, d in enumerate(demo_data):
35
- if not isinstance(d, np.ndarray):
36
- print(f"Converting demographic {i} from {type(d)} to numpy array")
37
- demo_data[i] = np.array(d)
38
 
39
- # Check for NaN or Inf values
40
- if np.isnan(X).any() or np.isinf(X).any():
41
- print("Warning: X contains NaN or Inf values. Replacing with zeros.")
42
- X = np.nan_to_num(X)
43
 
44
- # Create the VAE model
45
- vae = DemoVAE(
46
- latent_dim=model_config['latent_dim'],
47
- nepochs=model_config['nepochs'],
48
- bsize=model_config['bsize'],
49
- loss_rec_mult=model_config.get('loss_rec_mult', 100),
50
- loss_decor_mult=model_config.get('loss_decor_mult', 10),
51
- lr=model_config.get('lr', 1e-4),
52
- use_cuda=torch.cuda.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
- print("Fitting VAE model...")
56
- vae.fit(X, demo_data, demo_types)
57
 
58
- return vae, X, demo_data, demo_types
59
-
60
- def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
61
- """
62
- Load fMRI data and demographics from HuggingFace dataset or local files
63
- """
64
- if use_hf_dataset:
65
- # Load from HuggingFace Datasets
66
- from datasets import load_dataset
67
-
68
- print(f"Loading dataset from HuggingFace: {data_dir}")
69
- dataset = load_dataset(data_dir)
70
-
71
- print(f"Dataset columns: {dataset['train'].column_names}")
72
-
73
- # Get demographics directly from the dataset
74
- # Create a DataFrame from the dataset features
75
- demo_df = pd.DataFrame({
76
- 'ID': dataset['train']['ID'],
77
- 'wab_aq': dataset['train']['wab_aq'],
78
- 'age': dataset['train']['age'],
79
- 'mpo': dataset['train']['mpo'],
80
- 'education': dataset['train']['education'],
81
- 'gender': dataset['train']['gender'],
82
- 'handedness': dataset['train']['handedness']
83
- })
84
-
85
- print(f"Loaded demographic data with {len(demo_df)} subjects")
86
-
87
- # Extract demographic data matching our expected format
88
- # Map the dataset columns to our expected format
89
- demo_data = [
90
- demo_df['age'].values, # age at stroke -> age
91
- demo_df['gender'].values, # sex -> gender
92
- demo_df['mpo'].values, # months post stroke -> mpo
93
- demo_df['wab_aq'].values # wab score -> wab_aq
94
- ]
95
-
96
- # Check for FC matrices in the dataset
97
- fc_columns = []
98
- for col in dataset['train'].column_names:
99
- if col.startswith("fc_") or "_fc" in col:
100
- fc_columns.append(col)
101
-
102
- if fc_columns:
103
- print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
104
- # Extract FC matrices
105
- fc_matrices = []
106
- for fc_col in fc_columns:
107
- fc_matrices.append(dataset['train'][fc_col])
108
-
109
- # If we have FC matrices, return them directly
110
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
111
- return fc_matrices, demo_data, demo_types
112
-
113
- # If no FC matrices, look for .nii files
114
- nii_files = []
115
- for col in dataset['train'].column_names:
116
- if col.endswith(".nii.gz") or col.endswith(".nii"):
117
- nii_files.append(dataset['train'][col])
118
-
119
- if nii_files:
120
- print(f"Found {len(nii_files)} .nii files")
121
- else:
122
- print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
123
- # If no structured data is found, we can try to download raw files later
124
-
125
- else:
126
- # Original local file loading
127
- # Load demographics
128
- demo_df = pd.read_csv(demographic_file)
129
-
130
- demo_data = [
131
- demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
132
- demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
133
- demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
134
- demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
135
- ]
136
-
137
- # Load fMRI files
138
- nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
139
 
140
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
141
- return nii_files, demo_data, demo_types
142
-
143
- def run_fc_analysis(data_dir="SreekarB/OSFData",
144
- demographic_file=None,
145
- latent_dim=32,
146
- nepochs=1000,
147
- bsize=16,
148
- save_model=True,
149
- use_hf_dataset=True,
150
- return_data=False):
151
 
152
- # Update MODEL_CONFIG with user-specified parameters
153
- MODEL_CONFIG.update({
154
- 'latent_dim': latent_dim,
155
- 'nepochs': nepochs,
156
- 'bsize': bsize
157
- })
158
 
159
- try:
160
- # Load data
161
- print("Loading data...")
162
- nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
163
-
164
- # For SreekarB/OSFData, directly generate synthetic FC matrices
165
- if data_dir == "SreekarB/OSFData" and use_hf_dataset:
166
- print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
167
- X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
168
- # Check if we got FC matrices directly
169
- elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
170
- print("Using pre-computed FC matrices...")
171
- # Convert list of FC matrices to numpy array
172
- X = np.stack([np.array(fc) for fc in nii_files])
173
- else:
174
- # Prepare data by converting fMRI to FC matrices
175
- print("Converting fMRI data to FC matrices...")
176
- X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
177
-
178
- # Print shapes and data types
179
- print(f"X shape: {X.shape}, type: {type(X)}")
180
- for i, d in enumerate(demo_data):
181
- print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
182
-
183
- # Train VAE and get data
184
- print("Training VAE...")
185
- try:
186
- # Use the proper DemoVAE implementation from src/demovae/sklearn.py
187
- vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
188
-
189
- if save_model:
190
- print("Saving model...")
191
- os.makedirs('models', exist_ok=True)
192
- # Use the save method from DemoVAE
193
- vae.save('models/vae_model.pth')
194
- print("Model saved successfully.")
195
- except Exception as e:
196
- print(f"Error during VAE training: {e}")
197
- raise
198
-
199
- # Get latent representations
200
- print("Getting latent representations...")
201
- latents = vae.get_latents(X)
202
-
203
- # Analyze results
204
- print("Analyzing demographic relationships...")
205
- demographics = {
206
- 'age': demo_data[0],
207
- 'months_post_onset': demo_data[2],
208
- 'wab_aq': demo_data[3]
209
  }
210
- analysis_results = analyze_fc_patterns(latents, demographics)
211
-
212
- # Generate new FC matrix
213
- print("Generating new FC matrices...")
214
-
215
- # Get data types from original demographic data for proper conversion
216
- demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
217
-
218
- # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
219
- new_demographics = [
220
- np.array([60.0], dtype=np.float64), # age
221
- np.array(['M'], dtype=np.str_), # gender
222
- np.array([12.0], dtype=np.float64), # months post onset
223
- np.array([80.0], dtype=np.float64) # wab score
224
- ]
225
-
226
- # Verify the demographic data arrays match the expected types
227
- print("Demographic data types:")
228
- for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
229
- print(f" {name}: shape={data.shape}, dtype={data.dtype}")
230
-
231
- print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
232
- try:
233
- generated_fc = vae.transform(1, new_demographics, demo_types)
234
- except Exception as e:
235
- print(f"Error generating new FC matrix: {e}")
236
- # Try with a fallback approach
237
- print("Trying alternative generation approach...")
238
- # If specific gender is causing issues, try the first gender from training data
239
- new_demographics[1] = np.array([demo_data[1][0]])
240
- generated_fc = vae.transform(1, new_demographics, demo_types)
241
- reconstructed_fc = vae.transform(X, demo_data, demo_types)
242
-
243
- # Visualize results
244
- print("Creating visualizations...")
245
- fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
246
-
247
- # If requested, return additional data for accuracy calculations
248
- if return_data:
249
- results = {
250
- 'vae': vae,
251
- 'X': X,
252
- 'latents': latents,
253
- 'demographics': demographics,
254
- 'reconstructed_fc': reconstructed_fc,
255
- 'generated_fc': generated_fc,
256
- 'analysis_results': analysis_results
257
- }
258
- return fig, results
259
-
260
- return fig
261
-
262
- except Exception as e:
263
- import traceback
264
- print(f"Error in run_fc_analysis: {str(e)}")
265
- print(traceback.format_exc())
266
-
267
- # Create a dummy figure with error message
268
- import matplotlib.pyplot as plt
269
- fig = plt.figure(figsize=(10, 6))
270
- plt.text(0.5, 0.5, f"Error: {str(e)}",
271
- horizontalalignment='center', verticalalignment='center',
272
- fontsize=12, color='red')
273
- plt.axis('off')
274
-
275
- # Return the error figure and empty results if requested
276
- if return_data:
277
- return fig, None
278
-
279
- return fig
280
 
281
  if __name__ == "__main__":
282
  import argparse
283
 
284
- parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
285
- parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
286
- help='HuggingFace dataset ID or directory containing fMRI data')
287
- parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
288
  help='Path to demographic data CSV file')
 
 
289
  parser.add_argument('--latent_dim', type=int, default=32,
290
  help='Dimension of latent space')
291
  parser.add_argument('--nepochs', type=int, default=1000,
@@ -293,20 +152,16 @@ if __name__ == "__main__":
293
  parser.add_argument('--bsize', type=int, default=16,
294
  help='Batch size for training')
295
  parser.add_argument('--no_save', action='store_false',
296
- help='Do not save the model')
297
- parser.add_argument('--use_local', action='store_true',
298
- help='Use local data instead of HuggingFace dataset')
299
 
300
  args = parser.parse_args()
301
 
302
- fig = run_fc_analysis(
303
  data_dir=args.data_dir,
304
  demographic_file=args.demographic_file,
 
305
  latent_dim=args.latent_dim,
306
  nepochs=args.nepochs,
307
  bsize=args.bsize,
308
- save_model=args.no_save,
309
- use_hf_dataset=not args.use_local
310
  )
311
- fig.show()
312
-
 
1
  import os
 
 
 
 
2
  import numpy as np
3
  import torch
4
  from pathlib import Path
 
 
 
 
 
 
5
  import pandas as pd
6
+ from data_preprocessing import load_and_preprocess_data
7
+ from vae_model import DemoVAE
8
+ from rcf_prediction import AphasiaTreatmentPredictor
9
+ from visualization import plot_fc_matrices, plot_learning_curves
10
+ from config import MODEL_CONFIG
11
+ import matplotlib.pyplot as plt
12
 
13
+ def run_analysis(data_dir="data",
14
+ demographic_file="demographics.csv",
15
+ treatment_file="treatment_outcomes.csv",
16
+ latent_dim=32,
17
+ nepochs=1000,
18
+ bsize=16,
19
+ save_model=True):
20
  """
21
+ Run the complete analysis pipeline
22
  """
23
+ # Update MODEL_CONFIG with user-specified parameters
24
+ MODEL_CONFIG.update({
25
+ 'latent_dim': latent_dim,
26
+ 'nepochs': nepochs,
27
+ 'bsize': bsize
28
+ })
29
 
30
+ # Create output directories
31
+ os.makedirs('models', exist_ok=True)
32
+ os.makedirs('results', exist_ok=True)
33
 
34
+ # Load and preprocess data
35
+ print("Loading and preprocessing data...")
36
+ X, demo_data, demo_types = load_and_preprocess_data(data_dir, demographic_file)
 
37
 
38
+ # Load treatment outcomes
39
+ treatment_df = pd.read_csv(treatment_file)
40
+ treatment_outcomes = treatment_df['outcome_score'].values
 
 
41
 
42
+ # Initialize and train VAE
43
+ print("Training VAE...")
44
+ vae = DemoVAE(**MODEL_CONFIG)
45
+ train_losses, val_losses = vae.fit(X, demo_data, demo_types)
46
 
47
+ # Get latent representations
48
+ print("Extracting latent representations...")
49
+ latents = vae.get_latents(X)
50
+
51
+ # Initialize and train treatment predictor
52
+ print("Training treatment predictor...")
53
+ predictor = AphasiaTreatmentPredictor(n_estimators=100)
54
+
55
+ # Prepare demographics for predictor
56
+ demographics = {
57
+ 'age_at_stroke': demo_data[0],
58
+ 'sex': demo_data[1],
59
+ 'months_post_stroke': demo_data[2],
60
+ 'wab_score': demo_data[3]
61
+ }
62
+
63
+ # Cross-validate the predictor
64
+ print("Performing cross-validation...")
65
+ cv_mean, cv_std, predictions, prediction_stds = predictor.cross_validate(
66
+ latents=latents,
67
+ demographics=demographics,
68
+ treatment_outcomes=treatment_outcomes
69
  )
70
 
71
+ # Fit final predictor model
72
+ predictor.fit(latents, demographics, treatment_outcomes)
73
 
74
+ # Save models if requested
75
+ if save_model:
76
+ print("Saving models...")
77
+ vae.save('models/vae_model.pt')
78
+ torch.save({
79
+ 'predictor_state': predictor.rf_regressor,
80
+ 'feature_importance': predictor.feature_importance
81
+ }, 'models/predictor_model.pt')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
+ # Generate visualizations
84
+ print("Generating visualizations...")
 
 
 
 
 
 
 
 
 
85
 
86
+ # FC matrix visualization
87
+ reconstructed = vae.transform(X, demo_data, demo_types)
88
+ generated = vae.transform(1,
89
+ [d[:1] for d in demo_data],
90
+ demo_types)
91
+ fc_fig = plot_fc_matrices(X[0], reconstructed[0], generated[0])
92
 
93
+ # Learning curves
94
+ learning_fig = plot_learning_curves(train_losses, val_losses)
95
+
96
+ # Feature importance
97
+ importance_fig = predictor.plot_feature_importance()
98
+
99
+ # Prediction performance
100
+ performance_fig = plt.figure(figsize=(8, 6))
101
+ plt.scatter(treatment_outcomes, predictions)
102
+ plt.plot([min(treatment_outcomes), max(treatment_outcomes)],
103
+ [min(treatment_outcomes), max(treatment_outcomes)],
104
+ 'r--')
105
+ plt.fill_between(treatment_outcomes,
106
+ predictions - 2*prediction_stds,
107
+ predictions + 2*prediction_stds,
108
+ alpha=0.2, color='gray')
109
+ plt.xlabel('Actual Outcome')
110
+ plt.ylabel('Predicted Outcome')
111
+ plt.title(f'Treatment Outcome Prediction\nR² = {cv_mean:.3f} ± {cv_std:.3f}')
112
+ plt.tight_layout()
113
+
114
+ # Save results
115
+ print("Saving results...")
116
+ np.save('results/latents.npy', latents)
117
+ np.save('results/predictions.npy', predictions)
118
+ np.save('results/prediction_stds.npy', prediction_stds)
119
+
120
+ results = {
121
+ 'vae': vae,
122
+ 'predictor': predictor,
123
+ 'latents': latents,
124
+ 'cv_scores': (cv_mean, cv_std),
125
+ 'predictions': predictions,
126
+ 'prediction_stds': prediction_stds,
127
+ 'figures': {
128
+ 'fc_analysis': fc_fig,
129
+ 'learning_curves': learning_fig,
130
+ 'importance': importance_fig,
131
+ 'performance': performance_fig
 
 
 
 
 
 
 
 
 
 
 
132
  }
133
+ }
134
+
135
+ print("Analysis complete!")
136
+ return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  if __name__ == "__main__":
139
  import argparse
140
 
141
+ parser = argparse.ArgumentParser(description='Run Aphasia Treatment Analysis')
142
+ parser.add_argument('--data_dir', type=str, default='data',
143
+ help='Directory containing fMRI data')
144
+ parser.add_argument('--demographic_file', type=str, default='demographics.csv',
145
  help='Path to demographic data CSV file')
146
+ parser.add_argument('--treatment_file', type=str, default='treatment_outcomes.csv',
147
+ help='Path to treatment outcomes CSV file')
148
  parser.add_argument('--latent_dim', type=int, default=32,
149
  help='Dimension of latent space')
150
  parser.add_argument('--nepochs', type=int, default=1000,
 
152
  parser.add_argument('--bsize', type=int, default=16,
153
  help='Batch size for training')
154
  parser.add_argument('--no_save', action='store_false',
155
+ help='Do not save the models')
 
 
156
 
157
  args = parser.parse_args()
158
 
159
+ results = run_analysis(
160
  data_dir=args.data_dir,
161
  demographic_file=args.demographic_file,
162
+ treatment_file=args.treatment_file,
163
  latent_dim=args.latent_dim,
164
  nepochs=args.nepochs,
165
  bsize=args.bsize,
166
+ save_model=args.no_save
 
167
  )
 
 
rcf_prediction.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
3
+ from sklearn.model_selection import cross_val_score, KFold
4
+ import pandas as pd
5
+ from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score
6
+ import matplotlib.pyplot as plt
7
+ import os
8
+ import joblib
9
+ import logging
10
+
11
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
12
+ logger = logging.getLogger(__name__)
13
+
14
+ class AphasiaTreatmentPredictor:
15
+ def __init__(self, prediction_type="regression", n_estimators=100, max_depth=None, random_state=42):
16
+ """
17
+ Initialize the Treatment Predictor with Random Forest
18
+
19
+ Args:
20
+ prediction_type (str): "classification" or "regression" depending on outcome variable type
21
+ n_estimators (int): Number of trees in the forest
22
+ max_depth (int): Maximum depth of trees (None for unlimited)
23
+ random_state (int): Random seed for reproducibility
24
+ """
25
+ self.prediction_type = prediction_type
26
+ self.n_estimators = n_estimators
27
+ self.max_depth = max_depth
28
+ self.random_state = random_state
29
+ self.feature_importance = None
30
+ self.feature_names = None
31
+
32
+ if prediction_type == "classification":
33
+ self.model = RandomForestClassifier(
34
+ n_estimators=n_estimators,
35
+ max_depth=max_depth,
36
+ random_state=random_state
37
+ )
38
+ else: # regression
39
+ self.model = RandomForestRegressor(
40
+ n_estimators=n_estimators,
41
+ max_depth=max_depth,
42
+ random_state=random_state
43
+ )
44
+
45
+ def prepare_features(self, latents, demographics):
46
+ """
47
+ Combine latent features with demographics
48
+
49
+ Args:
50
+ latents (np.ndarray): Latent representations from VAE
51
+ demographics (dict or pd.DataFrame): Demographic information
52
+
53
+ Returns:
54
+ tuple: Combined features array and feature names
55
+ """
56
+ if isinstance(demographics, dict):
57
+ demo_df = pd.DataFrame(demographics)
58
+ else:
59
+ demo_df = demographics.copy()
60
+
61
+ # Get categorical columns
62
+ cat_columns = demo_df.select_dtypes(include=['object']).columns.tolist()
63
+
64
+ # Convert categorical variables to dummy variables
65
+ if cat_columns:
66
+ demo_df = pd.get_dummies(demo_df, columns=cat_columns)
67
+
68
+ # Get feature names
69
+ latent_names = [f'latent_{i}' for i in range(latents.shape[1])]
70
+ demo_names = demo_df.columns.tolist()
71
+ feature_names = latent_names + demo_names
72
+
73
+ # Combine latents with demographics
74
+ features = np.hstack([latents, demo_df.values])
75
+ return features, feature_names
76
+
77
+ def fit(self, latents, demographics, treatment_outcomes):
78
+ """
79
+ Fit the random forest model
80
+
81
+ Args:
82
+ latents (np.ndarray): Latent representations from VAE
83
+ demographics (dict or pd.DataFrame): Demographic information
84
+ treatment_outcomes (np.ndarray): Treatment outcome values to predict
85
+
86
+ Returns:
87
+ self: Trained model instance
88
+ """
89
+ X, feature_names = self.prepare_features(latents, demographics)
90
+ self.feature_names = feature_names
91
+
92
+ logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
93
+ self.model.fit(X, treatment_outcomes)
94
+
95
+ # Calculate feature importance
96
+ self.feature_importance = pd.DataFrame({
97
+ 'feature': feature_names,
98
+ 'importance': self.model.feature_importances_
99
+ }).sort_values('importance', ascending=False)
100
+
101
+ return self
102
+
103
+ def predict(self, latents, demographics):
104
+ """
105
+ Predict treatment outcomes for new patients
106
+
107
+ Args:
108
+ latents (np.ndarray): Latent representations from VAE
109
+ demographics (dict or pd.DataFrame): Demographic information
110
+
111
+ Returns:
112
+ tuple: Predictions and prediction uncertainty (std deviation)
113
+ """
114
+ X, _ = self.prepare_features(latents, demographics)
115
+ predictions = self.model.predict(X)
116
+
117
+ # Get prediction intervals using tree variance
118
+ if self.prediction_type == "regression":
119
+ tree_predictions = np.array([tree.predict(X)
120
+ for tree in self.model.estimators_])
121
+ prediction_std = np.std(tree_predictions, axis=0)
122
+ else: # classification
123
+ # For classification, use probability as a measure of confidence
124
+ proba = self.model.predict_proba(X)
125
+ # Use max probability as confidence measure
126
+ prediction_std = 1 - np.max(proba, axis=1)
127
+
128
+ return predictions, prediction_std
129
+
130
+ def predict_proba(self, latents, demographics):
131
+ """
132
+ Get probability estimates for classification
133
+
134
+ Args:
135
+ latents (np.ndarray): Latent representations from VAE
136
+ demographics (dict or pd.DataFrame): Demographic information
137
+
138
+ Returns:
139
+ np.ndarray: Probability estimates for each class
140
+ """
141
+ if self.prediction_type != "classification":
142
+ raise ValueError("Probability prediction only available for classification")
143
+
144
+ X, _ = self.prepare_features(latents, demographics)
145
+ return self.model.predict_proba(X)
146
+
147
+ def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
148
+ """
149
+ Perform cross-validation
150
+
151
+ Args:
152
+ latents (np.ndarray): Latent representations from VAE
153
+ demographics (dict or pd.DataFrame): Demographic information
154
+ treatment_outcomes (np.ndarray): Treatment outcome values to predict
155
+ n_splits (int): Number of folds for cross-validation
156
+
157
+ Returns:
158
+ dict: Cross-validation results
159
+ """
160
+ X, feature_names = self.prepare_features(latents, demographics)
161
+ self.feature_names = feature_names
162
+
163
+ logger.info(f"Running {n_splits}-fold cross-validation")
164
+
165
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
166
+
167
+ cv_scores = []
168
+ predictions = np.zeros_like(treatment_outcomes)
169
+ prediction_stds = np.zeros_like(treatment_outcomes)
170
+ fold_metrics = []
171
+
172
+ for fold, (train_idx, test_idx) in enumerate(kf.split(X)):
173
+ X_train, X_test = X[train_idx], X[test_idx]
174
+ y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
175
+
176
+ # Clone the model for this fold
177
+ if self.prediction_type == "classification":
178
+ fold_model = RandomForestClassifier(
179
+ n_estimators=self.n_estimators,
180
+ max_depth=self.max_depth,
181
+ random_state=self.random_state
182
+ )
183
+ else:
184
+ fold_model = RandomForestRegressor(
185
+ n_estimators=self.n_estimators,
186
+ max_depth=self.max_depth,
187
+ random_state=self.random_state
188
+ )
189
+
190
+ # Train the model
191
+ fold_model.fit(X_train, y_train)
192
+
193
+ # Make predictions
194
+ pred = fold_model.predict(X_test)
195
+
196
+ # Store predictions
197
+ predictions[test_idx] = pred
198
+
199
+ # Calculate metrics
200
+ if self.prediction_type == "regression":
201
+ rmse = np.sqrt(mean_squared_error(y_test, pred))
202
+ r2 = r2_score(y_test, pred)
203
+ metrics = {
204
+ "r2": r2,
205
+ "rmse": rmse,
206
+ "mse": rmse**2
207
+ }
208
+
209
+ # Get prediction intervals using tree variance
210
+ tree_predictions = np.array([tree.predict(X_test)
211
+ for tree in fold_model.estimators_])
212
+ pred_std = np.std(tree_predictions, axis=0)
213
+ prediction_stds[test_idx] = pred_std
214
+
215
+ else: # classification
216
+ acc = accuracy_score(y_test, pred)
217
+ prec = precision_score(y_test, pred, average='weighted', zero_division=0)
218
+ rec = recall_score(y_test, pred, average='weighted', zero_division=0)
219
+ f1 = f1_score(y_test, pred, average='weighted', zero_division=0)
220
+ metrics = {
221
+ "accuracy": acc,
222
+ "precision": prec,
223
+ "recall": rec,
224
+ "f1": f1
225
+ }
226
+
227
+ # Use probability as a measure of confidence
228
+ proba = fold_model.predict_proba(X_test)
229
+ # Use max probability as confidence measure
230
+ pred_std = 1 - np.max(proba, axis=1)
231
+ prediction_stds[test_idx] = pred_std
232
+
233
+ fold_metrics.append(metrics)
234
+ logger.info(f"Fold {fold+1} metrics: {metrics}")
235
+
236
+ # Calculate average metrics
237
+ avg_metrics = {}
238
+ for key in fold_metrics[0].keys():
239
+ avg_metrics[key] = np.mean([fold[key] for fold in fold_metrics])
240
+
241
+ logger.info(f"Average CV metrics: {avg_metrics}")
242
+
243
+ # Train final model on all data
244
+ self.model.fit(X, treatment_outcomes)
245
+
246
+ # Calculate feature importance
247
+ self.feature_importance = pd.DataFrame({
248
+ 'feature': feature_names,
249
+ 'importance': self.model.feature_importances_
250
+ }).sort_values('importance', ascending=False)
251
+
252
+ return {
253
+ "mean_metrics": avg_metrics,
254
+ "fold_metrics": fold_metrics,
255
+ "predictions": predictions,
256
+ "prediction_stds": prediction_stds,
257
+ "feature_importance": self.feature_importance
258
+ }
259
+
260
+ def get_feature_importance(self):
261
+ """
262
+ Get feature importance from the trained model
263
+
264
+ Returns:
265
+ pd.DataFrame: Feature importance values
266
+ """
267
+ if self.feature_importance is None:
268
+ raise ValueError("Model must be trained first")
269
+
270
+ return self.feature_importance
271
+
272
+ def plot_feature_importance(self, top_n=10):
273
+ """
274
+ Plot feature importance
275
+
276
+ Args:
277
+ top_n (int): Number of top features to show
278
+
279
+ Returns:
280
+ matplotlib.figure.Figure: Feature importance plot
281
+ """
282
+ if self.feature_importance is None:
283
+ raise ValueError("Model must be trained first")
284
+
285
+ # Get top N features
286
+ top_features = self.feature_importance.head(top_n)
287
+
288
+ plt.figure(figsize=(10, 6))
289
+ plt.barh(range(len(top_features)),
290
+ top_features['importance'],
291
+ align='center')
292
+ plt.yticks(range(len(top_features)),
293
+ top_features['feature'])
294
+ plt.xlabel('Importance')
295
+ plt.ylabel('Features')
296
+ plt.title('Feature Importance in Treatment Outcome Prediction')
297
+ plt.tight_layout()
298
+ return plt.gcf()
299
+
300
+ def save_model(self, path="results/treatment_predictor.joblib"):
301
+ """
302
+ Save the trained model to disk
303
+
304
+ Args:
305
+ path (str): Path to save the model
306
+ """
307
+ # Create directory if it doesn't exist
308
+ os.makedirs(os.path.dirname(path), exist_ok=True)
309
+
310
+ # Save model and metadata
311
+ joblib.dump({
312
+ 'model': self.model,
313
+ 'feature_names': self.feature_names,
314
+ 'feature_importance': self.feature_importance,
315
+ 'prediction_type': self.prediction_type,
316
+ 'n_estimators': self.n_estimators,
317
+ 'max_depth': self.max_depth,
318
+ 'random_state': self.random_state
319
+ }, path)
320
+
321
+ logger.info(f"Model saved to {path}")
322
+
323
+ @classmethod
324
+ def load_model(cls, path="results/treatment_predictor.joblib"):
325
+ """
326
+ Load a trained model from disk
327
+
328
+ Args:
329
+ path (str): Path to load the model from
330
+
331
+ Returns:
332
+ AphasiaTreatmentPredictor: Loaded model instance
333
+ """
334
+ data = joblib.load(path)
335
+
336
+ # Create new instance
337
+ predictor = cls(
338
+ prediction_type=data['prediction_type'],
339
+ n_estimators=data['n_estimators'],
340
+ max_depth=data['max_depth'],
341
+ random_state=data['random_state']
342
+ )
343
+
344
+ # Restore model and metadata
345
+ predictor.model = data['model']
346
+ predictor.feature_names = data['feature_names']
347
+ predictor.feature_importance = data['feature_importance']
348
+
349
+ logger.info(f"Model loaded from {path}")
350
+ return predictor
351
+
352
+
353
+ def train_predictor_from_latents(latents, outcomes, demographics=None, prediction_type="regression", cv=5, **kwargs):
354
+ """
355
+ Train a treatment outcome predictor from VAE latent representations
356
+
357
+ Args:
358
+ latents (np.ndarray): Latent representations from VAE
359
+ outcomes (np.ndarray): Treatment outcome values
360
+ demographics (dict or pd.DataFrame, optional): Demographic information to include as features
361
+ prediction_type (str): "classification" or "regression"
362
+ cv (int): Number of folds for cross-validation
363
+ **kwargs: Additional parameters for the AphasiaTreatmentPredictor
364
+
365
+ Returns:
366
+ dict: Training results and trained model
367
+ """
368
+ logger.info(f"Training {prediction_type} model for treatment prediction")
369
+
370
+ # Create predictor
371
+ predictor = AphasiaTreatmentPredictor(prediction_type=prediction_type, **kwargs)
372
+
373
+ # Run cross-validation
374
+ cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
375
+
376
+ # Save the model
377
+ predictor.save_model()
378
+
379
+ return {
380
+ "predictor": predictor,
381
+ "cv_results": cv_results,
382
+ "feature_importance": predictor.get_feature_importance()
383
+ }
requirements.txt CHANGED
@@ -9,4 +9,6 @@ gradio>=2.0.0
9
  datasets>=1.11.0
10
  huggingface_hub>=0.15.0
11
  transformers>=4.15.0
 
 
12
 
 
9
  datasets>=1.11.0
10
  huggingface_hub>=0.15.0
11
  transformers>=4.15.0
12
+ seaborn>=0.11.2
13
+ joblib>=1.0.1
14
 
src/.DS_Store CHANGED
Binary files a/src/.DS_Store and b/src/.DS_Store differ
 
utils.py CHANGED
@@ -11,15 +11,6 @@ def to_cuda(x, use_cuda):
11
  def to_numpy(x):
12
  return x.detach().cpu().numpy()
13
 
14
- def fc_matrix_from_triu(triu_values, n_rois=264):
15
- fc_matrix = np.zeros((n_rois, n_rois))
16
- triu_indices = np.triu_indices(n_rois, k=1)
17
- triu_values = np.tanh(triu_values)
18
- fc_matrix[triu_indices] = triu_values
19
- fc_matrix = fc_matrix + fc_matrix.T
20
- np.fill_diagonal(fc_matrix, 1)
21
- return fc_matrix
22
-
23
  def rmse(a, b, mean=torch.mean):
24
  return mean((a-b)**2)**0.5
25
 
@@ -47,6 +38,7 @@ def decor_loss(z, demo, use_cuda=True):
47
  ps.append(p)
48
  losses = torch.stack(losses)
49
  return losses, ps
 
50
  def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
51
  demo_t = []
52
  demo_idx = 0
@@ -70,10 +62,13 @@ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
70
  def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
71
  loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
72
  loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
 
73
  # Get linear predictors for demographics
74
  pred_w = []
75
  pred_i = []
76
  pred_stats = []
 
 
77
 
78
  for i, d, t in zip(range(len(demo)), demo, demo_types):
79
  print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
@@ -115,6 +110,9 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
115
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
116
 
117
  for e in range(nepochs):
 
 
 
118
  for bs in range(0, len(x), bsize):
119
  xb = x[bs:(bs+bsize)]
120
  db = demo_t[bs:(bs+bsize)]
@@ -128,59 +126,29 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
128
  loss_decor = sum(loss_decor)
129
  loss_rec = rmse(xb, y)
130
 
131
- # Sample demographics
132
- demo_gen = []
133
- for s, t in zip(pred_stats, demo_types):
134
- if t == 'continuous':
135
- mu, std = s
136
- dd = torch.randn(100).float()
137
- dd = dd*std+mu
138
- dd = to_cuda(dd, vae.use_cuda)
139
- demo_gen.append(dd)
140
- elif t == 'categorical':
141
- idx = np.random.randint(0, len(s))
142
- for i in range(len(s)):
143
- dd = torch.ones(100).float() if idx == i else torch.zeros(100).float()
144
- dd = to_cuda(dd, vae.use_cuda)
145
- demo_gen.append(dd)
146
-
147
- demo_gen = torch.stack(demo_gen).permute(1,0)
148
-
149
- # Generate
150
- z = vae.gen(100)
151
- y = vae.dec(z, demo_gen)
152
-
153
- # Regressor/classifier guidance loss
154
- losses_pred = []
155
- idcs = []
156
- dg_idx = 0
157
-
158
- for s, t in zip(pred_stats, demo_types):
159
- if t == 'continuous':
160
- yy = y@pred_w[dg_idx]+pred_i[dg_idx]
161
- loss = rmse(demo_gen[:,dg_idx], yy)
162
- losses_pred.append(loss)
163
- idcs.append(float(demo_gen[0,dg_idx]))
164
- dg_idx += 1
165
- elif t == 'categorical':
166
- loss = 0
167
- for i in range(len(s)):
168
- yy = y@pred_w[dg_idx]+pred_i[dg_idx]
169
- loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
170
- idcs.append(int(demo_gen[0,dg_idx]))
171
- dg_idx += 1
172
- losses_pred.append(loss)
173
-
174
  total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
175
- loss_rec_mult*loss_rec + loss_decor_mult*loss_decor +
176
- loss_pred_mult*sum(losses_pred))
177
 
178
  total_loss.backward()
179
  optim.step()
180
 
181
- if e%pperiod == 0 or e == nepochs-1:
182
- print(f'Epoch {e} ReconLoss {loss_rec:.4f} CovarianceLoss {loss_C:.4f} '
183
- f'MeanLoss {loss_mu:.4f} DecorLoss {loss_decor:.4f}')
184
- print(f'GuidanceTargets {idcs}')
185
- print(f'GuidanceLosses {[f"{loss:.4f}" for loss in losses_pred]}')
186
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  def to_numpy(x):
12
  return x.detach().cpu().numpy()
13
 
 
 
 
 
 
 
 
 
 
14
  def rmse(a, b, mean=torch.mean):
15
  return mean((a-b)**2)**0.5
16
 
 
38
  ps.append(p)
39
  losses = torch.stack(losses)
40
  return losses, ps
41
+
42
  def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
43
  demo_t = []
44
  demo_idx = 0
 
62
  def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
63
  loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
64
  loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
65
+
66
  # Get linear predictors for demographics
67
  pred_w = []
68
  pred_i = []
69
  pred_stats = []
70
+ train_losses = []
71
+ val_losses = []
72
 
73
  for i, d, t in zip(range(len(demo)), demo, demo_types):
74
  print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
 
110
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
111
 
112
  for e in range(nepochs):
113
+ epoch_losses = []
114
+ vae.train()
115
+
116
  for bs in range(0, len(x), bsize):
117
  xb = x[bs:(bs+bsize)]
118
  db = demo_t[bs:(bs+bsize)]
 
126
  loss_decor = sum(loss_decor)
127
  loss_rec = rmse(xb, y)
128
 
129
+ # Calculate total loss
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
131
+ loss_rec_mult*loss_rec + loss_decor_mult*loss_decor)
 
132
 
133
  total_loss.backward()
134
  optim.step()
135
 
136
+ epoch_losses.append(total_loss.item())
137
+
138
+ # Record training loss
139
+ train_losses.append(np.mean(epoch_losses))
140
+
141
+ # Validation step
142
+ if e % pperiod == 0:
143
+ vae.eval()
144
+ with torch.no_grad():
145
+ z = vae.enc(x)
146
+ y = vae.dec(z, demo_t)
147
+ val_loss = rmse(x, y).item()
148
+ val_losses.append(val_loss)
149
+
150
+ print(f'Epoch {e}/{nepochs} - '
151
+ f'Train Loss: {train_losses[-1]:.4f} - '
152
+ f'Val Loss: {val_loss:.4f}')
153
+
154
+ return train_losses, val_losses
vae_model.py CHANGED
@@ -129,22 +129,3 @@ class DemoVAE(BaseEstimator):
129
  self.demo_dim = checkpoint['demo_dim']
130
  self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
131
  self.vae.load_state_dict(checkpoint['model_state_dict'])
132
-
133
- def train_fc_vae(X, demo_data, demo_types, model_config):
134
- n_rois = 264
135
- input_dim = (n_rois * (n_rois - 1)) // 2
136
-
137
- vae = DemoVAE(
138
- latent_dim=model_config['latent_dim'],
139
- nepochs=model_config['nepochs'],
140
- bsize=model_config['bsize'],
141
- loss_rec_mult=model_config['loss_rec_mult'],
142
- loss_decor_mult=model_config['loss_decor_mult'],
143
- lr=model_config['lr'],
144
- use_cuda=torch.cuda.is_available()
145
- )
146
-
147
- vae.fit(X, demo_data, demo_types)
148
-
149
- return vae, X, demo_data, demo_types
150
-
 
129
  self.demo_dim = checkpoint['demo_dim']
130
  self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
131
  self.vae.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
visualization.py CHANGED
@@ -1,44 +1,66 @@
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
- from utils import fc_matrix_from_triu
4
 
5
- def visualize_fc_analysis(original_triu, reconstructed_triu, generated_triu, analysis_results=None):
6
- fig = plt.figure(figsize=(15, 10))
7
- gs = plt.GridSpec(2, 3)
8
-
9
- ax1 = fig.add_subplot(gs[0, 0])
10
- ax2 = fig.add_subplot(gs[0, 1])
11
- ax3 = fig.add_subplot(gs[0, 2])
12
-
13
- original = fc_matrix_from_triu(original_triu)
14
- reconstructed = fc_matrix_from_triu(reconstructed_triu)
15
- generated = fc_matrix_from_triu(generated_triu)
16
-
17
- im1 = ax1.imshow(original, cmap='RdBu_r', vmin=-1, vmax=1)
18
- ax1.set_title('Original FC')
19
-
20
- im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=-1, vmax=1)
21
- ax2.set_title('Reconstructed FC')
22
-
23
- im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=-1, vmax=1)
24
- ax3.set_title('Generated FC')
25
-
26
- plt.colorbar(im1, ax=ax1)
27
- plt.colorbar(im2, ax=ax2)
28
- plt.colorbar(im3, ax=ax3)
29
-
30
- if analysis_results is not None:
31
- ax4 = fig.add_subplot(gs[1, :])
32
- for demo_name, results in analysis_results.items():
33
- significant_dims = np.where(np.array(results['p_values']) < 0.05)[0]
34
- correlations = np.array(results['correlations'])
35
- ax4.plot(correlations, label=f'{demo_name} (sig. dims: {len(significant_dims)})')
36
-
37
- ax4.set_xlabel('Latent Dimension')
38
- ax4.set_ylabel('Correlation Strength')
39
- ax4.set_title('Demographic Correlations with Latent Dimensions')
40
- ax4.legend()
41
 
42
  plt.tight_layout()
43
  return fig
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
 
3
 
4
+ def plot_fc_matrices(original, reconstructed, generated):
5
+ """Plot FC matrices comparison"""
6
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
7
+
8
+ vmin, vmax = -1, 1
9
+
10
+ im1 = axes[0].imshow(original, cmap='RdBu_r', vmin=vmin, vmax=vmax)
11
+ axes[0].set_title('Original FC')
12
+
13
+ im2 = axes[1].imshow(reconstructed, cmap='RdBu_r', vmin=vmin, vmax=vmax)
14
+ axes[1].set_title('Reconstructed FC')
15
+
16
+ im3 = axes[2].imshow(generated, cmap='RdBu_r', vmin=vmin, vmax=vmax)
17
+ axes[2].set_title('Generated FC')
18
+
19
+ for ax, im in zip(axes, [im1, im2, im3]):
20
+ plt.colorbar(im, ax=ax)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
  plt.tight_layout()
23
  return fig
24
 
25
+ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None):
26
+ """Plot predicted treatment trajectory"""
27
+ fig = plt.figure(figsize=(10, 6))
28
+
29
+ # Plot current and predicted points
30
+ plt.scatter([0], [current_score], label='Current Status', color='blue', s=100)
31
+ plt.scatter([months_post_stroke], [predicted_score],
32
+ label='Predicted Outcome', color='red', s=100)
33
+
34
+ # Plot trajectory
35
+ plt.plot([0, months_post_stroke], [current_score, predicted_score],
36
+ 'g--', label='Predicted Trajectory')
37
+
38
+ # Add prediction interval if available
39
+ if prediction_std is not None:
40
+ plt.fill_between([months_post_stroke],
41
+ [predicted_score - 2*prediction_std],
42
+ [predicted_score + 2*prediction_std],
43
+ color='red', alpha=0.2,
44
+ label='95% Prediction Interval')
45
+
46
+ plt.xlabel('Months Post Treatment')
47
+ plt.ylabel('WAB Score')
48
+ plt.title('Predicted Treatment Trajectory')
49
+ plt.legend()
50
+ plt.grid(True)
51
+
52
+ return fig
53
+
54
+ def plot_learning_curves(train_losses, val_losses):
55
+ """Plot VAE learning curves"""
56
+ fig = plt.figure(figsize=(10, 6))
57
+
58
+ plt.plot(train_losses, label='Training Loss')
59
+ plt.plot(val_losses, label='Validation Loss')
60
+ plt.xlabel('Epoch')
61
+ plt.ylabel('Loss')
62
+ plt.title('VAE Learning Curves')
63
+ plt.legend()
64
+ plt.grid(True)
65
+
66
+ return fig