SreekarB commited on
Commit
1c47445
·
verified ·
1 Parent(s): dbe81c1

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +141 -655
  2. data_preprocessing.py +578 -78
  3. main.py +271 -126
app.py CHANGED
@@ -1,518 +1,10 @@
1
  import gradio as gr
2
- from main import run_analysis
3
- from rcf_prediction import AphasiaTreatmentPredictor
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
7
- from visualization import plot_fc_matrices, plot_learning_curves
8
  import os
9
- from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, f1_score
 
10
  import json
11
  import pickle
12
- import pandas as pd
13
- import seaborn as sns
14
- import logging
15
- from config import MODEL_CONFIG, PREDICTION_CONFIG
16
-
17
- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
18
- logger = logging.getLogger(__name__)
19
-
20
- class AphasiaPredictionApp:
21
- def __init__(self):
22
- self.vae = None
23
- self.predictor = None
24
- self.trained = False
25
- self.latent_dim = MODEL_CONFIG['latent_dim']
26
-
27
- def train_models(self, data_dir, latent_dim, nepochs, bsize):
28
- """
29
- Train VAE and Random Forest models
30
- """
31
- # Train VAE and Random Forest
32
- logger.info(f"Training models with data from {data_dir}")
33
- logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
34
-
35
- # Default prediction parameters from our config
36
- prediction_type = PREDICTION_CONFIG.get('prediction_type', 'regression')
37
- outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq')
38
- logger.info(f"Prediction: type={prediction_type}, outcome={outcome_variable}")
39
-
40
- figures = {}
41
-
42
- try:
43
- # Run the full analysis pipeline
44
- results = run_analysis(
45
- data_dir=data_dir,
46
- demographic_file="demographics.csv",
47
- treatment_file="treatment_outcomes.csv",
48
- latent_dim=latent_dim,
49
- nepochs=nepochs,
50
- bsize=bsize,
51
- save_model=True
52
- )
53
-
54
- # Get the VAE figure from results
55
- vae_fig = results.get('figures', {}).get('vae')
56
-
57
- figures['vae'] = vae_fig
58
-
59
- if results:
60
- self.vae = results.get('vae')
61
- self.predictor = results.get('predictor')
62
- latents = results.get('latents')
63
- demographics = results.get('demographics')
64
- predictor_cv_results = results.get('predictor_cv_results')
65
-
66
- # Store the latent dimension
67
- self.latent_dim = latent_dim
68
-
69
- # Mark models as trained
70
- self.trained = True
71
-
72
- # Prepare prediction visualization if available
73
- if self.predictor and predictor_cv_results:
74
- # Get the outcome variable data
75
- if outcome_variable == 'wab_aq':
76
- outcomes = demographics['wab_aq']
77
- elif outcome_variable == 'age':
78
- outcomes = demographics['age']
79
- elif outcome_variable == 'months_post_onset':
80
- outcomes = demographics['months_post_onset']
81
- else:
82
- # Try to find the outcome in demographics data
83
- outcomes = None
84
- for key in demographics:
85
- if outcome_variable.lower() in key.lower():
86
- outcomes = demographics[key]
87
- break
88
-
89
- # Create plots
90
- if 'prediction_stds' in predictor_cv_results and 'predictions' in predictor_cv_results:
91
- # Create prediction plots
92
- prediction_fig = self.create_prediction_plots(
93
- latents,
94
- demographics,
95
- outcomes,
96
- predictor_cv_results['predictions'],
97
- predictor_cv_results['prediction_stds']
98
- )
99
- figures['prediction'] = prediction_fig
100
-
101
- # Create feature importance plot if available
102
- try:
103
- feature_importance = self.predictor.get_feature_importance()
104
- if feature_importance is not None:
105
- importance_fig = self.create_importance_plot(feature_importance)
106
- figures['importance'] = importance_fig
107
- except Exception as e:
108
- logger.warning(f"Could not create feature importance plot: {e}")
109
-
110
- logger.info("Training completed successfully")
111
-
112
- # Create learning curve plots if available
113
- if 'fold_metrics' in predictor_cv_results:
114
- learning_fig = self.create_learning_curve_plot(
115
- predictor_cv_results['fold_metrics']
116
- )
117
- figures['learning'] = learning_fig
118
-
119
- except Exception as e:
120
- logger.error(f"Error in training: {str(e)}")
121
- error_fig = plt.figure(figsize=(10, 6))
122
- plt.text(0.5, 0.5, f"Error: {str(e)}",
123
- horizontalalignment='center', verticalalignment='center',
124
- fontsize=12, color='red')
125
- plt.axis('off')
126
- figures['error'] = error_fig
127
-
128
- return figures
129
-
130
- def predict_treatment(self, fmri_file=None, age=50, sex="M",
131
- months_post_stroke=12, wab_score=50, fc_matrix=None):
132
- """
133
- Predict treatment outcome for a patient
134
-
135
- Args:
136
- fmri_file: Path to patient's fMRI file
137
- age: Patient's age at stroke
138
- sex: Patient's sex (M/F)
139
- months_post_stroke: Months since stroke
140
- wab_score: Current WAB score
141
- fc_matrix: Pre-processed FC matrix (if fMRI file not provided)
142
-
143
- Returns:
144
- Prediction results and visualization
145
- """
146
- if not self.trained:
147
- return "Please train the models first!", None
148
-
149
- try:
150
- # Process fMRI to FC matrix if provided
151
- if fmri_file and not fc_matrix:
152
- logger.info(f"Processing fMRI file: {fmri_file}")
153
- # Use the single fMRI processing function
154
- fc_matrix = process_single_fmri(fmri_file)
155
-
156
- if fc_matrix is None:
157
- return "Please provide either an fMRI file or an FC matrix", None
158
-
159
- # Ensure FC matrix is properly shaped
160
- if isinstance(fc_matrix, list):
161
- fc_matrix = np.array(fc_matrix)
162
-
163
- # Get latent representation
164
- logger.info("Extracting latent representation from FC matrix")
165
- if len(fc_matrix.shape) == 2: # If matrix is 2D (e.g., 264x264)
166
- # Convert to flattened upper triangular form
167
- n = fc_matrix.shape[0]
168
- indices = np.triu_indices(n, k=1)
169
- fc_flattened = fc_matrix[indices]
170
- fc_flattened = fc_flattened.reshape(1, -1)
171
- latent = self.vae.get_latents(fc_flattened)
172
- else:
173
- # Assume already flattened
174
- latent = self.vae.get_latents(fc_matrix.reshape(1, -1))
175
-
176
- # Prepare demographics
177
- demographics = {
178
- 'age': np.array([float(age)]),
179
- 'gender': np.array([sex]),
180
- 'months_post_onset': np.array([float(months_post_stroke)]),
181
- 'wab_aq': np.array([float(wab_score)])
182
- }
183
-
184
- logger.info("Making prediction")
185
- # Make prediction
186
- if self.predictor is None:
187
- return "Predictor model not trained", None
188
-
189
- # Make prediction using the model's predict method
190
- prediction, prediction_std = self.predictor.predict(latent, demographics)
191
-
192
- # Create visualization
193
- fig = self.plot_treatment_trajectory(
194
- current_score=wab_score,
195
- predicted_score=prediction[0],
196
- months_post_stroke=months_post_stroke,
197
- prediction_std=prediction_std[0]
198
- )
199
-
200
- result_text = f"Predicted treatment outcome: {prediction[0]:.2f} ± {2*prediction_std[0]:.2f}"
201
- logger.info(result_text)
202
-
203
- return result_text, fig
204
-
205
- except Exception as e:
206
- error_msg = f"Error in prediction: {str(e)}"
207
- logger.error(error_msg)
208
- error_fig = plt.figure(figsize=(10, 6))
209
- plt.text(0.5, 0.5, error_msg,
210
- horizontalalignment='center', verticalalignment='center',
211
- fontsize=12, color='red')
212
- plt.axis('off')
213
- return error_msg, error_fig
214
-
215
- def plot_treatment_trajectory(self, current_score, predicted_score,
216
- months_post_stroke, prediction_std,
217
- treatment_duration=6):
218
- """
219
- Create a visualization of predicted treatment trajectory
220
-
221
- Args:
222
- current_score: Current WAB score
223
- predicted_score: Predicted WAB score after treatment
224
- months_post_stroke: Current months post stroke
225
- prediction_std: Standard deviation of prediction
226
- treatment_duration: Duration of treatment in months
227
-
228
- Returns:
229
- matplotlib figure
230
- """
231
- fig = plt.figure(figsize=(10, 6))
232
-
233
- # X-axis: months
234
- x = np.array([months_post_stroke, months_post_stroke + treatment_duration])
235
-
236
- # Y-axis: WAB scores
237
- y = np.array([current_score, predicted_score])
238
-
239
- # Plot the trajectory
240
- plt.plot(x, y, 'bo-', linewidth=2, label='Predicted Trajectory')
241
-
242
- # Add confidence interval
243
- plt.fill_between(
244
- x,
245
- [y[0], y[1] - 2*prediction_std],
246
- [y[0], y[1] + 2*prediction_std],
247
- alpha=0.2, color='blue', label='95% Confidence Interval'
248
- )
249
-
250
- # Add reference lines
251
- if current_score < predicted_score:
252
- improvement = predicted_score - current_score
253
- plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5,
254
- label=f'Current WAB = {current_score:.1f}')
255
- plt.axhline(y=predicted_score, color='g', linestyle='--', alpha=0.5,
256
- label=f'Predicted WAB = {predicted_score:.1f} (+{improvement:.1f})')
257
- else:
258
- decline = current_score - predicted_score
259
- plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5,
260
- label=f'Current WAB = {current_score:.1f}')
261
- plt.axhline(y=predicted_score, color='orange', linestyle='--', alpha=0.5,
262
- label=f'Predicted WAB = {predicted_score:.1f} (-{decline:.1f})')
263
-
264
- # Add labels and title
265
- plt.xlabel('Months Post Stroke')
266
- plt.ylabel('WAB Score')
267
- plt.title('Predicted Treatment Trajectory')
268
- plt.legend(loc='best')
269
-
270
- # Set y-axis limits
271
- plt.ylim([0, 100])
272
-
273
- plt.tight_layout()
274
- return fig
275
-
276
- def create_prediction_plots(self, latents, demographics, y_true, y_pred, y_std):
277
- """Create prediction performance plots"""
278
- fig = plt.figure(figsize=(12, 8))
279
-
280
- # Create a 2x2 grid for plots
281
- gs = plt.GridSpec(2, 2, figure=fig)
282
-
283
- # Plot predicted vs actual values
284
- ax1 = fig.add_subplot(gs[0, 0])
285
-
286
- if self.predictor.prediction_type == 'regression':
287
- # Regression: scatter plot
288
- ax1.scatter(y_true, y_pred, alpha=0.7)
289
-
290
- # Add perfect prediction line
291
- min_val = min(np.min(y_true), np.min(y_pred))
292
- max_val = max(np.max(y_true), np.max(y_pred))
293
- ax1.plot([min_val, max_val], [min_val, max_val], 'r--')
294
-
295
- ax1.set_xlabel('Actual Values')
296
- ax1.set_ylabel('Predicted Values')
297
- ax1.set_title('Predicted vs. Actual Values')
298
-
299
- # Add R² to the plot
300
- r2 = r2_score(y_true, y_pred)
301
- ax1.text(0.05, 0.95, f'R² = {r2:.4f}', transform=ax1.transAxes,
302
- bbox=dict(facecolor='white', alpha=0.5))
303
-
304
- # Plot residuals
305
- ax2 = fig.add_subplot(gs[0, 1])
306
- residuals = y_true - y_pred
307
- ax2.scatter(y_pred, residuals, alpha=0.7)
308
- ax2.axhline(y=0, color='r', linestyle='--')
309
- ax2.set_xlabel('Predicted Values')
310
- ax2.set_ylabel('Residuals')
311
- ax2.set_title('Residual Plot')
312
-
313
- # Plot prediction errors
314
- ax3 = fig.add_subplot(gs[1, 0])
315
- ax3.errorbar(range(len(y_pred)), y_pred, yerr=2*y_std, fmt='o', alpha=0.7,
316
- label='Predicted ± 2σ')
317
- ax3.plot(range(len(y_true)), y_true, 'rx', alpha=0.7, label='Actual')
318
- ax3.set_xlabel('Sample Index')
319
- ax3.set_ylabel('Value')
320
- ax3.set_title('Prediction with Error Bars')
321
- ax3.legend()
322
-
323
- # Plot error distribution
324
- ax4 = fig.add_subplot(gs[1, 1])
325
- ax4.hist(residuals, bins=20, alpha=0.7)
326
- ax4.axvline(x=0, color='r', linestyle='--')
327
- ax4.set_xlabel('Prediction Error')
328
- ax4.set_ylabel('Frequency')
329
- ax4.set_title('Error Distribution')
330
-
331
- else: # classification
332
- # Convert to integer classes if they're strings
333
- if isinstance(y_true[0], str) or isinstance(y_pred[0], str):
334
- # Create mapping of class labels to integers
335
- classes = sorted(list(set(list(y_true) + list(y_pred))))
336
- class_to_int = {c: i for i, c in enumerate(classes)}
337
-
338
- y_true_int = np.array([class_to_int[c] for c in y_true])
339
- y_pred_int = np.array([class_to_int[c] for c in y_pred])
340
- else:
341
- y_true_int = y_true
342
- y_pred_int = y_pred
343
- classes = sorted(list(set(list(y_true_int) + list(y_pred_int))))
344
-
345
- # Confusion matrix
346
- from sklearn.metrics import confusion_matrix
347
- cm = confusion_matrix(y_true_int, y_pred_int)
348
-
349
- # Plot confusion matrix
350
- sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes,
351
- yticklabels=classes, ax=ax1)
352
- ax1.set_xlabel('Predicted')
353
- ax1.set_ylabel('True')
354
- ax1.set_title('Confusion Matrix')
355
-
356
- # Class distribution
357
- ax2 = fig.add_subplot(gs[0, 1])
358
- unique_classes, true_counts = np.unique(y_true_int, return_counts=True)
359
- unique_classes, pred_counts = np.unique(y_pred_int, return_counts=True)
360
-
361
- # Create class distribution DataFrame
362
- class_dist = pd.DataFrame({
363
- 'Class': classes,
364
- 'True': 0,
365
- 'Predicted': 0
366
- })
367
-
368
- for c, count in zip(unique_classes, true_counts):
369
- class_dist.loc[class_dist['Class'] == classes[c], 'True'] = count
370
-
371
- for c, count in zip(unique_classes, pred_counts):
372
- class_dist.loc[class_dist['Class'] == classes[c], 'Predicted'] = count
373
-
374
- # Plot class distribution
375
- ax2.bar(class_dist['Class'].astype(str), class_dist['True'], label='True', alpha=0.7)
376
- ax2.bar(class_dist['Class'].astype(str), class_dist['Predicted'], label='Predicted', alpha=0.5)
377
- ax2.set_xlabel('Class')
378
- ax2.set_ylabel('Count')
379
- ax2.set_title('Class Distribution')
380
- ax2.legend()
381
-
382
- # Performance metrics
383
- ax3 = fig.add_subplot(gs[1, 0])
384
- ax3.axis('off')
385
-
386
- # Calculate metrics
387
- from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
388
- acc = accuracy_score(y_true_int, y_pred_int)
389
- prec = precision_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
390
- rec = recall_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
391
- f1 = f1_score(y_true_int, y_pred_int, average='weighted', zero_division=0)
392
-
393
- metrics_text = (
394
- f"Classification Metrics:\n\n"
395
- f"Accuracy: {acc:.4f}\n"
396
- f"Precision: {prec:.4f}\n"
397
- f"Recall: {rec:.4f}\n"
398
- f"F1 Score: {f1:.4f}"
399
- )
400
-
401
- ax3.text(0.5, 0.5, metrics_text, ha='center', va='center', fontsize=12)
402
-
403
- # Confidence distribution
404
- ax4 = fig.add_subplot(gs[1, 1])
405
- ax4.hist(1 - y_std, bins=20, alpha=0.7)
406
- ax4.set_xlabel('Prediction Confidence')
407
- ax4.set_ylabel('Frequency')
408
- ax4.set_title('Confidence Distribution')
409
-
410
- plt.tight_layout()
411
- return fig
412
-
413
- def create_importance_plot(self, feature_importance, top_n=15):
414
- """Create feature importance plot"""
415
- # If feature_importance is a DataFrame, use it directly
416
- if isinstance(feature_importance, pd.DataFrame):
417
- importance_df = feature_importance
418
- else:
419
- # Create DataFrame
420
- importance_df = pd.DataFrame({
421
- 'feature': [f'Feature {i}' for i in range(len(feature_importance))],
422
- 'importance': feature_importance
423
- })
424
-
425
- # Get top N features
426
- top_features = importance_df.sort_values('importance', ascending=False).head(top_n)
427
-
428
- # Create plot
429
- fig = plt.figure(figsize=(10, 6))
430
- plt.barh(range(len(top_features)), top_features['importance'], align='center')
431
- plt.yticks(range(len(top_features)), top_features['feature'])
432
- plt.xlabel('Importance')
433
- plt.ylabel('Features')
434
- plt.title(f'Top {top_n} Features by Importance')
435
- plt.tight_layout()
436
-
437
- return fig
438
-
439
- def create_learning_curve_plot(self, fold_metrics):
440
- """Create learning curve plots from cross-validation results"""
441
- fig = plt.figure(figsize=(12, 6))
442
-
443
- # Create a grid for plots
444
- if self.predictor.prediction_type == 'regression':
445
- # For regression, show R² and RMSE
446
- ax1 = plt.subplot(1, 2, 1)
447
- ax2 = plt.subplot(1, 2, 2)
448
-
449
- # Plot R² for each fold
450
- for i, metrics in enumerate(fold_metrics):
451
- ax1.plot(i+1, metrics['r2'], 'bo')
452
-
453
- # Plot average R²
454
- avg_r2 = np.mean([m['r2'] for m in fold_metrics])
455
- ax1.axhline(y=avg_r2, color='r', linestyle='--',
456
- label=f'Average R² = {avg_r2:.4f}')
457
-
458
- ax1.set_xlabel('Fold')
459
- ax1.set_ylabel('R²')
460
- ax1.set_title('R² by Fold')
461
- ax1.set_xticks(range(1, len(fold_metrics)+1))
462
- ax1.legend()
463
-
464
- # Plot RMSE for each fold
465
- for i, metrics in enumerate(fold_metrics):
466
- ax2.plot(i+1, metrics['rmse'], 'go')
467
-
468
- # Plot average RMSE
469
- avg_rmse = np.mean([m['rmse'] for m in fold_metrics])
470
- ax2.axhline(y=avg_rmse, color='r', linestyle='--',
471
- label=f'Average RMSE = {avg_rmse:.4f}')
472
-
473
- ax2.set_xlabel('Fold')
474
- ax2.set_ylabel('RMSE')
475
- ax2.set_title('RMSE by Fold')
476
- ax2.set_xticks(range(1, len(fold_metrics)+1))
477
- ax2.legend()
478
-
479
- else: # classification
480
- # For classification, show accuracy and F1
481
- ax1 = plt.subplot(1, 2, 1)
482
- ax2 = plt.subplot(1, 2, 2)
483
-
484
- # Plot accuracy for each fold
485
- for i, metrics in enumerate(fold_metrics):
486
- ax1.plot(i+1, metrics['accuracy'], 'bo')
487
-
488
- # Plot average accuracy
489
- avg_acc = np.mean([m['accuracy'] for m in fold_metrics])
490
- ax1.axhline(y=avg_acc, color='r', linestyle='--',
491
- label=f'Average Accuracy = {avg_acc:.4f}')
492
-
493
- ax1.set_xlabel('Fold')
494
- ax1.set_ylabel('Accuracy')
495
- ax1.set_title('Accuracy by Fold')
496
- ax1.set_xticks(range(1, len(fold_metrics)+1))
497
- ax1.legend()
498
-
499
- # Plot F1 for each fold
500
- for i, metrics in enumerate(fold_metrics):
501
- ax2.plot(i+1, metrics['f1'], 'go')
502
-
503
- # Plot average F1
504
- avg_f1 = np.mean([m['f1'] for m in fold_metrics])
505
- ax2.axhline(y=avg_f1, color='r', linestyle='--',
506
- label=f'Average F1 = {avg_f1:.4f}')
507
-
508
- ax2.set_xlabel('Fold')
509
- ax2.set_ylabel('F1 Score')
510
- ax2.set_title('F1 Score by Fold')
511
- ax2.set_xticks(range(1, len(fold_metrics)+1))
512
- ax2.legend()
513
-
514
- plt.tight_layout()
515
- return fig
516
 
517
  def calculate_fc_accuracy(original_fc, reconstructed_fc):
518
  """
@@ -576,169 +68,163 @@ def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'):
576
 
577
  return os.path.join('results', file_path)
578
 
579
- # Make sure directory exists for saving results
580
- os.makedirs('results', exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
581
 
582
  def create_interface():
583
- """Create the Gradio interface"""
584
- app = AphasiaPredictionApp()
585
-
586
- with gr.Blocks(title="Aphasia Treatment Trajectory Prediction") as interface:
587
- gr.Markdown("# Aphasia Treatment Trajectory Prediction")
 
 
 
 
 
 
 
 
 
 
 
588
 
589
- with gr.Tabs():
590
- # Training Tab
591
- with gr.Tab("Train Models"):
592
- with gr.Row():
593
- with gr.Column(scale=1):
594
- data_dir = gr.Textbox(
595
- label="Data Directory",
596
- value="SreekarB/OSFData"
597
- )
598
- latent_dim = gr.Slider(
599
- minimum=8, maximum=64, step=8,
600
- label="Latent Dimensions", value=32
601
- )
602
- nepochs = gr.Slider(
603
- minimum=100, maximum=5000, step=100,
604
- label="Number of Epochs", value=200 # Reduced for faster demos
605
- )
606
-
607
- with gr.Column(scale=1):
608
- bsize = gr.Slider(
609
- minimum=8, maximum=64, step=8,
610
- label="Batch Size", value=16
611
- )
612
- use_hf_dataset = gr.Checkbox(
613
- label="Use HuggingFace Dataset", value=True
614
- )
615
- with gr.Group("Prediction Options"):
616
- prediction_type = gr.Radio(
617
- label="Prediction Type",
618
- choices=["regression", "classification"],
619
- value="regression"
620
- )
621
- outcome_variable = gr.Dropdown(
622
- label="Outcome Variable",
623
- choices=["wab_aq", "age", "months_post_onset"],
624
- value="wab_aq"
625
- )
626
-
627
- train_btn = gr.Button("Train Models", variant="primary")
628
-
629
- with gr.Row():
630
- fc_plot = gr.Plot(label="FC Analysis")
631
 
632
- with gr.Row():
633
- with gr.Column(scale=1):
634
- importance_plot = gr.Plot(label="Feature Importance")
635
- with gr.Column(scale=1):
636
- prediction_plot = gr.Plot(label="Prediction Performance")
637
 
638
- with gr.Row():
639
- learning_plot = gr.Plot(label="Cross-validation Results")
640
-
641
- # Prediction Tab
642
- with gr.Tab("Predict Treatment"):
643
- with gr.Row():
644
- with gr.Column(scale=1):
645
- fmri_file = gr.File(label="Patient fMRI Data")
646
- with gr.Column(scale=1):
647
- with gr.Group("Patient Demographics"):
648
- age = gr.Number(label="Age at Stroke", value=60)
649
- sex = gr.Dropdown(choices=["M", "F"], label="Sex", value="M")
650
- months = gr.Number(label="Months Post Stroke", value=12)
651
- wab = gr.Number(label="Current WAB Score", value=50)
652
-
653
- predict_btn = gr.Button("Predict Treatment Outcome", variant="primary")
654
-
655
- with gr.Row():
656
- prediction_text = gr.Textbox(label="Prediction Result")
657
-
658
- with gr.Row():
659
- trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory")
660
-
661
- # Connect components
662
- train_outputs = {
663
- 'vae': fc_plot,
664
- 'importance': importance_plot,
665
- 'prediction': prediction_plot,
666
- 'learning': learning_plot
667
- }
668
-
669
- # Handle train button click
670
- def handle_train(data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
671
- prediction_type, outcome_variable):
672
- # Ensure we have the necessary files before training
673
- # This is a placeholder - in a real app you'd validate these files exist
674
- demographic_file = os.path.join(data_dir, "demographics.csv")
675
- treatment_file = os.path.join(data_dir, "treatment_outcomes.csv")
676
-
677
- results = app.train_models(
678
- data_dir=data_dir,
679
- latent_dim=latent_dim,
680
- nepochs=nepochs,
681
- bsize=bsize
682
- )
683
-
684
- # Return plots in the expected order
685
- return [
686
- results.get('vae', None),
687
- results.get('importance', None),
688
- results.get('prediction', None),
689
- results.get('learning', None)
690
- ]
691
-
692
- train_btn.click(
693
- fn=handle_train,
694
- inputs=[data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
695
- prediction_type, outcome_variable],
696
- outputs=[fc_plot, importance_plot, prediction_plot, learning_plot]
697
  )
698
 
699
- predict_btn.click(
700
- fn=app.predict_treatment,
701
- inputs=[fmri_file, age, sex, months, wab],
702
- outputs=[prediction_text, trajectory_plot]
 
 
 
 
 
 
 
 
 
 
 
703
  )
704
 
705
  # Add examples
706
  gr.Examples(
707
  examples=[
708
- ["SreekarB/OSFData", 32, 200, 16, True, "regression", "wab_aq"], # Standard training
709
- ["SreekarB/OSFData", 16, 100, 8, True, "classification", "wab_aq"] # Faster training with classification
710
  ],
711
- inputs=[data_dir, latent_dim, nepochs, bsize, use_hf_dataset,
712
- prediction_type, outcome_variable],
713
  )
714
 
715
- # Add explanation
716
  gr.Markdown("""
717
- ## How to use this tool
718
-
719
- 1. **Train Models Tab**: First train the VAE and Random Forest models using your dataset
720
- - Use the default SreekarB/OSFData dataset or specify your own data source
721
- - Adjust parameters like latent dimensions and training epochs
722
- - Choose regression or classification prediction type
723
- - Select which variable to predict (WAB score by default)
724
-
725
- 2. **Predict Treatment Tab**: Use the trained models to predict treatment outcomes
726
- - Upload a patient's fMRI scan or use synthetic data
727
- - Enter the patient's demographic information
728
- - Click "Predict Treatment Outcome" to see the projected treatment trajectory
729
- - The visualization shows the predicted outcome with confidence intervals
730
-
731
- ## Interpreting Results
732
 
733
- - The **Feature Importance** plot shows which latent dimensions and demographic variables most strongly predict treatment outcomes
734
- - The **Prediction Performance** plot shows how well the model predicts known outcomes
735
- - The **Treatment Trajectory** shows the projected change in WAB score over the course of treatment
 
 
736
 
737
- Note: For optimal results, train with at least 500 epochs and latent dimension of 32 or higher.
738
  """)
739
 
740
- return interface
741
 
742
  if __name__ == "__main__":
743
- interface = create_interface()
744
- interface.launch(share=True)
 
 
1
  import gradio as gr
2
+ from main import run_fc_analysis
 
 
 
 
 
3
  import os
4
+ import numpy as np
5
+ from sklearn.metrics import mean_squared_error, r2_score
6
  import json
7
  import pickle
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  def calculate_fc_accuracy(original_fc, reconstructed_fc):
10
  """
 
68
 
69
  return os.path.join('results', file_path)
70
 
71
+ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
72
+ """Run the full VAE analysis pipeline with accuracy metrics"""
73
+ # Run the original analysis
74
+ fig, results = run_fc_analysis(
75
+ data_dir=data_source,
76
+ demographic_file=None, # We're now getting demographics directly from the dataset
77
+ latent_dim=latent_dim,
78
+ nepochs=nepochs,
79
+ bsize=bsize,
80
+ save_model=True,
81
+ use_hf_dataset=use_hf_dataset,
82
+ return_data=True # New parameter to return data, will need to update main.py
83
+ )
84
+
85
+ if results:
86
+ vae = results.get('vae')
87
+ X = results.get('X')
88
+ latents = results.get('latents')
89
+ demographics = results.get('demographics')
90
+ reconstructed_fc = results.get('reconstructed_fc')
91
+ generated_fc = results.get('generated_fc')
92
+
93
+ # Calculate accuracy metrics
94
+ accuracy_metrics = {}
95
+ if X is not None and reconstructed_fc is not None:
96
+ for i in range(min(5, len(X))): # Calculate for up to 5 samples
97
+ metrics = calculate_fc_accuracy(X[i], reconstructed_fc[i])
98
+ accuracy_metrics[f"Subject_{i+1}"] = metrics
99
+
100
+ # Average metrics across subjects
101
+ avg_metrics = {}
102
+ for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]:
103
+ avg_metrics[metric] = np.mean([subject_metrics[metric]
104
+ for subject_metrics in accuracy_metrics.values()])
105
+ accuracy_metrics["Average"] = avg_metrics
106
+
107
+ # Save latent representations if available
108
+ if latents is not None and demographics is not None:
109
+ latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
110
+ print(f"Saved latents to {latents_path}")
111
+
112
+ # Prepare status message with accuracy metrics
113
+ if accuracy_metrics:
114
+ avg = accuracy_metrics["Average"]
115
+ status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n"
116
+ f"Reconstruction Accuracy Metrics (Average):\n"
117
+ f"• MSE: {avg['MSE']:.6f}\n"
118
+ f"• RMSE: {avg['RMSE']:.6f}\n"
119
+ f"• R²: {avg['R²']:.6f}\n"
120
+ f"• Correlation: {avg['Correlation']:.6f}\n"
121
+ f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n"
122
+ f"Latent representations saved to results/latents_dim{latent_dim}.pkl")
123
+ else:
124
+ status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
125
+ else:
126
+ status = "Analysis complete, but no results were returned for accuracy calculation."
127
+
128
+ return fig, status
129
 
130
  def create_interface():
131
+ with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
132
+ gr.Markdown("""
133
+ # Aphasia fMRI to FC Analysis using VAE
134
+
135
+ This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
136
+
137
+ ## Dataset Information
138
+ By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
139
+ - ID: Subject identifier
140
+ - wab_aq: Aphasia severity score
141
+ - age: Age of the subject
142
+ - mpo: Months post onset
143
+ - education: Years of education
144
+ - gender: Subject gender
145
+ - handedness: Subject handedness (ignored in the analysis)
146
+ """)
147
 
148
+ with gr.Row():
149
+ with gr.Column(scale=1):
150
+ # Configuration parameters
151
+ data_source = gr.Textbox(
152
+ label="Data Source (HF Dataset ID or Local Directory)",
153
+ value="SreekarB/OSFData"
154
+ )
155
+ latent_dim = gr.Slider(
156
+ minimum=8, maximum=64, step=8,
157
+ label="Latent Dimensions", value=32
158
+ )
159
+ nepochs = gr.Slider(
160
+ minimum=100, maximum=5000, step=100,
161
+ label="Number of Epochs", value=200 # Reduced for faster demos
162
+ )
163
+ bsize = gr.Slider(
164
+ minimum=8, maximum=64, step=8,
165
+ label="Batch Size", value=16
166
+ )
167
+ use_hf_dataset = gr.Checkbox(
168
+ label="Use HuggingFace Dataset", value=True
169
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
170
 
171
+ # Training button
172
+ train_button = gr.Button("Start Training", variant="primary")
173
+ status_text = gr.Textbox(label="Status", value="Ready to start training")
 
 
174
 
175
+ with gr.Column(scale=2):
176
+ # Output plot
177
+ output_plot = gr.Plot(label="Analysis Results")
178
+ accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
179
+
180
+ # Link the training button to the analysis function
181
+ train_button.click(
182
+ fn=gradio_fc_analysis,
183
+ inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
184
+ outputs=[output_plot, status_text]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  )
186
 
187
+ # Custom function to update the accuracy box
188
+ def update_accuracy_display(status_text):
189
+ if "Accuracy Metrics" in status_text:
190
+ # Extract the accuracy metrics section
191
+ parts = status_text.split("Reconstruction Accuracy Metrics (Average):")
192
+ if len(parts) > 1:
193
+ metrics_text = parts[1].split("\n\n")[0]
194
+ return f"### Reconstruction Accuracy Metrics\n{metrics_text}"
195
+ return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics."
196
+
197
+ # Update accuracy box when status changes
198
+ status_text.change(
199
+ fn=update_accuracy_display,
200
+ inputs=[status_text],
201
+ outputs=[accuracy_box]
202
  )
203
 
204
  # Add examples
205
  gr.Examples(
206
  examples=[
207
+ ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
 
208
  ],
209
+ inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
 
210
  )
211
 
212
+ # Add explanation of the workflow
213
  gr.Markdown("""
214
+ ## How this works
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
+ 1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
217
+ 2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
218
+ 3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
219
+ 4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
220
+ 5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
221
 
222
+ Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
223
  """)
224
 
225
+ return iface
226
 
227
  if __name__ == "__main__":
228
+ iface = create_interface()
229
+ iface.launch(share=True)
230
+
data_preprocessing.py CHANGED
@@ -1,93 +1,593 @@
1
  import numpy as np
2
  import pandas as pd
 
3
  from nilearn import input_data, connectome
4
  from nilearn.image import load_img
5
  import nibabel as nib
6
- from pathlib import Path
7
- from config import PREPROCESS_CONFIG
8
 
9
- def process_single_fmri(fmri_file):
10
  """
11
- Process a single fMRI file to FC matrix
12
- """
13
- # Use Power 264 atlas
14
- from nilearn import datasets
15
- power = datasets.fetch_coords_power_2011()
16
- coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
17
-
18
- # Create masker
19
- masker = input_data.NiftiSpheresMasker(
20
- coords,
21
- radius=PREPROCESS_CONFIG['radius'],
22
- standardize=True,
23
- memory='nilearn_cache',
24
- memory_level=1,
25
- verbose=0,
26
- detrend=True,
27
- low_pass=PREPROCESS_CONFIG['low_pass'],
28
- high_pass=PREPROCESS_CONFIG['high_pass'],
29
- t_r=PREPROCESS_CONFIG['t_r']
30
- )
31
-
32
- # Load and process fMRI
33
- fmri_img = load_img(fmri_file)
34
- time_series = masker.fit_transform(fmri_img)
35
-
36
- # Compute FC matrix
37
- correlation_measure = connectome.ConnectivityMeasure(
38
- kind='correlation',
39
- vectorize=False,
40
- discard_diagonal=False
41
- )
42
-
43
- fc_matrix = correlation_measure.fit_transform([time_series])[0]
44
 
45
- # Get upper triangular part
46
- triu_indices = np.triu_indices_from(fc_matrix, k=1)
47
- fc_triu = fc_matrix[triu_indices]
 
48
 
49
- # Fisher z-transform
50
- fc_triu = np.arctanh(fc_triu)
51
-
52
- return fc_triu
53
-
54
- def preprocess_fmri_to_fc(nii_files, demo_data, demo_types):
55
- """
56
- Convert multiple fMRI files to FC matrices
57
  """
58
- fc_matrices = []
59
 
60
- for nii_file in nii_files:
61
- fc_triu = process_single_fmri(nii_file)
62
- fc_matrices.append(fc_triu)
63
-
64
- X = np.array(fc_matrices)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  # Normalize the FC data
67
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
68
 
69
- return X, demo_data, demo_types
70
-
71
- def load_and_preprocess_data(data_dir, demographic_file):
72
- """
73
- Load and preprocess both fMRI data and demographics
74
- """
75
- # Load demographics
76
- demo_df = pd.read_csv(demographic_file)
77
-
78
- demo_data = [
79
- demo_df['age_at_stroke'].values,
80
- demo_df['sex'].values,
81
- demo_df['months_post_stroke'].values,
82
- demo_df['wab_score'].values
83
- ]
84
-
85
- demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
86
-
87
- # Load fMRI files
88
- nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
89
-
90
- # Process fMRI files to FC matrices
91
- X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
92
-
93
- return X, demo_data, demo_types
 
1
  import numpy as np
2
  import pandas as pd
3
+ from datasets import load_dataset
4
  from nilearn import input_data, connectome
5
  from nilearn.image import load_img
6
  import nibabel as nib
7
+ import os
 
8
 
9
+ def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
10
  """
11
+ Process fMRI data to generate functional connectivity matrices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
+ Parameters:
14
+ - dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
15
+ - demo_data: Optional demographic data, required if providing NIfTI files
16
+ - demo_types: Optional demographic data types, required if providing NIfTI files
17
 
18
+ Returns:
19
+ - X: Array of FC matrices
20
+ - demo_data: Demographic data
21
+ - demo_types: Demographic data types
 
 
 
 
22
  """
23
+ print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
24
 
25
+ # For SreekarB/OSFData dataset, the data will be loaded from dataset features
26
+ if isinstance(dataset_or_niifiles, str):
27
+ dataset_name = dataset_or_niifiles
28
+ print(f"Loading data from dataset: {dataset_name}")
29
+ try:
30
+ # Try multiple approaches to load the dataset
31
+ approaches = [
32
+ lambda: load_dataset(dataset_name, split="train"),
33
+ lambda: load_dataset(dataset_name), # Try without split
34
+ lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code
35
+ lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
36
+ ]
37
+
38
+ dataset = None
39
+ last_error = None
40
+
41
+ for i, approach in enumerate(approaches):
42
+ if approach is None:
43
+ continue
44
+
45
+ try:
46
+ print(f"Attempt {i+1} to load dataset...")
47
+ dataset = approach()
48
+ print(f"Successfully loaded dataset with approach {i+1}!")
49
+ break
50
+ except Exception as e:
51
+ print(f"Attempt {i+1} failed: {e}")
52
+ last_error = e
53
+
54
+ if dataset is None:
55
+ print(f"All attempts to load dataset failed. Last error: {last_error}")
56
+ raise ValueError(f"Could not load dataset {dataset_name}")
57
+ except Exception as e:
58
+ print(f"Error during dataset loading: {e}")
59
+ raise
60
+
61
+ # Prepare demographics data from the dataset
62
+ if demo_data is None:
63
+ # Create demo_data from the dataset
64
+ demo_df = pd.DataFrame({
65
+ 'age': dataset['age'],
66
+ 'gender': dataset['gender'],
67
+ 'mpo': dataset['mpo'],
68
+ 'wab_aq': dataset['wab_aq']
69
+ })
70
+
71
+ demo_data = [
72
+ demo_df['age'].values,
73
+ demo_df['gender'].values,
74
+ demo_df['mpo'].values,
75
+ demo_df['wab_aq'].values
76
+ ]
77
+
78
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
79
+
80
+ # Look for NIfTI files in P01_rs.nii format
81
+ print("Searching for NIfTI files in dataset columns...")
82
+ nii_files = []
83
+
84
+ # Create a temp directory for downloads
85
+ import tempfile
86
+ from huggingface_hub import hf_hub_download
87
+ import shutil
88
+
89
+ temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
90
+ print(f"Created temporary directory for NIfTI files: {temp_dir}")
91
+
92
+ try:
93
+ # First approach: Check if there are any columns containing file paths
94
+ nii_columns = []
95
+ for col in dataset.column_names:
96
+ # Check if column name suggests NIfTI files
97
+ if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
98
+ nii_columns.append(col)
99
+ # Or check if column contains file paths
100
+ elif len(dataset) > 0:
101
+ first_val = dataset[0][col]
102
+ if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
103
+ nii_columns.append(col)
104
+
105
+ if nii_columns:
106
+ print(f"Found columns that may contain NIfTI files: {nii_columns}")
107
+
108
+ for col in nii_columns:
109
+ print(f"Processing column '{col}'...")
110
+
111
+ for i, item in enumerate(dataset[col]):
112
+ if not isinstance(item, str):
113
+ print(f"Item {i} in column {col} is not a string but {type(item)}")
114
+ continue
115
+
116
+ if not (item.endswith('.nii') or item.endswith('.nii.gz')):
117
+ print(f"Item {i} in column {col} is not a NIfTI file: {item}")
118
+ continue
119
+
120
+ print(f"Downloading {item} from dataset {dataset_name}...")
121
+
122
+ try:
123
+ # Attempt to download with explicit filename
124
+ file_path = hf_hub_download(
125
+ repo_id=dataset_name,
126
+ filename=item,
127
+ repo_type="dataset",
128
+ cache_dir=temp_dir
129
+ )
130
+ nii_files.append(file_path)
131
+ print(f"✓ Successfully downloaded {item}")
132
+ except Exception as e1:
133
+ print(f"Error downloading with explicit filename: {e1}")
134
+
135
+ # Second attempt: try with the item's basename
136
+ try:
137
+ basename = os.path.basename(item)
138
+ print(f"Trying with basename: {basename}")
139
+ file_path = hf_hub_download(
140
+ repo_id=dataset_name,
141
+ filename=basename,
142
+ repo_type="dataset",
143
+ cache_dir=temp_dir
144
+ )
145
+ nii_files.append(file_path)
146
+ print(f"✓ Successfully downloaded {basename}")
147
+ except Exception as e2:
148
+ print(f"Error downloading with basename: {e2}")
149
+
150
+ # Third attempt: check if it's a binary blob in the dataset
151
+ try:
152
+ if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
153
+ print("Found binary data in dataset, saving to temporary file...")
154
+ binary_data = dataset[i]['bytes']
155
+ temp_file = os.path.join(temp_dir, basename)
156
+ with open(temp_file, 'wb') as f:
157
+ f.write(binary_data)
158
+ nii_files.append(temp_file)
159
+ print(f"✓ Saved binary data to {temp_file}")
160
+ except Exception as e3:
161
+ print(f"Error handling binary data: {e3}")
162
+
163
+ # Last resort: look for the file locally
164
+ local_path = os.path.join(os.getcwd(), item)
165
+ if os.path.exists(local_path):
166
+ nii_files.append(local_path)
167
+ print(f"✓ Found {item} locally")
168
+ else:
169
+ print(f"❌ Warning: Could not find {item} anywhere")
170
+
171
+ # Second approach: Try to find NIfTI files in dataset repository directly
172
+ if not nii_files:
173
+ print("No NIfTI files found in dataset columns. Trying direct repository search...")
174
+
175
+ try:
176
+ from huggingface_hub import list_repo_files, hf_hub_download
177
+
178
+ # Try to list all files in the repository
179
+ try:
180
+ print("Listing all repository files...")
181
+ all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
182
+ print(f"Found {len(all_repo_files)} files in repository")
183
+
184
+ # First prioritize P*_rs.nii files
185
+ p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
186
+
187
+ # Then include all other NIfTI files
188
+ other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
189
+
190
+ # Combine, with P*_rs.nii files first
191
+ nii_repo_files = p_rs_files + other_nii_files
192
+
193
+ if nii_repo_files:
194
+ print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
195
+
196
+ # Download each file
197
+ for nii_file in nii_repo_files:
198
+ try:
199
+ file_path = hf_hub_download(
200
+ repo_id=dataset_name,
201
+ filename=nii_file,
202
+ repo_type="dataset",
203
+ cache_dir=temp_dir
204
+ )
205
+ nii_files.append(file_path)
206
+ print(f"✓ Downloaded {nii_file}")
207
+ except Exception as e:
208
+ print(f"Error downloading {nii_file}: {e}")
209
+ except Exception as e:
210
+ print(f"Error listing repository files: {e}")
211
+ print("Will try alternative approaches...")
212
+
213
+ # If repo listing fails, try with common NIfTI file patterns directly
214
+ if not nii_files:
215
+ print("Trying common NIfTI file patterns...")
216
+
217
+ # Focus specifically on P*_rs.nii pattern
218
+ patterns = []
219
+
220
+ # Generate P01_rs.nii through P30_rs.nii
221
+ for i in range(1, 31): # Try subjects 1-30
222
+ patterns.append(f"P{i:02d}_rs.nii")
223
+
224
+ # Also try with .nii.gz extension
225
+ for i in range(1, 31):
226
+ patterns.append(f"P{i:02d}_rs.nii.gz")
227
+
228
+ # Include a few other common patterns as fallbacks
229
+ patterns.extend([
230
+ "sub-01_task-rest_bold.nii.gz", # BIDS format
231
+ "fmri.nii.gz", "bold.nii.gz",
232
+ "rest.nii.gz"
233
+ ])
234
+
235
+ for pattern in patterns:
236
+ try:
237
+ print(f"Trying to download {pattern}...")
238
+ file_path = hf_hub_download(
239
+ repo_id=dataset_name,
240
+ filename=pattern,
241
+ repo_type="dataset",
242
+ cache_dir=temp_dir
243
+ )
244
+ nii_files.append(file_path)
245
+ print(f"✓ Successfully downloaded {pattern}")
246
+ except Exception as e:
247
+ print(f"× Failed to download {pattern}")
248
+
249
+ # If we still couldn't find any files, check if data files are nested
250
+ if not nii_files:
251
+ print("Checking for nested data files...")
252
+ nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
253
+
254
+ for path in nested_paths:
255
+ for pattern in patterns:
256
+ nested_file = f"{path}{pattern}"
257
+ try:
258
+ print(f"Trying to download {nested_file}...")
259
+ file_path = hf_hub_download(
260
+ repo_id=dataset_name,
261
+ filename=nested_file,
262
+ repo_type="dataset",
263
+ cache_dir=temp_dir
264
+ )
265
+ nii_files.append(file_path)
266
+ print(f"✓ Successfully downloaded {nested_file}")
267
+ # If we found one file in this directory, try to find all files in it
268
+ try:
269
+ all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
270
+ nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
271
+ print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
272
+
273
+ for nii_file in nii_files_in_dir:
274
+ if nii_file != nested_file: # Skip the one we already downloaded
275
+ try:
276
+ file_path = hf_hub_download(
277
+ repo_id=dataset_name,
278
+ filename=nii_file,
279
+ repo_type="dataset",
280
+ cache_dir=temp_dir
281
+ )
282
+ nii_files.append(file_path)
283
+ print(f"✓ Downloaded {nii_file}")
284
+ except Exception as e:
285
+ print(f"Error downloading {nii_file}: {e}")
286
+ except Exception as e:
287
+ print(f"Error finding additional files in {path}: {e}")
288
+ except Exception as e:
289
+ pass
290
+
291
+ except Exception as e:
292
+ print(f"Error during repository exploration: {e}")
293
+
294
+ # If we still don't have any files, try to search for P*_rs.nii pattern specifically
295
+ if not nii_files:
296
+ print("Trying to find files matching P*_rs.nii pattern specifically...")
297
+
298
+ try:
299
+ # List all files in the repository (if we haven't already)
300
+ if not 'all_repo_files' in locals():
301
+ from huggingface_hub import list_repo_files
302
+ try:
303
+ all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
304
+ except Exception as e:
305
+ print(f"Error listing repo files: {e}")
306
+ all_repo_files = []
307
+
308
+ # Look for files matching the pattern exactly (P*_rs.nii)
309
+ pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
310
+
311
+ # If we don't find any exact matches, try a more relaxed pattern
312
+ if not pattern_files:
313
+ pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
314
+
315
+ if pattern_files:
316
+ print(f"Found {len(pattern_files)} files matching rs.nii pattern")
317
+
318
+ # Download each file
319
+ for pattern_file in pattern_files:
320
+ try:
321
+ file_path = hf_hub_download(
322
+ repo_id=dataset_name,
323
+ filename=pattern_file,
324
+ repo_type="dataset",
325
+ cache_dir=temp_dir
326
+ )
327
+ nii_files.append(file_path)
328
+ print(f"✓ Downloaded {pattern_file}")
329
+ except Exception as e:
330
+ print(f"Error downloading {pattern_file}: {e}")
331
+ except Exception as e:
332
+ print(f"Error searching for pattern files: {e}")
333
+
334
+ print(f"Found total of {len(nii_files)} NIfTI files")
335
+ except Exception as e:
336
+ print(f"Unexpected error during NIfTI file search: {e}")
337
+ import traceback
338
+ traceback.print_exc()
339
+
340
+ # If we found NIfTI files, process them to FC matrices
341
+ if nii_files:
342
+ print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
343
+
344
+ # Load Power 264 atlas
345
+ from nilearn import datasets
346
+ power = datasets.fetch_coords_power_2011()
347
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
348
+
349
+ masker = input_data.NiftiSpheresMasker(
350
+ coords, radius=5,
351
+ standardize=True,
352
+ memory='nilearn_cache', memory_level=1,
353
+ verbose=0,
354
+ detrend=True,
355
+ low_pass=0.1,
356
+ high_pass=0.01,
357
+ t_r=2.0 # Adjust TR according to your data
358
+ )
359
+
360
+ # Process fMRI data and compute FC matrices
361
+ fc_matrices = []
362
+ valid_files = 0
363
+ total_files = len(nii_files)
364
+
365
+ for nii_file in nii_files:
366
+ try:
367
+ print(f"Processing {nii_file}...")
368
+ fmri_img = load_img(nii_file)
369
+
370
+ # Check image dimensions
371
+ if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
372
+ print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
373
+ continue
374
+
375
+ try:
376
+ # Explicitly handle warnings about empty spheres
377
+ import warnings
378
+ with warnings.catch_warnings():
379
+ warnings.filterwarnings('ignore', message='.*empty.*')
380
+ time_series = masker.fit_transform(fmri_img)
381
+ except Exception as e:
382
+ if "empty" in str(e):
383
+ print(f"Warning: Some spheres are empty in {nii_file}. Using a different sphere radius.")
384
+
385
+ # Extract the list of empty spheres for logging
386
+ import re
387
+ empty_spheres = re.findall(r"\[(.*?)\]", str(e))
388
+ if empty_spheres:
389
+ print(f"Empty spheres: {empty_spheres[0]}")
390
+
391
+ # Try with a different radius
392
+ alternate_masker = input_data.NiftiSpheresMasker(
393
+ coords, radius=8, # Larger radius
394
+ standardize=True,
395
+ memory='nilearn_cache', memory_level=1,
396
+ verbose=0,
397
+ detrend=True,
398
+ low_pass=0.1,
399
+ high_pass=0.01,
400
+ t_r=2.0
401
+ )
402
+ try:
403
+ time_series = alternate_masker.fit_transform(fmri_img)
404
+ print(f"Successfully extracted time series with larger radius")
405
+ except Exception as e2:
406
+ print(f"Error with alternate masker: {e2}")
407
+ print(f"Skipping this file due to empty spheres")
408
+ continue # Skip this file entirely
409
+ else:
410
+ print(f"Unknown error in masker: {e}")
411
+ continue # Skip this file if there's any other error
412
+
413
+ # Validate time series data
414
+ if np.isnan(time_series).any() or np.isinf(time_series).any():
415
+ print(f"Warning: {nii_file} contains NaN or Inf values after masking")
416
+ # Replace NaNs with zeros for this file
417
+ time_series = np.nan_to_num(time_series)
418
+
419
+ correlation_measure = connectome.ConnectivityMeasure(
420
+ kind='correlation',
421
+ vectorize=False,
422
+ discard_diagonal=False
423
+ )
424
+
425
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
426
+
427
+ # Check for invalid correlation values
428
+ if np.isnan(fc_matrix).any():
429
+ print(f"Warning: {nii_file} produced NaN correlation values")
430
+ continue
431
+
432
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
433
+ fc_triu = fc_matrix[triu_indices]
434
+
435
+ # Fisher z-transform with proper bounds check
436
+ # Clip correlation values to valid range for arctanh
437
+ fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
438
+ fc_triu = np.arctanh(fc_triu_clipped)
439
+
440
+ fc_matrices.append(fc_triu)
441
+ valid_files += 1
442
+ print(f"Successfully processed {nii_file} to FC matrix")
443
+
444
+ except Exception as e:
445
+ print(f"Error processing {nii_file}: {e}")
446
+
447
+ if fc_matrices:
448
+ print(f"Successfully processed {valid_files} out of {total_files} files")
449
+
450
+ # Ensure all matrices have the same dimensions
451
+ dims = [m.shape[0] for m in fc_matrices]
452
+ if len(set(dims)) > 1:
453
+ print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
454
+ # Use the most common dimension
455
+ from collections import Counter
456
+ most_common_dim = Counter(dims).most_common(1)[0][0]
457
+ print(f"Using most common dimension: {most_common_dim}")
458
+ fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
459
+
460
+ X = np.array(fc_matrices)
461
+
462
+ # Normalize the FC data
463
+ mean_x = np.mean(X, axis=0)
464
+ std_x = np.std(X, axis=0)
465
+
466
+ # Handle zero standard deviation
467
+ std_x[std_x == 0] = 1.0
468
+
469
+ X = (X - mean_x) / std_x
470
+ print(f"Created FC matrices with shape {X.shape}")
471
+
472
+ # Make sure demo_data matches the number of FC matrices
473
+ if len(demo_data[0]) != X.shape[0]:
474
+ print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
475
+ f"doesn't match number of FC matrices ({X.shape[0]})")
476
+ # Adjust demo_data to match FC matrices
477
+ indices = list(range(min(len(demo_data[0]), X.shape[0])))
478
+ X = X[indices]
479
+ demo_data = [d[indices] for d in demo_data]
480
+
481
+ return X, demo_data, demo_types
482
+
483
+ print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
484
+ # Return a placeholder with the right demographics but empty FC
485
+ n_subjects = len(dataset)
486
+ n_rois = 264
487
+ fc_dim = (n_rois * (n_rois - 1)) // 2
488
+ X = np.zeros((n_subjects, fc_dim))
489
+ print(f"Created placeholder FC matrices with shape {X.shape}")
490
+ return X, demo_data, demo_types
491
+
492
+ elif isinstance(dataset_or_niifiles, str):
493
+ # Handle real dataset with actual fMRI data
494
+ dataset = load_dataset(dataset_or_niifiles, split="train")
495
+
496
+ # Load Power 264 atlas
497
+ from nilearn import datasets
498
+ power = datasets.fetch_coords_power_2011()
499
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
500
+
501
+ masker = input_data.NiftiSpheresMasker(
502
+ coords, radius=5,
503
+ standardize=True,
504
+ memory='nilearn_cache', memory_level=1,
505
+ verbose=0,
506
+ detrend=True,
507
+ low_pass=0.1,
508
+ high_pass=0.01,
509
+ t_r=2.0 # Adjust TR according to your data
510
+ )
511
+
512
+ # Load demographic data if needed
513
+ if demo_data is None:
514
+ if 'demographics' in dataset.features:
515
+ demo_df = pd.DataFrame(dataset['demographics'])
516
+
517
+ demo_data = [
518
+ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
519
+ demo_df['sex'].values if 'sex' in demo_df.columns else [],
520
+ demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
521
+ demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
522
+ ]
523
+
524
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
525
+
526
+ # Process fMRI data and compute FC matrices
527
+ fc_matrices = []
528
+ for nii_file in dataset['nii_files']:
529
+ fmri_img = load_img(nii_file)
530
+ time_series = masker.fit_transform(fmri_img)
531
+
532
+ correlation_measure = connectome.ConnectivityMeasure(
533
+ kind='correlation', vectorize=False, discard_diagonal=False
534
+ )
535
+
536
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
537
+
538
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
539
+ fc_triu = fc_matrix[triu_indices]
540
+
541
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
542
+
543
+ fc_matrices.append(fc_triu)
544
+
545
+ X = np.array(fc_matrices)
546
+
547
+ elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
548
+ # Handle a list of NIfTI files
549
+ # Similar processing as above but with local files
550
+ print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
551
+
552
+ # Load Power 264 atlas
553
+ from nilearn import datasets
554
+ power = datasets.fetch_coords_power_2011()
555
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
556
+
557
+ masker = input_data.NiftiSpheresMasker(
558
+ coords, radius=5,
559
+ standardize=True,
560
+ memory='nilearn_cache', memory_level=1,
561
+ verbose=0,
562
+ detrend=True,
563
+ low_pass=0.1,
564
+ high_pass=0.01,
565
+ t_r=2.0
566
+ )
567
+
568
+ fc_matrices = []
569
+ for nii_file in dataset_or_niifiles:
570
+ fmri_img = load_img(nii_file)
571
+ time_series = masker.fit_transform(fmri_img)
572
+
573
+ correlation_measure = connectome.ConnectivityMeasure(
574
+ kind='correlation', vectorize=False, discard_diagonal=False
575
+ )
576
+
577
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
578
+
579
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
580
+ fc_triu = fc_matrix[triu_indices]
581
+
582
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
583
+
584
+ fc_matrices.append(fc_triu)
585
+
586
+ X = np.array(fc_matrices)
587
+ else:
588
+ raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
589
 
590
  # Normalize the FC data
591
  X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
592
 
593
+ return X, demo_data, demo_types
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main.py CHANGED
@@ -1,150 +1,291 @@
1
  import os
 
 
 
 
2
  import numpy as np
3
  import torch
4
  from pathlib import Path
 
 
 
 
 
 
5
  import pandas as pd
6
- from data_preprocessing import load_and_preprocess_data
7
- from vae_model import DemoVAE
8
- from rcf_prediction import AphasiaTreatmentPredictor
9
- from visualization import plot_fc_matrices, plot_learning_curves
10
- from config import MODEL_CONFIG
11
- import matplotlib.pyplot as plt
12
 
13
- def run_analysis(data_dir="data",
14
- demographic_file="demographics.csv",
15
- treatment_file="treatment_outcomes.csv",
16
- latent_dim=32,
17
- nepochs=1000,
18
- bsize=16,
19
- save_model=True):
20
  """
21
- Run the complete analysis pipeline
22
  """
23
- # Update MODEL_CONFIG with user-specified parameters
24
- MODEL_CONFIG.update({
25
- 'latent_dim': latent_dim,
26
- 'nepochs': nepochs,
27
- 'bsize': bsize
28
- })
29
 
30
- # Create output directories
31
- os.makedirs('models', exist_ok=True)
32
- os.makedirs('results', exist_ok=True)
33
 
34
- # Load and preprocess data
35
- print("Loading and preprocessing data...")
36
- X, demo_data, demo_types = load_and_preprocess_data(data_dir, demographic_file)
 
37
 
38
- # Load treatment outcomes
39
- treatment_df = pd.read_csv(treatment_file)
40
- treatment_outcomes = treatment_df['outcome_score'].values
 
 
41
 
42
- # Initialize and train VAE
43
- print("Training VAE...")
44
- vae = DemoVAE(**MODEL_CONFIG)
45
- train_losses, val_losses = vae.fit(X, demo_data, demo_types)
46
 
47
- # Get latent representations
48
- print("Extracting latent representations...")
49
- latents = vae.get_latents(X)
50
-
51
- # Initialize and train treatment predictor
52
- print("Training treatment predictor...")
53
- predictor = AphasiaTreatmentPredictor(n_estimators=100)
54
-
55
- # Prepare demographics for predictor
56
- demographics = {
57
- 'age_at_stroke': demo_data[0],
58
- 'sex': demo_data[1],
59
- 'months_post_stroke': demo_data[2],
60
- 'wab_score': demo_data[3]
61
- }
62
-
63
- # Cross-validate the predictor
64
- print("Performing cross-validation...")
65
- cv_mean, cv_std, predictions, prediction_stds = predictor.cross_validate(
66
- latents=latents,
67
- demographics=demographics,
68
- treatment_outcomes=treatment_outcomes
69
  )
70
 
71
- # Fit final predictor model
72
- predictor.fit(latents, demographics, treatment_outcomes)
73
-
74
- # Save models if requested
75
- if save_model:
76
- print("Saving models...")
77
- vae.save('models/vae_model.pt')
78
- torch.save({
79
- 'predictor_state': predictor.rf_regressor,
80
- 'feature_importance': predictor.feature_importance
81
- }, 'models/predictor_model.pt')
82
-
83
- # Generate visualizations
84
- print("Generating visualizations...")
85
-
86
- # FC matrix visualization
87
- reconstructed = vae.transform(X, demo_data, demo_types)
88
- generated = vae.transform(1,
89
- [d[:1] for d in demo_data],
90
- demo_types)
91
- fc_fig = plot_fc_matrices(X[0], reconstructed[0], generated[0])
92
 
93
- # Learning curves
94
- learning_fig = plot_learning_curves(train_losses, val_losses)
95
-
96
- # Feature importance
97
- importance_fig = predictor.plot_feature_importance()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
- # Prediction performance
100
- performance_fig = plt.figure(figsize=(8, 6))
101
- plt.scatter(treatment_outcomes, predictions)
102
- plt.plot([min(treatment_outcomes), max(treatment_outcomes)],
103
- [min(treatment_outcomes), max(treatment_outcomes)],
104
- 'r--')
105
- plt.fill_between(treatment_outcomes,
106
- predictions - 2*prediction_stds,
107
- predictions + 2*prediction_stds,
108
- alpha=0.2, color='gray')
109
- plt.xlabel('Actual Outcome')
110
- plt.ylabel('Predicted Outcome')
111
- plt.title(f'Treatment Outcome Prediction\nR² = {cv_mean:.3f} ± {cv_std:.3f}')
112
- plt.tight_layout()
113
 
114
- # Save results
115
- print("Saving results...")
116
- np.save('results/latents.npy', latents)
117
- np.save('results/predictions.npy', predictions)
118
- np.save('results/prediction_stds.npy', prediction_stds)
 
119
 
120
- results = {
121
- 'vae': vae,
122
- 'predictor': predictor,
123
- 'latents': latents,
124
- 'cv_scores': (cv_mean, cv_std),
125
- 'predictions': predictions,
126
- 'prediction_stds': prediction_stds,
127
- 'figures': {
128
- 'fc_analysis': fc_fig,
129
- 'learning_curves': learning_fig,
130
- 'importance': importance_fig,
131
- 'performance': performance_fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
  }
133
- }
134
-
135
- print("Analysis complete!")
136
- return results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  if __name__ == "__main__":
139
  import argparse
140
 
141
- parser = argparse.ArgumentParser(description='Run Aphasia Treatment Analysis')
142
- parser.add_argument('--data_dir', type=str, default='data',
143
- help='Directory containing fMRI data')
144
- parser.add_argument('--demographic_file', type=str, default='demographics.csv',
145
  help='Path to demographic data CSV file')
146
- parser.add_argument('--treatment_file', type=str, default='treatment_outcomes.csv',
147
- help='Path to treatment outcomes CSV file')
148
  parser.add_argument('--latent_dim', type=int, default=32,
149
  help='Dimension of latent space')
150
  parser.add_argument('--nepochs', type=int, default=1000,
@@ -152,16 +293,20 @@ if __name__ == "__main__":
152
  parser.add_argument('--bsize', type=int, default=16,
153
  help='Batch size for training')
154
  parser.add_argument('--no_save', action='store_false',
155
- help='Do not save the models')
 
 
156
 
157
  args = parser.parse_args()
158
 
159
- results = run_analysis(
160
  data_dir=args.data_dir,
161
  demographic_file=args.demographic_file,
162
- treatment_file=args.treatment_file,
163
  latent_dim=args.latent_dim,
164
  nepochs=args.nepochs,
165
  bsize=args.bsize,
166
- save_model=args.no_save
 
167
  )
 
 
 
1
  import os
2
+ import sys
3
+ # Add the src directory to the path so we can import from demovae
4
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
5
+
6
  import numpy as np
7
  import torch
8
  from pathlib import Path
9
+ import nibabel as nib
10
+ from data_preprocessing import preprocess_fmri_to_fc
11
+ from src.demovae.sklearn import DemoVAE
12
+ from analysis import analyze_fc_patterns
13
+ from visualization import visualize_fc_analysis
14
+ from config import MODEL_CONFIG, DATASET_CONFIG
15
  import pandas as pd
16
+ import io
17
+ from typing import List, Dict, Union, Tuple, Any
 
 
 
 
18
 
19
+ def train_fc_vae(X, demo_data, demo_types, model_config):
 
 
 
 
 
 
20
  """
21
+ Train a VAE model on functional connectivity matrices
22
  """
23
+ n_rois = 264
24
+ input_dim = (n_rois * (n_rois - 1)) // 2
 
 
 
 
25
 
26
+ print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
 
 
27
 
28
+ # Ensure X is a numpy array with correct data type
29
+ if not isinstance(X, np.ndarray):
30
+ print(f"Converting X from {type(X)} to numpy array")
31
+ X = np.array(X, dtype=np.float32)
32
 
33
+ # Ensure demo_data contains numpy arrays
34
+ for i, d in enumerate(demo_data):
35
+ if not isinstance(d, np.ndarray):
36
+ print(f"Converting demographic {i} from {type(d)} to numpy array")
37
+ demo_data[i] = np.array(d)
38
 
39
+ # Check for NaN or Inf values
40
+ if np.isnan(X).any() or np.isinf(X).any():
41
+ print("Warning: X contains NaN or Inf values. Replacing with zeros.")
42
+ X = np.nan_to_num(X)
43
 
44
+ # Create the VAE model
45
+ vae = DemoVAE(
46
+ latent_dim=model_config['latent_dim'],
47
+ nepochs=model_config['nepochs'],
48
+ bsize=model_config['bsize'],
49
+ loss_rec_mult=model_config.get('loss_rec_mult', 100),
50
+ loss_decor_mult=model_config.get('loss_decor_mult', 10),
51
+ lr=model_config.get('lr', 1e-4),
52
+ use_cuda=torch.cuda.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  )
54
 
55
+ print("Fitting VAE model...")
56
+ vae.fit(X, demo_data, demo_types)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ return vae, X, demo_data, demo_types
59
+
60
+ def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
61
+ """
62
+ Load fMRI data and demographics from HuggingFace dataset or local files
63
+ """
64
+ if use_hf_dataset:
65
+ # Load from HuggingFace Datasets
66
+ from datasets import load_dataset
67
+
68
+ print(f"Loading dataset from HuggingFace: {data_dir}")
69
+ dataset = load_dataset(data_dir)
70
+
71
+ print(f"Dataset columns: {dataset['train'].column_names}")
72
+
73
+ # Get demographics directly from the dataset
74
+ # Create a DataFrame from the dataset features
75
+ demo_df = pd.DataFrame({
76
+ 'ID': dataset['train']['ID'],
77
+ 'wab_aq': dataset['train']['wab_aq'],
78
+ 'age': dataset['train']['age'],
79
+ 'mpo': dataset['train']['mpo'],
80
+ 'education': dataset['train']['education'],
81
+ 'gender': dataset['train']['gender'],
82
+ 'handedness': dataset['train']['handedness']
83
+ })
84
+
85
+ print(f"Loaded demographic data with {len(demo_df)} subjects")
86
+
87
+ # Extract demographic data matching our expected format
88
+ # Map the dataset columns to our expected format
89
+ demo_data = [
90
+ demo_df['age'].values, # age at stroke -> age
91
+ demo_df['gender'].values, # sex -> gender
92
+ demo_df['mpo'].values, # months post stroke -> mpo
93
+ demo_df['wab_aq'].values # wab score -> wab_aq
94
+ ]
95
+
96
+ # Check for FC matrices in the dataset
97
+ fc_columns = []
98
+ for col in dataset['train'].column_names:
99
+ if col.startswith("fc_") or "_fc" in col:
100
+ fc_columns.append(col)
101
+
102
+ if fc_columns:
103
+ print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
104
+ # Extract FC matrices
105
+ fc_matrices = []
106
+ for fc_col in fc_columns:
107
+ fc_matrices.append(dataset['train'][fc_col])
108
+
109
+ # If we have FC matrices, return them directly
110
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
111
+ return fc_matrices, demo_data, demo_types
112
+
113
+ # If no FC matrices, look for .nii files
114
+ nii_files = []
115
+ for col in dataset['train'].column_names:
116
+ if col.endswith(".nii.gz") or col.endswith(".nii"):
117
+ nii_files.append(dataset['train'][col])
118
+
119
+ if nii_files:
120
+ print(f"Found {len(nii_files)} .nii files")
121
+ else:
122
+ print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
123
+ # If no structured data is found, we can try to download raw files later
124
+
125
+ else:
126
+ # Original local file loading
127
+ # Load demographics
128
+ demo_df = pd.read_csv(demographic_file)
129
+
130
+ demo_data = [
131
+ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
132
+ demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
133
+ demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
134
+ demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
135
+ ]
136
+
137
+ # Load fMRI files
138
+ nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
139
 
140
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
141
+ return nii_files, demo_data, demo_types
142
+
143
+ def run_fc_analysis(data_dir="SreekarB/OSFData",
144
+ demographic_file=None,
145
+ latent_dim=32,
146
+ nepochs=1000,
147
+ bsize=16,
148
+ save_model=True,
149
+ use_hf_dataset=True,
150
+ return_data=False):
 
 
 
151
 
152
+ # Update MODEL_CONFIG with user-specified parameters
153
+ MODEL_CONFIG.update({
154
+ 'latent_dim': latent_dim,
155
+ 'nepochs': nepochs,
156
+ 'bsize': bsize
157
+ })
158
 
159
+ try:
160
+ # Load data
161
+ print("Loading data...")
162
+ nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
163
+
164
+ # For SreekarB/OSFData, directly generate synthetic FC matrices
165
+ if data_dir == "SreekarB/OSFData" and use_hf_dataset:
166
+ print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
167
+ X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
168
+ # Check if we got FC matrices directly
169
+ elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
170
+ print("Using pre-computed FC matrices...")
171
+ # Convert list of FC matrices to numpy array
172
+ X = np.stack([np.array(fc) for fc in nii_files])
173
+ else:
174
+ # Prepare data by converting fMRI to FC matrices
175
+ print("Converting fMRI data to FC matrices...")
176
+ X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
177
+
178
+ # Print shapes and data types
179
+ print(f"X shape: {X.shape}, type: {type(X)}")
180
+ for i, d in enumerate(demo_data):
181
+ print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
182
+
183
+ # Train VAE and get data
184
+ print("Training VAE...")
185
+ try:
186
+ # Use the proper DemoVAE implementation from src/demovae/sklearn.py
187
+ vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
188
+
189
+ if save_model:
190
+ print("Saving model...")
191
+ os.makedirs('models', exist_ok=True)
192
+ # Use the save method from DemoVAE
193
+ vae.save('models/vae_model.pth')
194
+ print("Model saved successfully.")
195
+ except Exception as e:
196
+ print(f"Error during VAE training: {e}")
197
+ raise
198
+
199
+ # Get latent representations
200
+ print("Getting latent representations...")
201
+ latents = vae.get_latents(X)
202
+
203
+ # Analyze results
204
+ print("Analyzing demographic relationships...")
205
+ demographics = {
206
+ 'age': demo_data[0],
207
+ 'months_post_onset': demo_data[2],
208
+ 'wab_aq': demo_data[3]
209
  }
210
+ analysis_results = analyze_fc_patterns(latents, demographics)
211
+
212
+ # Generate new FC matrix
213
+ print("Generating new FC matrices...")
214
+
215
+ # Get data types from original demographic data for proper conversion
216
+ demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
217
+
218
+ # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
219
+ new_demographics = [
220
+ np.array([60.0], dtype=np.float64), # age
221
+ np.array(['M'], dtype=np.str_), # gender
222
+ np.array([12.0], dtype=np.float64), # months post onset
223
+ np.array([80.0], dtype=np.float64) # wab score
224
+ ]
225
+
226
+ # Verify the demographic data arrays match the expected types
227
+ print("Demographic data types:")
228
+ for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
229
+ print(f" {name}: shape={data.shape}, dtype={data.dtype}")
230
+
231
+ print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
232
+ try:
233
+ generated_fc = vae.transform(1, new_demographics, demo_types)
234
+ except Exception as e:
235
+ print(f"Error generating new FC matrix: {e}")
236
+ # Try with a fallback approach
237
+ print("Trying alternative generation approach...")
238
+ # If specific gender is causing issues, try the first gender from training data
239
+ new_demographics[1] = np.array([demo_data[1][0]])
240
+ generated_fc = vae.transform(1, new_demographics, demo_types)
241
+ reconstructed_fc = vae.transform(X, demo_data, demo_types)
242
+
243
+ # Visualize results
244
+ print("Creating visualizations...")
245
+ fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
246
+
247
+ # If requested, return additional data for accuracy calculations
248
+ if return_data:
249
+ results = {
250
+ 'vae': vae,
251
+ 'X': X,
252
+ 'latents': latents,
253
+ 'demographics': demographics,
254
+ 'reconstructed_fc': reconstructed_fc,
255
+ 'generated_fc': generated_fc,
256
+ 'analysis_results': analysis_results
257
+ }
258
+ return fig, results
259
+
260
+ return fig
261
+
262
+ except Exception as e:
263
+ import traceback
264
+ print(f"Error in run_fc_analysis: {str(e)}")
265
+ print(traceback.format_exc())
266
+
267
+ # Create a dummy figure with error message
268
+ import matplotlib.pyplot as plt
269
+ fig = plt.figure(figsize=(10, 6))
270
+ plt.text(0.5, 0.5, f"Error: {str(e)}",
271
+ horizontalalignment='center', verticalalignment='center',
272
+ fontsize=12, color='red')
273
+ plt.axis('off')
274
+
275
+ # Return the error figure and empty results if requested
276
+ if return_data:
277
+ return fig, None
278
+
279
+ return fig
280
 
281
  if __name__ == "__main__":
282
  import argparse
283
 
284
+ parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
285
+ parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
286
+ help='HuggingFace dataset ID or directory containing fMRI data')
287
+ parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
288
  help='Path to demographic data CSV file')
 
 
289
  parser.add_argument('--latent_dim', type=int, default=32,
290
  help='Dimension of latent space')
291
  parser.add_argument('--nepochs', type=int, default=1000,
 
293
  parser.add_argument('--bsize', type=int, default=16,
294
  help='Batch size for training')
295
  parser.add_argument('--no_save', action='store_false',
296
+ help='Do not save the model')
297
+ parser.add_argument('--use_local', action='store_true',
298
+ help='Use local data instead of HuggingFace dataset')
299
 
300
  args = parser.parse_args()
301
 
302
+ fig = run_fc_analysis(
303
  data_dir=args.data_dir,
304
  demographic_file=args.demographic_file,
 
305
  latent_dim=args.latent_dim,
306
  nepochs=args.nepochs,
307
  bsize=args.bsize,
308
+ save_model=args.no_save,
309
+ use_hf_dataset=not args.use_local
310
  )
311
+ fig.show()
312
+