SreekarB commited on
Commit
b507484
·
verified ·
1 Parent(s): 7dd36eb

Upload 11 files

Browse files
app.py CHANGED
@@ -2,6 +2,9 @@ import gradio as gr
2
  from main import run_analysis
3
  from rcf_prediction import AphasiaTreatmentPredictor
4
  import numpy as np
 
 
 
5
  import matplotlib.pyplot as plt
6
  from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
7
  from visualization import plot_fc_matrices, plot_learning_curves
@@ -2091,13 +2094,14 @@ def create_interface():
2091
  def handle_fc_visualization():
2092
  """Generate FC visualization using stored data or synthetic data"""
2093
  try:
 
 
 
 
2094
  # Check if we have trained VAE and data
2095
  if app_state.get('vae_trained', False) and app_state.get('vae') is not None:
2096
  logger.info("Visualizing FC matrices from trained VAE")
2097
 
2098
- # Get visualization data
2099
- from visualization import plot_fc_matrices
2100
-
2101
  # If we have stored original and reconstructed matrices, use them
2102
  if app_state.get('original_fc') is not None and app_state.get('reconstructed_fc') is not None:
2103
  original = app_state['original_fc']
@@ -2208,6 +2212,10 @@ def create_interface():
2208
  # Tab 2: Random Forest Training Handler
2209
  def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
2210
  """Train the Random Forest model using the VAE latent representations"""
 
 
 
 
2211
  # Check if VAE has been trained or if we can use synthetic data
2212
  if not app_state.get('vae_trained', False) or app_state.get('latents') is None:
2213
  # Instead of error, create synthetic data for demonstration
@@ -2273,8 +2281,6 @@ def create_interface():
2273
 
2274
  # Train Random Forest predictor
2275
  from rcf_prediction import AphasiaTreatmentPredictor
2276
- import pandas as pd
2277
- import numpy as np
2278
 
2279
  # Get treatment outcomes data
2280
  # Check if we already created synthetic data
@@ -2429,6 +2435,9 @@ def create_interface():
2429
  def handle_treatment_prediction(fmri_file, age, sex, months, wab):
2430
  """Predict treatment outcome for a new patient"""
2431
  try:
 
 
 
2432
  # First, check if we have saved models we can use
2433
  rf_model_path = "results/treatment_predictor.joblib"
2434
  rf_available = os.path.exists(rf_model_path)
 
2
  from main import run_analysis
3
  from rcf_prediction import AphasiaTreatmentPredictor
4
  import numpy as np
5
+ # Configure matplotlib for headless environment
6
+ import matplotlib
7
+ matplotlib.use('Agg') # Use non-interactive backend
8
  import matplotlib.pyplot as plt
9
  from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri
10
  from visualization import plot_fc_matrices, plot_learning_curves
 
2094
  def handle_fc_visualization():
2095
  """Generate FC visualization using stored data or synthetic data"""
2096
  try:
2097
+ # Import necessary packages
2098
+ import numpy as np
2099
+ from visualization import plot_fc_matrices
2100
+
2101
  # Check if we have trained VAE and data
2102
  if app_state.get('vae_trained', False) and app_state.get('vae') is not None:
2103
  logger.info("Visualizing FC matrices from trained VAE")
2104
 
 
 
 
2105
  # If we have stored original and reconstructed matrices, use them
2106
  if app_state.get('original_fc') is not None and app_state.get('reconstructed_fc') is not None:
2107
  original = app_state['original_fc']
 
2212
  # Tab 2: Random Forest Training Handler
2213
  def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
2214
  """Train the Random Forest model using the VAE latent representations"""
2215
+ # Import necessary packages
2216
+ import numpy as np
2217
+ import pandas as pd
2218
+
2219
  # Check if VAE has been trained or if we can use synthetic data
2220
  if not app_state.get('vae_trained', False) or app_state.get('latents') is None:
2221
  # Instead of error, create synthetic data for demonstration
 
2281
 
2282
  # Train Random Forest predictor
2283
  from rcf_prediction import AphasiaTreatmentPredictor
 
 
2284
 
2285
  # Get treatment outcomes data
2286
  # Check if we already created synthetic data
 
2435
  def handle_treatment_prediction(fmri_file, age, sex, months, wab):
2436
  """Predict treatment outcome for a new patient"""
2437
  try:
2438
+ # Import necessary packages
2439
+ import numpy as np
2440
+
2441
  # First, check if we have saved models we can use
2442
  rf_model_path = "results/treatment_predictor.joblib"
2443
  rf_available = os.path.exists(rf_model_path)
demo_fc_visualization.py CHANGED
@@ -3,6 +3,9 @@ Demo script to visualize FC matrices from real fMRI data using nilearn's built-i
3
  """
4
 
5
  import numpy as np
 
 
 
6
  import matplotlib.pyplot as plt
7
  from nilearn import datasets
8
  from nilearn import input_data, connectome
 
3
  """
4
 
5
  import numpy as np
6
+ # Configure matplotlib for headless environment
7
+ import matplotlib
8
+ matplotlib.use('Agg') # Use non-interactive backend
9
  import matplotlib.pyplot as plt
10
  from nilearn import datasets
11
  from nilearn import input_data, connectome
direct_fc_visualization.py CHANGED
@@ -7,6 +7,9 @@ This script creates and visualizes FC matrices directly, without relying on fMRI
7
 
8
  import os
9
  import numpy as np
 
 
 
10
  import matplotlib.pyplot as plt
11
  from visualization import vector_to_matrix
12
 
 
7
 
8
  import os
9
  import numpy as np
10
+ # Configure matplotlib for headless environment
11
+ import matplotlib
12
+ matplotlib.use('Agg') # Use non-interactive backend
13
  import matplotlib.pyplot as plt
14
  from visualization import vector_to_matrix
15
 
fc_visualization.py CHANGED
@@ -6,6 +6,9 @@ independently from the prediction pipeline.
6
  """
7
 
8
  import numpy as np
 
 
 
9
  import matplotlib.pyplot as plt
10
  from pathlib import Path
11
  import argparse
 
6
  """
7
 
8
  import numpy as np
9
+ # Configure matplotlib for headless environment
10
+ import matplotlib
11
+ matplotlib.use('Agg') # Use non-interactive backend
12
  import matplotlib.pyplot as plt
13
  from pathlib import Path
14
  import argparse
huggingface_fc_visualization.py CHANGED
@@ -4,6 +4,9 @@ Script to visualize FC matrices from HuggingFace dataset, comparing original FC
4
 
5
  import os
6
  import numpy as np
 
 
 
7
  import matplotlib.pyplot as plt
8
  from datasets import load_dataset
9
  from fc_visualization import FCVisualizer
 
4
 
5
  import os
6
  import numpy as np
7
+ # Configure matplotlib for headless environment
8
+ import matplotlib
9
+ matplotlib.use('Agg') # Use non-interactive backend
10
  import matplotlib.pyplot as plt
11
  from datasets import load_dataset
12
  from fc_visualization import FCVisualizer
main.py CHANGED
@@ -8,6 +8,9 @@ from vae_model import DemoVAE
8
  from rcf_prediction import AphasiaTreatmentPredictor
9
  from visualization import plot_fc_matrices, plot_learning_curves
10
  from config import MODEL_CONFIG, PREDICTION_CONFIG
 
 
 
11
  import matplotlib.pyplot as plt
12
 
13
  def run_analysis(data_dir="data",
 
8
  from rcf_prediction import AphasiaTreatmentPredictor
9
  from visualization import plot_fc_matrices, plot_learning_curves
10
  from config import MODEL_CONFIG, PREDICTION_CONFIG
11
+ # Configure matplotlib for headless environment
12
+ import matplotlib
13
+ matplotlib.use('Agg') # Use non-interactive backend
14
  import matplotlib.pyplot as plt
15
 
16
  def run_analysis(data_dir="data",
rcf_prediction.py CHANGED
@@ -3,6 +3,9 @@ from sklearn.ensemble import RandomForestRegressor
3
  from sklearn.model_selection import cross_val_score, KFold
4
  import pandas as pd
5
  from sklearn.metrics import mean_squared_error, r2_score
 
 
 
6
  import matplotlib.pyplot as plt
7
  import os
8
  import joblib
@@ -82,6 +85,11 @@ class AphasiaTreatmentPredictor:
82
  self.feature_names = feature_names
83
 
84
  logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
 
 
 
 
 
85
  self.model.fit(X, treatment_outcomes)
86
 
87
  # Calculate feature importance
@@ -90,6 +98,7 @@ class AphasiaTreatmentPredictor:
90
  'importance': self.model.feature_importances_
91
  }).sort_values('importance', ascending=False)
92
 
 
93
  return self
94
 
95
  def predict(self, latents, demographics):
@@ -137,12 +146,14 @@ class AphasiaTreatmentPredictor:
137
  n_splits = adjusted_n_splits
138
 
139
  logger.info(f"Running {n_splits}-fold cross-validation on {sample_count} samples")
 
140
 
141
  # Use stratified KFold for regression to ensure balanced folds
142
  # or LeaveOneOut for very small datasets
143
  if sample_count <= 5:
144
  from sklearn.model_selection import LeaveOneOut
145
  logger.warning(f"Using Leave-One-Out CV for small dataset with {sample_count} samples")
 
146
  kf = LeaveOneOut()
147
  cv_iterator = kf.split(X)
148
  else:
@@ -158,11 +169,14 @@ class AphasiaTreatmentPredictor:
158
  X_train, X_test = X[train_idx], X[test_idx]
159
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
160
 
 
 
161
  # Clone the model for this fold
162
  fold_model = RandomForestRegressor(
163
  n_estimators=self.n_estimators,
164
  max_depth=self.max_depth,
165
- random_state=self.random_state
 
166
  )
167
 
168
  # Train the model
@@ -183,6 +197,7 @@ class AphasiaTreatmentPredictor:
183
  else:
184
  r2 = np.nan
185
  logger.warning(f"Fold {fold+1}: R² not calculated (insufficient samples or variance)")
 
186
 
187
  # MSE can always be calculated
188
  mse = rmse**2
@@ -213,6 +228,12 @@ class AphasiaTreatmentPredictor:
213
  fold_metrics.append(metrics)
214
  logger.info(f"Fold {fold+1} metrics: {metrics}")
215
 
 
 
 
 
 
 
216
  # Calculate average metrics
217
  avg_metrics = {}
218
  for key in fold_metrics[0].keys():
@@ -225,7 +246,15 @@ class AphasiaTreatmentPredictor:
225
 
226
  logger.info(f"Average CV metrics: {avg_metrics}")
227
 
 
 
 
 
 
 
228
  # Train final model on all data
 
 
229
  self.model.fit(X, treatment_outcomes)
230
 
231
  # Calculate feature importance
 
3
  from sklearn.model_selection import cross_val_score, KFold
4
  import pandas as pd
5
  from sklearn.metrics import mean_squared_error, r2_score
6
+ # Configure matplotlib for headless environment
7
+ import matplotlib
8
+ matplotlib.use('Agg') # Use non-interactive backend
9
  import matplotlib.pyplot as plt
10
  import os
11
  import joblib
 
85
  self.feature_names = feature_names
86
 
87
  logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
88
+ print(f"Random Forest: Building {self.n_estimators} trees...")
89
+
90
+ # Track progress during fit with verbose
91
+ # Set verbose to 2 for detailed per-tree progress
92
+ self.model.verbose = 1
93
  self.model.fit(X, treatment_outcomes)
94
 
95
  # Calculate feature importance
 
98
  'importance': self.model.feature_importances_
99
  }).sort_values('importance', ascending=False)
100
 
101
+ print(f"Random Forest: Training complete. Top features: {', '.join(self.feature_importance['feature'].head(3).tolist())}")
102
  return self
103
 
104
  def predict(self, latents, demographics):
 
146
  n_splits = adjusted_n_splits
147
 
148
  logger.info(f"Running {n_splits}-fold cross-validation on {sample_count} samples")
149
+ print(f"Random Forest: Starting {n_splits}-fold cross-validation with {sample_count} samples")
150
 
151
  # Use stratified KFold for regression to ensure balanced folds
152
  # or LeaveOneOut for very small datasets
153
  if sample_count <= 5:
154
  from sklearn.model_selection import LeaveOneOut
155
  logger.warning(f"Using Leave-One-Out CV for small dataset with {sample_count} samples")
156
+ print(f"Random Forest: Using Leave-One-Out cross-validation due to small sample size ({sample_count})")
157
  kf = LeaveOneOut()
158
  cv_iterator = kf.split(X)
159
  else:
 
169
  X_train, X_test = X[train_idx], X[test_idx]
170
  y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
171
 
172
+ print(f"Random Forest: Training fold {fold+1}/{n_splits} - {len(X_train)} training samples, {len(X_test)} test samples")
173
+
174
  # Clone the model for this fold
175
  fold_model = RandomForestRegressor(
176
  n_estimators=self.n_estimators,
177
  max_depth=self.max_depth,
178
+ random_state=self.random_state,
179
+ verbose=1 # Add verbosity
180
  )
181
 
182
  # Train the model
 
197
  else:
198
  r2 = np.nan
199
  logger.warning(f"Fold {fold+1}: R² not calculated (insufficient samples or variance)")
200
+ print(f"Random Forest: Fold {fold+1} - R² not calculated (insufficient samples or variance)")
201
 
202
  # MSE can always be calculated
203
  mse = rmse**2
 
228
  fold_metrics.append(metrics)
229
  logger.info(f"Fold {fold+1} metrics: {metrics}")
230
 
231
+ # Print a more user-friendly version of the fold results
232
+ r2_val = metrics.get('r2', np.nan)
233
+ rmse_val = metrics.get('rmse', np.nan)
234
+ r2_text = f"R² = {r2_val:.4f}" if not np.isnan(r2_val) else "R² = N/A"
235
+ print(f"Random Forest: Fold {fold+1} results - {r2_text}, RMSE = {rmse_val:.4f}")
236
+
237
  # Calculate average metrics
238
  avg_metrics = {}
239
  for key in fold_metrics[0].keys():
 
246
 
247
  logger.info(f"Average CV metrics: {avg_metrics}")
248
 
249
+ # Print a summary of cross-validation performance
250
+ r2_avg = avg_metrics.get('r2', np.nan)
251
+ rmse_avg = avg_metrics.get('rmse', np.nan)
252
+ r2_text = f"R² = {r2_avg:.4f}" if not np.isnan(r2_avg) else "R² = N/A"
253
+ print(f"Random Forest: Cross-validation complete - Average {r2_text}, RMSE = {rmse_avg:.4f}")
254
+
255
  # Train final model on all data
256
+ print(f"Random Forest: Training final model on all {len(X)} samples...")
257
+ self.model.verbose = 1
258
  self.model.fit(X, treatment_outcomes)
259
 
260
  # Calculate feature importance
utils.py CHANGED
@@ -156,7 +156,11 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
156
  epoch_losses.append(total_loss.item())
157
 
158
  # Record training loss
159
- train_losses.append(np.mean(epoch_losses))
 
 
 
 
160
 
161
  # Validation step
162
  if e % pperiod == 0:
@@ -167,8 +171,6 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
167
  val_loss = rmse(x, y).item()
168
  val_losses.append(val_loss)
169
 
170
- print(f'Epoch {e}/{nepochs} - '
171
- f'Train Loss: {train_losses[-1]:.4f} - '
172
- f'Val Loss: {val_loss:.4f}')
173
 
174
  return train_losses, val_losses
 
156
  epoch_losses.append(total_loss.item())
157
 
158
  # Record training loss
159
+ epoch_loss = np.mean(epoch_losses)
160
+ train_losses.append(epoch_loss)
161
+
162
+ # Print progress for every epoch
163
+ print(f'Epoch {e+1}/{nepochs} - Train Loss: {epoch_loss:.4f}')
164
 
165
  # Validation step
166
  if e % pperiod == 0:
 
171
  val_loss = rmse(x, y).item()
172
  val_losses.append(val_loss)
173
 
174
+ print(f' Validation - Val Loss: {val_loss:.4f}')
 
 
175
 
176
  return train_losses, val_losses
visualization.py CHANGED
@@ -1,3 +1,6 @@
 
 
 
1
  import matplotlib.pyplot as plt
2
  import numpy as np
3
 
 
1
+ # Configure matplotlib for headless environment
2
+ import matplotlib
3
+ matplotlib.use('Agg') # Use non-interactive backend
4
  import matplotlib.pyplot as plt
5
  import numpy as np
6
 
visualize_fc.py CHANGED
@@ -6,6 +6,9 @@ Standalone script to visualize FC matrices using the VAE.
6
  import os
7
  import sys
8
  import numpy as np
 
 
 
9
  import matplotlib.pyplot as plt
10
  from main import run_fc_analysis
11
  from config import PREDICTION_CONFIG
 
6
  import os
7
  import sys
8
  import numpy as np
9
+ # Configure matplotlib for headless environment
10
+ import matplotlib
11
+ matplotlib.use('Agg') # Use non-interactive backend
12
  import matplotlib.pyplot as plt
13
  from main import run_fc_analysis
14
  from config import PREDICTION_CONFIG