SreekarB commited on
Commit
d526dee
·
verified ·
1 Parent(s): a4c8f0c

Upload 5 files

Browse files
Files changed (3) hide show
  1. app.py +29 -15
  2. main.py +108 -35
  3. vae_model.py +23 -3
app.py CHANGED
@@ -504,23 +504,37 @@ class AphasiaPredictionApp:
504
 
505
  # Prepare prediction visualization if available
506
  if self.predictor and predictor_cv_results:
507
- # Get the outcome variable data
508
- if outcome_variable == 'wab_aq':
509
- outcomes = demographics['wab_aq']
510
- elif outcome_variable == 'age':
511
- outcomes = demographics['age']
512
- elif outcome_variable == 'mpo' or outcome_variable == 'months_post_onset':
513
- outcomes = demographics['mpo']
514
- else:
515
- # Try to find the outcome in demographics data
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
516
  outcomes = None
517
- for key in demographics:
518
- if outcome_variable.lower() in key.lower():
519
- outcomes = demographics[key]
520
- break
521
 
522
- # Create plots
523
- if 'prediction_stds' in predictor_cv_results and 'predictions' in predictor_cv_results:
524
  # Create prediction plots
525
  prediction_fig = self.create_prediction_plots(
526
  latents,
 
504
 
505
  # Prepare prediction visualization if available
506
  if self.predictor and predictor_cv_results:
507
+ try:
508
+ # Get the outcome variable data
509
+ outcomes = None
510
+ if demographics:
511
+ if outcome_variable == 'wab_aq' and 'wab_aq' in demographics:
512
+ outcomes = demographics['wab_aq']
513
+ elif outcome_variable == 'age' and 'age' in demographics:
514
+ outcomes = demographics['age']
515
+ elif (outcome_variable == 'mpo' or outcome_variable == 'months_post_onset') and 'mpo' in demographics:
516
+ outcomes = demographics['mpo']
517
+ else:
518
+ # Try to find the outcome in demographics data
519
+ for key in demographics:
520
+ if outcome_variable.lower() in key.lower():
521
+ outcomes = demographics[key]
522
+ logger.info(f"Found matching outcome variable: {key}")
523
+ break
524
+
525
+ if outcomes is None:
526
+ logger.warning(f"Could not find outcome variable '{outcome_variable}' in demographics")
527
+ # Create a dummy array to prevent errors
528
+ if 'predictions' in predictor_cv_results:
529
+ outcomes = np.zeros_like(predictor_cv_results['predictions'])
530
+ else:
531
+ logger.warning("Cannot create prediction plots without outcome data")
532
+ except Exception as e:
533
+ logger.error(f"Error getting outcome variable: {e}")
534
  outcomes = None
 
 
 
 
535
 
536
+ # Create plots if we have the necessary data
537
+ if outcomes is not None and 'prediction_stds' in predictor_cv_results and 'predictions' in predictor_cv_results:
538
  # Create prediction plots
539
  prediction_fig = self.create_prediction_plots(
540
  latents,
main.py CHANGED
@@ -84,7 +84,13 @@ def run_analysis(data_dir="data",
84
  # Initialize and train VAE
85
  print("Training VAE...")
86
  vae = DemoVAE(**MODEL_CONFIG)
87
- train_losses, val_losses = vae.fit(X, demo_data, demo_types)
 
 
 
 
 
 
88
 
89
  # Get latent representations
90
  print("Extracting latent representations...")
@@ -116,18 +122,28 @@ def run_analysis(data_dir="data",
116
  )
117
 
118
  # Extract results from CV
119
- mean_metrics = cv_results["mean_metrics"]
120
- fold_metrics = cv_results["fold_metrics"]
121
- predictions = cv_results["predictions"]
122
- prediction_stds = cv_results["prediction_stds"]
123
 
124
  # For regression, get R2 metrics, otherwise use accuracy
125
- if predictor.prediction_type == "regression":
126
- cv_mean = mean_metrics["r2"]
127
- cv_std = np.std([fold["r2"] for fold in fold_metrics])
128
- else:
129
- cv_mean = mean_metrics["accuracy"]
130
- cv_std = np.std([fold["accuracy"] for fold in fold_metrics])
 
 
 
 
 
 
 
 
 
 
131
 
132
  # Fit final predictor model
133
  predictor.fit(latents, demographics, treatment_outcomes)
@@ -145,31 +161,84 @@ def run_analysis(data_dir="data",
145
  print("Generating visualizations...")
146
 
147
  # FC matrix visualization
148
- reconstructed = vae.transform(X, demo_data, demo_types)
149
- generated = vae.transform(1,
150
- [d[:1] for d in demo_data],
151
- demo_types)
152
- fc_fig = plot_fc_matrices(X[0], reconstructed[0], generated[0])
 
 
 
 
 
 
 
153
 
154
  # Learning curves
155
- learning_fig = plot_learning_curves(train_losses, val_losses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
  # Feature importance
158
- importance_fig = predictor.plot_feature_importance()
 
 
 
 
 
 
 
159
 
160
  # Prediction performance
161
  performance_fig = plt.figure(figsize=(8, 6))
162
- plt.scatter(treatment_outcomes, predictions)
163
- plt.plot([min(treatment_outcomes), max(treatment_outcomes)],
164
- [min(treatment_outcomes), max(treatment_outcomes)],
165
- 'r--')
166
- plt.fill_between(treatment_outcomes,
167
- predictions - 2*prediction_stds,
168
- predictions + 2*prediction_stds,
169
- alpha=0.2, color='gray')
170
- plt.xlabel('Actual Outcome')
171
- plt.ylabel('Predicted Outcome')
172
- plt.title(f'Treatment Outcome Prediction\nR² = {cv_mean:.3f} ± {cv_std:.3f}')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173
  plt.tight_layout()
174
 
175
  # Save results
@@ -178,6 +247,15 @@ def run_analysis(data_dir="data",
178
  np.save('results/predictions.npy', predictions)
179
  np.save('results/prediction_stds.npy', prediction_stds)
180
 
 
 
 
 
 
 
 
 
 
181
  results = {
182
  'vae': vae,
183
  'predictor': predictor,
@@ -186,12 +264,7 @@ def run_analysis(data_dir="data",
186
  'cv_scores': (cv_mean, cv_std),
187
  'predictions': predictions,
188
  'prediction_stds': prediction_stds,
189
- 'predictor_cv_results': {
190
- 'mean_metrics': mean_metrics,
191
- 'fold_metrics': fold_metrics,
192
- 'predictions': predictions,
193
- 'prediction_stds': prediction_stds
194
- },
195
  'figures': {
196
  'vae': fc_fig, # Changed to match app.py expectations
197
  'fc_analysis': fc_fig,
 
84
  # Initialize and train VAE
85
  print("Training VAE...")
86
  vae = DemoVAE(**MODEL_CONFIG)
87
+ try:
88
+ train_losses, val_losses = vae.fit(X, demo_data, demo_types)
89
+ print(f"VAE training complete. Final train loss: {train_losses[-1]:.4f}, final validation loss: {val_losses[-1]:.4f}")
90
+ except Exception as e:
91
+ print(f"Error during VAE training: {e}")
92
+ print("Using empty lists for losses as fallback")
93
+ train_losses, val_losses = [], []
94
 
95
  # Get latent representations
96
  print("Extracting latent representations...")
 
122
  )
123
 
124
  # Extract results from CV
125
+ mean_metrics = cv_results.get("mean_metrics", {})
126
+ fold_metrics = cv_results.get("fold_metrics", [])
127
+ predictions = cv_results.get("predictions", np.zeros_like(treatment_outcomes))
128
+ prediction_stds = cv_results.get("prediction_stds", np.zeros_like(treatment_outcomes))
129
 
130
  # For regression, get R2 metrics, otherwise use accuracy
131
+ try:
132
+ if predictor.prediction_type == "regression":
133
+ cv_mean = mean_metrics.get("r2", 0.0)
134
+ if fold_metrics and "r2" in fold_metrics[0]:
135
+ cv_std = np.std([fold.get("r2", 0.0) for fold in fold_metrics])
136
+ else:
137
+ cv_std = 0.0
138
+ else:
139
+ cv_mean = mean_metrics.get("accuracy", 0.0)
140
+ if fold_metrics and "accuracy" in fold_metrics[0]:
141
+ cv_std = np.std([fold.get("accuracy", 0.0) for fold in fold_metrics])
142
+ else:
143
+ cv_std = 0.0
144
+ except Exception as e:
145
+ print(f"Error calculating CV metrics: {e}")
146
+ cv_mean, cv_std = 0.0, 0.0
147
 
148
  # Fit final predictor model
149
  predictor.fit(latents, demographics, treatment_outcomes)
 
161
  print("Generating visualizations...")
162
 
163
  # FC matrix visualization
164
+ try:
165
+ reconstructed = vae.transform(X, demo_data, demo_types)
166
+ generated = vae.transform(1,
167
+ [d[:1] for d in demo_data],
168
+ demo_types)
169
+ fc_fig = plot_fc_matrices(X[0], reconstructed[0], generated[0])
170
+ except Exception as e:
171
+ print(f"Error creating FC visualization: {e}")
172
+ fc_fig = plt.figure(figsize=(15, 5))
173
+ plt.text(0.5, 0.5, "FC visualization unavailable",
174
+ ha='center', va='center', transform=plt.gca().transAxes)
175
+ plt.tight_layout()
176
 
177
  # Learning curves
178
+ try:
179
+ if train_losses and val_losses:
180
+ learning_fig = plot_learning_curves(train_losses, val_losses)
181
+ else:
182
+ print("No training history available for learning curves")
183
+ learning_fig = plt.figure(figsize=(10, 6))
184
+ plt.text(0.5, 0.5, "Learning curve data unavailable",
185
+ ha='center', va='center', transform=plt.gca().transAxes)
186
+ plt.tight_layout()
187
+ except Exception as e:
188
+ print(f"Error creating learning curve plot: {e}")
189
+ learning_fig = plt.figure(figsize=(10, 6))
190
+ plt.text(0.5, 0.5, "Error creating learning curves",
191
+ ha='center', va='center', transform=plt.gca().transAxes)
192
+ plt.tight_layout()
193
 
194
  # Feature importance
195
+ try:
196
+ importance_fig = predictor.plot_feature_importance()
197
+ except Exception as e:
198
+ print(f"Error creating feature importance plot: {e}")
199
+ importance_fig = plt.figure(figsize=(8, 6))
200
+ plt.text(0.5, 0.5, "Feature importance unavailable",
201
+ ha='center', va='center', transform=plt.gca().transAxes)
202
+ plt.tight_layout()
203
 
204
  # Prediction performance
205
  performance_fig = plt.figure(figsize=(8, 6))
206
+
207
+ # Check if we have valid predictions
208
+ if len(treatment_outcomes) > 0 and len(predictions) == len(treatment_outcomes):
209
+ try:
210
+ # Only create scatter plot if we have matching data
211
+ plt.scatter(treatment_outcomes, predictions)
212
+
213
+ # Reference line
214
+ min_val = min(np.min(treatment_outcomes), np.min(predictions))
215
+ max_val = max(np.max(treatment_outcomes), np.max(predictions))
216
+ plt.plot([min_val, max_val], [min_val, max_val], 'r--')
217
+
218
+ # Confidence band
219
+ plt.fill_between(treatment_outcomes,
220
+ predictions - 2*prediction_stds,
221
+ predictions + 2*prediction_stds,
222
+ alpha=0.2, color='gray')
223
+
224
+ # Labels
225
+ plt.xlabel('Actual Outcome')
226
+ plt.ylabel('Predicted Outcome')
227
+
228
+ # Title with metrics
229
+ if predictor.prediction_type == "regression":
230
+ plt.title(f'Treatment Outcome Prediction\nR² = {cv_mean:.3f} ± {cv_std:.3f}')
231
+ else:
232
+ plt.title(f'Treatment Outcome Prediction\nAccuracy = {cv_mean:.3f} ± {cv_std:.3f}')
233
+ except Exception as e:
234
+ print(f"Error creating performance plot: {e}")
235
+ plt.text(0.5, 0.5, "Error creating plot",
236
+ ha='center', va='center', transform=plt.gca().transAxes)
237
+ else:
238
+ # Handle case with no data
239
+ plt.text(0.5, 0.5, "No prediction data available",
240
+ ha='center', va='center', transform=plt.gca().transAxes)
241
+
242
  plt.tight_layout()
243
 
244
  # Save results
 
247
  np.save('results/predictions.npy', predictions)
248
  np.save('results/prediction_stds.npy', prediction_stds)
249
 
250
+ # Prepare predictor_cv_results with appropriate default values if missing
251
+ predictor_cv_results = {
252
+ 'mean_metrics': mean_metrics if mean_metrics else {},
253
+ 'fold_metrics': fold_metrics if fold_metrics else [],
254
+ 'predictions': predictions if len(predictions) > 0 else np.zeros(0),
255
+ 'prediction_stds': prediction_stds if len(prediction_stds) > 0 else np.zeros(0)
256
+ }
257
+
258
+ # Construct the final results dictionary
259
  results = {
260
  'vae': vae,
261
  'predictor': predictor,
 
264
  'cv_scores': (cv_mean, cv_std),
265
  'predictions': predictions,
266
  'prediction_stds': prediction_stds,
267
+ 'predictor_cv_results': predictor_cv_results,
 
 
 
 
 
268
  'figures': {
269
  'vae': fc_fig, # Changed to match app.py expectations
270
  'fc_analysis': fc_fig,
vae_model.py CHANGED
@@ -89,7 +89,7 @@ class DemoVAE(BaseEstimator):
89
  self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
90
 
91
  # Train VAE
92
- train_vae(
93
  self.vae, x, demo, demo_types,
94
  self.nepochs, self.pperiod, self.bsize,
95
  self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
@@ -97,7 +97,13 @@ class DemoVAE(BaseEstimator):
97
  self.lr, self.weight_decay, self.alpha, self.LR_C,
98
  self
99
  )
100
- return self
 
 
 
 
 
 
101
 
102
  def transform(self, x, demo, demo_types):
103
  if isinstance(x, int):
@@ -113,13 +119,19 @@ class DemoVAE(BaseEstimator):
113
  return to_numpy(z)
114
 
115
  def save(self, path):
 
 
 
116
  torch.save({
117
  'model_state_dict': self.vae.state_dict(),
118
  'params': self.get_params(),
119
  'pred_stats': self.pred_stats,
120
  'input_dim': self.input_dim,
121
- 'demo_dim': self.demo_dim
 
 
122
  }, path)
 
123
 
124
  def load(self, path):
125
  checkpoint = torch.load(path)
@@ -129,3 +141,11 @@ class DemoVAE(BaseEstimator):
129
  self.demo_dim = checkpoint['demo_dim']
130
  self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
131
  self.vae.load_state_dict(checkpoint['model_state_dict'])
 
 
 
 
 
 
 
 
 
89
  self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
90
 
91
  # Train VAE
92
+ train_losses, val_losses = train_vae(
93
  self.vae, x, demo, demo_types,
94
  self.nepochs, self.pperiod, self.bsize,
95
  self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
 
97
  self.lr, self.weight_decay, self.alpha, self.LR_C,
98
  self
99
  )
100
+
101
+ # Store the losses for later visualization
102
+ self.train_losses = train_losses
103
+ self.val_losses = val_losses
104
+
105
+ # Return the losses for immediate use
106
+ return train_losses, val_losses
107
 
108
  def transform(self, x, demo, demo_types):
109
  if isinstance(x, int):
 
119
  return to_numpy(z)
120
 
121
  def save(self, path):
122
+ train_losses = getattr(self, 'train_losses', [])
123
+ val_losses = getattr(self, 'val_losses', [])
124
+
125
  torch.save({
126
  'model_state_dict': self.vae.state_dict(),
127
  'params': self.get_params(),
128
  'pred_stats': self.pred_stats,
129
  'input_dim': self.input_dim,
130
+ 'demo_dim': self.demo_dim,
131
+ 'train_losses': train_losses,
132
+ 'val_losses': val_losses
133
  }, path)
134
+ print(f"Saved VAE model to {path}")
135
 
136
  def load(self, path):
137
  checkpoint = torch.load(path)
 
141
  self.demo_dim = checkpoint['demo_dim']
142
  self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
143
  self.vae.load_state_dict(checkpoint['model_state_dict'])
144
+
145
+ # Load training history if available
146
+ if 'train_losses' in checkpoint:
147
+ self.train_losses = checkpoint['train_losses']
148
+ if 'val_losses' in checkpoint:
149
+ self.val_losses = checkpoint['val_losses']
150
+
151
+ print(f"Loaded VAE model from {path}")