import matplotlib.pyplot as plt import numpy as np def plot_fc_matrices(original, reconstructed, generated): """Plot FC matrices comparison""" fig, axes = plt.subplots(1, 3, figsize=(15, 5)) vmin, vmax = -1, 1 # Convert 1D arrays to 2D matrices if needed def vector_to_matrix(vector): """Convert upper triangular vector to full matrix""" if len(vector.shape) == 1: # Calculate the matrix size based on vector length # For a vector of length n, the matrix size is (-1 + sqrt(1 + 8*n))/2 n = len(vector) matrix_size = int((-1 + np.sqrt(1 + 8*n)) / 2) # Create empty matrix matrix = np.zeros((matrix_size, matrix_size)) # Fill upper triangle idx = 0 for i in range(matrix_size): for j in range(i+1, matrix_size): matrix[i, j] = vector[idx] idx += 1 # Make symmetric matrix = matrix + matrix.T return matrix return vector # Convert inputs to matrices if needed original_mat = vector_to_matrix(original) reconstructed_mat = vector_to_matrix(reconstructed) generated_mat = vector_to_matrix(generated) im1 = axes[0].imshow(original_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax) axes[0].set_title('Original FC') im2 = axes[1].imshow(reconstructed_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax) axes[1].set_title('Reconstructed FC') im3 = axes[2].imshow(generated_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax) axes[2].set_title('Generated FC') for ax, im in zip(axes, [im1, im2, im3]): plt.colorbar(im, ax=ax) plt.tight_layout() return fig def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None): """Plot predicted treatment trajectory""" fig = plt.figure(figsize=(10, 6)) # Plot current and predicted points plt.scatter([0], [current_score], label='Current Status', color='blue', s=100) plt.scatter([months_post_stroke], [predicted_score], label='Predicted Outcome', color='red', s=100) # Plot trajectory plt.plot([0, months_post_stroke], [current_score, predicted_score], 'g--', label='Predicted Trajectory') # Add prediction interval if available if prediction_std is not None: plt.fill_between([months_post_stroke], [predicted_score - 2*prediction_std], [predicted_score + 2*prediction_std], color='red', alpha=0.2, label='95% Prediction Interval') plt.xlabel('Months Post Treatment') plt.ylabel('WAB Score') plt.title('Predicted Treatment Trajectory') plt.legend() plt.grid(True) return fig def plot_learning_curves(train_losses, val_losses): """Plot VAE learning curves""" fig = plt.figure(figsize=(10, 6)) plt.plot(train_losses, label='Training Loss') plt.plot(val_losses, label='Validation Loss') plt.xlabel('Epoch') plt.ylabel('Loss') plt.title('VAE Learning Curves') plt.legend() plt.grid(True) return fig