Spaces:
Sleeping
Sleeping
File size: 3,311 Bytes
ef677f1 dbe81c1 3417890 dbe81c1 3417890 dbe81c1 3417890 dbe81c1 67303f6 b32645b 67303f6 dbe81c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
import matplotlib.pyplot as plt
import numpy as np
def plot_fc_matrices(original, reconstructed, generated):
"""Plot FC matrices comparison"""
fig, axes = plt.subplots(1, 3, figsize=(15, 5))
vmin, vmax = -1, 1
# Convert 1D arrays to 2D matrices if needed
def vector_to_matrix(vector):
"""Convert upper triangular vector to full matrix"""
if len(vector.shape) == 1:
# Calculate the matrix size based on vector length
# For a vector of length n, the matrix size is (-1 + sqrt(1 + 8*n))/2
n = len(vector)
matrix_size = int((-1 + np.sqrt(1 + 8*n)) / 2)
# Create empty matrix
matrix = np.zeros((matrix_size, matrix_size))
# Fill upper triangle
idx = 0
for i in range(matrix_size):
for j in range(i+1, matrix_size):
matrix[i, j] = vector[idx]
idx += 1
# Make symmetric
matrix = matrix + matrix.T
return matrix
return vector
# Convert inputs to matrices if needed
original_mat = vector_to_matrix(original)
reconstructed_mat = vector_to_matrix(reconstructed)
generated_mat = vector_to_matrix(generated)
im1 = axes[0].imshow(original_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
axes[0].set_title('Original FC')
im2 = axes[1].imshow(reconstructed_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
axes[1].set_title('Reconstructed FC')
im3 = axes[2].imshow(generated_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
axes[2].set_title('Generated FC')
for ax, im in zip(axes, [im1, im2, im3]):
plt.colorbar(im, ax=ax)
plt.tight_layout()
return fig
def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None):
"""Plot predicted treatment trajectory"""
fig = plt.figure(figsize=(10, 6))
# Plot current and predicted points
plt.scatter([0], [current_score], label='Current Status', color='blue', s=100)
plt.scatter([months_post_stroke], [predicted_score],
label='Predicted Outcome', color='red', s=100)
# Plot trajectory
plt.plot([0, months_post_stroke], [current_score, predicted_score],
'g--', label='Predicted Trajectory')
# Add prediction interval if available
if prediction_std is not None:
plt.fill_between([months_post_stroke],
[predicted_score - 2*prediction_std],
[predicted_score + 2*prediction_std],
color='red', alpha=0.2,
label='95% Prediction Interval')
plt.xlabel('Months Post Treatment')
plt.ylabel('WAB Score')
plt.title('Predicted Treatment Trajectory')
plt.legend()
plt.grid(True)
return fig
def plot_learning_curves(train_losses, val_losses):
"""Plot VAE learning curves"""
fig = plt.figure(figsize=(10, 6))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('VAE Learning Curves')
plt.legend()
plt.grid(True)
return fig
|