SreekarB commited on
Commit
55c1385
·
verified ·
1 Parent(s): 055f005

Upload 5 files

Browse files
Files changed (4) hide show
  1. app.py +60 -16
  2. main.py +53 -9
  3. utils.py +32 -9
  4. visualization.py +107 -12
app.py CHANGED
@@ -1763,23 +1763,67 @@ def create_learning_figure(vae):
1763
  """Helper function to create VAE learning curve figure"""
1764
  plt.close('all') # Close previous figures
1765
 
1766
- if hasattr(vae, 'train_losses') and hasattr(vae, 'val_losses') and vae.train_losses:
1767
- logger.info(f"Creating learning curve with {len(vae.train_losses)} loss points")
1768
- fig = plot_learning_curves(vae.train_losses, vae.val_losses)
1769
- # Force rendering
1770
- fig.canvas.draw()
1771
- logger.info("Successfully created learning curve figure")
1772
- return fig
1773
  else:
1774
- logger.warning("No loss data found in VAE model - creating empty learning figure")
1775
- fig = plt.figure(figsize=(10, 6))
1776
- plt.title("No learning curve data available")
1777
- plt.xlabel("Epoch")
1778
- plt.ylabel("Loss")
1779
- plt.text(0.5, 0.5, "Learning curve data unavailable",
1780
- ha='center', va='center', transform=plt.gca().transAxes)
1781
- fig.canvas.draw()
1782
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1783
 
1784
  def find_real_nifti_files(max_samples=2):
1785
  """Find real NIfTI files in the dataset or local directories, limited to the specified number"""
 
1763
  """Helper function to create VAE learning curve figure"""
1764
  plt.close('all') # Close previous figures
1765
 
1766
+ # First check if loss data exists in the VAE object
1767
+ has_train_losses = hasattr(vae, 'train_losses') and isinstance(vae.train_losses, (list, tuple)) and len(vae.train_losses) > 0
1768
+ has_val_losses = hasattr(vae, 'val_losses') and isinstance(vae.val_losses, (list, tuple)) and len(vae.val_losses) > 0
1769
+
1770
+ # Log the status for debugging
1771
+ if has_train_losses:
1772
+ logger.info(f"Found training losses: {len(vae.train_losses)} points")
1773
  else:
1774
+ logger.warning("No training loss data found in VAE model")
1775
+
1776
+ if has_val_losses:
1777
+ logger.info(f"Found validation losses: {len(vae.val_losses)} points")
1778
+ else:
1779
+ logger.warning("No validation loss data found in VAE model")
1780
+
1781
+ # If we have both train and validation losses, create the learning curve
1782
+ if has_train_losses and has_val_losses:
1783
+ logger.info(f"Creating learning curve with {len(vae.train_losses)} loss points")
1784
+ try:
1785
+ fig = plot_learning_curves(vae.train_losses, vae.val_losses)
1786
+ # Force rendering
1787
+ fig.canvas.draw()
1788
+ logger.info("Successfully created learning curve figure")
1789
+ return fig
1790
+ except Exception as e:
1791
+ logger.error(f"Error creating learning curve: {e}")
1792
+ # Fall through to the default figure below
1793
+
1794
+ # If we're missing one type of loss data but have the other, we can create a partial plot
1795
+ elif has_train_losses:
1796
+ logger.info("Creating learning curve with training losses only")
1797
+ try:
1798
+ # Create dummy validation losses (same as training but offset)
1799
+ dummy_val = [t * 1.1 for t in vae.train_losses]
1800
+ fig = plot_learning_curves(vae.train_losses, dummy_val)
1801
+ plt.title("VAE Learning Curve (Training Only)")
1802
+ plt.figtext(0.5, 0.01, "Note: Validation data unavailable",
1803
+ ha='center', fontsize=10, color='red')
1804
+ fig.canvas.draw()
1805
+ logger.info("Created partial learning curve with training data only")
1806
+ return fig
1807
+ except Exception as e:
1808
+ logger.error(f"Error creating partial learning curve: {e}")
1809
+ # Fall through to the default figure below
1810
+
1811
+ # Create a default figure if no loss data is available or plotting failed
1812
+ logger.warning("No complete loss data found - creating placeholder learning figure")
1813
+ fig = plt.figure(figsize=(10, 6))
1814
+ plt.title("VAE Learning Curve Data Unavailable", color='darkred')
1815
+ plt.xlabel("Epoch")
1816
+ plt.ylabel("Loss")
1817
+ plt.text(0.5, 0.5, "Learning curves will appear here after training",
1818
+ ha='center', va='center', transform=plt.gca().transAxes,
1819
+ fontsize=14)
1820
+ plt.text(0.5, 0.4, "Try using more training epochs to see learning progress",
1821
+ ha='center', va='center', transform=plt.gca().transAxes,
1822
+ fontsize=12, color='darkblue')
1823
+ plt.grid(True, alpha=0.3)
1824
+ plt.axis('on')
1825
+ fig.canvas.draw()
1826
+ return fig
1827
 
1828
  def find_real_nifti_files(max_samples=2):
1829
  """Find real NIfTI files in the dataset or local directories, limited to the specified number"""
main.py CHANGED
@@ -129,10 +129,21 @@ def run_analysis(data_dir="data",
129
 
130
  # Format demographics for predictor and results
131
  demographics = {}
 
 
132
  demo_keys = ['age_at_stroke', 'sex', 'months_post_stroke', 'wab_score']
 
 
 
133
  for i, key in enumerate(demo_keys):
134
  if i < len(demo_data):
135
  demographics[key] = demo_data[i]
 
 
 
 
 
 
136
 
137
  # Generate reconstructions and synthetic FC
138
  try:
@@ -185,19 +196,52 @@ def run_analysis(data_dir="data",
185
 
186
  # Learning curves
187
  try:
188
- if train_losses and val_losses:
189
- learning_fig = plot_learning_curves(train_losses, val_losses)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
  else:
191
- print("No training history available for learning curves")
192
- learning_fig = plt.figure(figsize=(10, 6))
193
- plt.text(0.5, 0.5, "Learning curve data unavailable",
194
- ha='center', va='center', transform=plt.gca().transAxes)
195
- plt.tight_layout()
 
 
 
 
 
 
 
 
196
  except Exception as e:
 
197
  print(f"Error creating learning curve plot: {e}")
 
 
 
198
  learning_fig = plt.figure(figsize=(10, 6))
199
- plt.text(0.5, 0.5, "Error creating learning curves",
200
- ha='center', va='center', transform=plt.gca().transAxes)
 
 
201
  plt.tight_layout()
202
 
203
  # Initialize results dictionary
 
129
 
130
  # Format demographics for predictor and results
131
  demographics = {}
132
+
133
+ # Define both standard and alternative keys
134
  demo_keys = ['age_at_stroke', 'sex', 'months_post_stroke', 'wab_score']
135
+ alternate_keys = {'age_at_stroke': 'age', 'months_post_stroke': 'mpo', 'wab_score': 'wab_aq'}
136
+
137
+ # Map demographic data to consistent keys
138
  for i, key in enumerate(demo_keys):
139
  if i < len(demo_data):
140
  demographics[key] = demo_data[i]
141
+ # Also add alternate versions of the key for compatibility
142
+ if key in alternate_keys:
143
+ demographics[alternate_keys[key]] = demo_data[i]
144
+
145
+ # Print the keys available in demographics for debugging
146
+ print(f"Demographics keys available: {list(demographics.keys())}")
147
 
148
  # Generate reconstructions and synthetic FC
149
  try:
 
196
 
197
  # Learning curves
198
  try:
199
+ print("Creating learning curve visualization...")
200
+
201
+ # Check if losses are stored in the VAE object first (most reliable source)
202
+ if hasattr(vae, 'train_losses') and hasattr(vae, 'val_losses'):
203
+ if len(vae.train_losses) > 0 and len(vae.val_losses) > 0:
204
+ print(f"Using learning curves from VAE object: {len(vae.train_losses)} train, {len(vae.val_losses)} validation points")
205
+ learning_fig = plot_learning_curves(vae.train_losses, vae.val_losses)
206
+ else:
207
+ # Fall back to the losses passed directly
208
+ if train_losses and val_losses:
209
+ print(f"Using passed learning curves: {len(train_losses)} train, {len(val_losses)} validation points")
210
+ learning_fig = plot_learning_curves(train_losses, val_losses)
211
+ else:
212
+ # Create a placeholder
213
+ print("No training history available for learning curves")
214
+ learning_fig = plt.figure(figsize=(10, 6))
215
+ plt.text(0.5, 0.5, "Learning curve data unavailable",
216
+ ha='center', va='center', transform=plt.gca().transAxes,
217
+ fontsize=14, color='darkred')
218
+ plt.axis('off')
219
+ plt.tight_layout()
220
  else:
221
+ # Fall back to the losses passed directly
222
+ if train_losses and val_losses:
223
+ print(f"Using passed learning curves: {len(train_losses)} train, {len(val_losses)} validation points")
224
+ learning_fig = plot_learning_curves(train_losses, val_losses)
225
+ else:
226
+ # Create a placeholder
227
+ print("No training history available for learning curves")
228
+ learning_fig = plt.figure(figsize=(10, 6))
229
+ plt.text(0.5, 0.5, "Learning curve data unavailable",
230
+ ha='center', va='center', transform=plt.gca().transAxes,
231
+ fontsize=14, color='darkred')
232
+ plt.axis('off')
233
+ plt.tight_layout()
234
  except Exception as e:
235
+ import traceback
236
  print(f"Error creating learning curve plot: {e}")
237
+ print(f"Traceback: {traceback.format_exc()}")
238
+
239
+ # Create a more informative error display
240
  learning_fig = plt.figure(figsize=(10, 6))
241
+ plt.text(0.5, 0.5, f"Error creating learning curves: {str(e)}",
242
+ ha='center', va='center', transform=plt.gca().transAxes,
243
+ fontsize=12, color='darkred')
244
+ plt.axis('off')
245
  plt.tight_layout()
246
 
247
  # Initialize results dictionary
utils.py CHANGED
@@ -129,6 +129,17 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
129
  ce = torch.nn.CrossEntropyLoss()
130
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
131
 
 
 
 
 
 
 
 
 
 
 
 
132
  for e in range(nepochs):
133
  epoch_losses = []
134
  vae.train()
@@ -162,15 +173,27 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
162
  # Print progress for every epoch
163
  print(f'Epoch {e+1}/{nepochs} - Train Loss: {epoch_loss:.4f}')
164
 
165
- # Validation step
166
- if e % pperiod == 0:
167
- vae.eval()
168
- with torch.no_grad():
169
- z = vae.enc(x)
170
- y = vae.dec(z, demo_t)
171
- val_loss = rmse(x, y).item()
172
- val_losses.append(val_loss)
173
-
 
174
  print(f' Validation - Val Loss: {val_loss:.4f}')
175
 
 
 
 
 
 
 
 
 
 
 
 
176
  return train_losses, val_losses
 
129
  ce = torch.nn.CrossEntropyLoss()
130
  optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
131
 
132
+ # Calculate initial validation loss
133
+ print("Calculating initial validation metrics...")
134
+ vae.eval()
135
+ with torch.no_grad():
136
+ z_val = vae.enc(x)
137
+ y_val = vae.dec(z_val, demo_t)
138
+ initial_val_loss = rmse(x, y_val).item()
139
+ val_losses.append(initial_val_loss)
140
+ print(f"Initial validation loss: {initial_val_loss:.4f}")
141
+
142
+ # Main training loop
143
  for e in range(nepochs):
144
  epoch_losses = []
145
  vae.train()
 
173
  # Print progress for every epoch
174
  print(f'Epoch {e+1}/{nepochs} - Train Loss: {epoch_loss:.4f}')
175
 
176
+ # Validation step (perform at every epoch to have full data for plotting)
177
+ vae.eval()
178
+ with torch.no_grad():
179
+ z = vae.enc(x)
180
+ y = vae.dec(z, demo_t)
181
+ val_loss = rmse(x, y).item()
182
+ val_losses.append(val_loss)
183
+
184
+ # Only print detailed validation logs at pperiod intervals
185
+ if (e + 1) % pperiod == 0:
186
  print(f' Validation - Val Loss: {val_loss:.4f}')
187
 
188
+ # Make sure losses are converted to regular Python lists (for serialization)
189
+ train_losses = [float(loss) for loss in train_losses]
190
+ val_losses = [float(loss) for loss in val_losses]
191
+
192
+ print(f"Training complete - Final train loss: {train_losses[-1]:.4f}, Val loss: {val_losses[-1]:.4f}")
193
+ print(f"Loss history recorded: {len(train_losses)} train points, {len(val_losses)} validation points")
194
+
195
+ # Store the losses in the return object for future reference
196
+ ret_obj.train_losses = train_losses
197
+ ret_obj.val_losses = val_losses
198
+
199
  return train_losses, val_losses
visualization.py CHANGED
@@ -312,15 +312,110 @@ def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke
312
  return fig
313
 
314
  def plot_learning_curves(train_losses, val_losses):
315
- """Plot VAE learning curves"""
316
- fig = plt.figure(figsize=(10, 6))
317
-
318
- plt.plot(train_losses, label='Training Loss')
319
- plt.plot(val_losses, label='Validation Loss')
320
- plt.xlabel('Epoch')
321
- plt.ylabel('Loss')
322
- plt.title('VAE Learning Curves')
323
- plt.legend()
324
- plt.grid(True)
325
-
326
- return fig
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
  return fig
313
 
314
  def plot_learning_curves(train_losses, val_losses):
315
+ """Plot VAE learning curves with enhanced visualization"""
316
+ try:
317
+ # Convert to numpy arrays for safe handling
318
+ train_np = np.array(train_losses)
319
+ val_np = np.array(val_losses)
320
+
321
+ # Check for NaN values
322
+ if np.any(np.isnan(train_np)) or np.any(np.isnan(val_np)):
323
+ print("WARNING: Learning curves contain NaN values, replacing with zeros")
324
+ train_np = np.nan_to_num(train_np)
325
+ val_np = np.nan_to_num(val_np)
326
+
327
+ # Create figure
328
+ fig = plt.figure(figsize=(12, 6))
329
+
330
+ # Add improved styling
331
+ plt.rcParams['font.size'] = 12
332
+
333
+ # Check if train and val lengths match
334
+ if len(train_np) != len(val_np):
335
+ print(f"Training and validation loss lengths don't match: {len(train_np)} vs {len(val_np)}")
336
+ if len(train_np) > len(val_np):
337
+ # Validation might be evaluated less frequently
338
+ # Create epoch indices for each
339
+ train_epochs = np.arange(len(train_np))
340
+ val_factor = len(train_np) / len(val_np)
341
+ val_epochs = np.arange(0, len(train_np), val_factor)[:len(val_np)]
342
+
343
+ plt.plot(train_epochs, train_np, 'b-', linewidth=2, label='Training Loss')
344
+ plt.plot(val_epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
345
+ else:
346
+ # This is unusual, but handle it anyway
347
+ plt.plot(train_np, 'b-', linewidth=2, label='Training Loss')
348
+ plt.plot(val_np[:len(train_np)], 'r-', linewidth=2, label='Validation Loss')
349
+ else:
350
+ # Standard case - equal length arrays
351
+ epochs = np.arange(len(train_np))
352
+ plt.plot(epochs, train_np, 'b-', linewidth=2, label='Training Loss')
353
+ plt.plot(epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
354
+
355
+ # Add shaded confidence region
356
+ if len(train_np) > 5: # Only if we have enough points
357
+ # Calculate moving average for smoother trend lines
358
+ window_size = min(5, len(train_np) // 5)
359
+ if window_size > 1:
360
+ avg_train = np.convolve(train_np, np.ones(window_size)/window_size, mode='valid')
361
+ avg_val = np.convolve(val_np, np.ones(window_size)/window_size, mode='valid')
362
+ avg_epochs = epochs[window_size-1:]
363
+ plt.plot(avg_epochs, avg_train, 'b--', linewidth=1, alpha=0.6)
364
+ plt.plot(avg_epochs, avg_val, 'r--', linewidth=1, alpha=0.6)
365
+
366
+ # Calculate improvement from start to end
367
+ if len(train_np) > 1:
368
+ train_improvement = ((train_np[0] - train_np[-1]) / train_np[0]) * 100
369
+ if len(val_np) > 1:
370
+ val_improvement = ((val_np[0] - val_np[-1]) / val_np[0]) * 100
371
+ plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement, Validation: {val_improvement:.1f}% improvement')
372
+ else:
373
+ plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement')
374
+ else:
375
+ plt.title('VAE Learning Curves')
376
+
377
+ # Add min/max annotations
378
+ if len(train_np) > 0:
379
+ min_train = np.min(train_np)
380
+ min_train_epoch = np.argmin(train_np)
381
+ plt.annotate(f'Min: {min_train:.4f}', xy=(min_train_epoch, min_train),
382
+ xytext=(min_train_epoch+5, min_train+0.05),
383
+ arrowprops=dict(facecolor='blue', shrink=0.05, alpha=0.5),
384
+ color='blue', fontsize=10)
385
+
386
+ if len(val_np) > 0:
387
+ min_val = np.min(val_np)
388
+ min_val_epoch = np.argmin(val_np)
389
+ plt.annotate(f'Min: {min_val:.4f}', xy=(min_val_epoch, min_val),
390
+ xytext=(min_val_epoch+5, min_val+0.05),
391
+ arrowprops=dict(facecolor='red', shrink=0.05, alpha=0.5),
392
+ color='red', fontsize=10)
393
+
394
+ # Styling
395
+ plt.xlabel('Epoch')
396
+ plt.ylabel('Loss')
397
+ plt.legend(loc='upper right')
398
+ plt.grid(True, alpha=0.3)
399
+
400
+ # Set reasonable y-axis limits
401
+ all_losses = np.concatenate([train_np, val_np])
402
+ y_min = max(0, np.min(all_losses) * 0.9) # Don't go below zero
403
+ y_max = np.percentile(all_losses, 95) * 1.1 # Exclude outliers
404
+ plt.ylim(y_min, y_max)
405
+
406
+ plt.tight_layout()
407
+ return fig
408
+
409
+ except Exception as e:
410
+ import traceback
411
+ print(f"Error in plot_learning_curves: {e}")
412
+ print(f"Traceback: {traceback.format_exc()}")
413
+
414
+ # Create a simple error figure
415
+ fig = plt.figure(figsize=(10, 6))
416
+ plt.text(0.5, 0.5, f"Learning curves error: {str(e)}",
417
+ ha='center', va='center', transform=plt.gca().transAxes,
418
+ fontsize=12, color='red')
419
+ plt.axis('off')
420
+ plt.tight_layout()
421
+ return fig