File size: 3,311 Bytes
ef677f1
 
 
dbe81c1
 
 
 
 
 
3417890
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbe81c1
 
3417890
dbe81c1
 
3417890
dbe81c1
 
 
 
67303f6
b32645b
67303f6
 
dbe81c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import matplotlib.pyplot as plt
import numpy as np

def plot_fc_matrices(original, reconstructed, generated):
    """Plot FC matrices comparison"""
    fig, axes = plt.subplots(1, 3, figsize=(15, 5))
    
    vmin, vmax = -1, 1
    
    # Convert 1D arrays to 2D matrices if needed
    def vector_to_matrix(vector):
        """Convert upper triangular vector to full matrix"""
        if len(vector.shape) == 1:
            # Calculate the matrix size based on vector length
            # For a vector of length n, the matrix size is (-1 + sqrt(1 + 8*n))/2
            n = len(vector)
            matrix_size = int((-1 + np.sqrt(1 + 8*n)) / 2)
            
            # Create empty matrix
            matrix = np.zeros((matrix_size, matrix_size))
            
            # Fill upper triangle
            idx = 0
            for i in range(matrix_size):
                for j in range(i+1, matrix_size):
                    matrix[i, j] = vector[idx]
                    idx += 1
            
            # Make symmetric
            matrix = matrix + matrix.T
            
            return matrix
        return vector
    
    # Convert inputs to matrices if needed
    original_mat = vector_to_matrix(original)
    reconstructed_mat = vector_to_matrix(reconstructed)
    generated_mat = vector_to_matrix(generated)
    
    im1 = axes[0].imshow(original_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
    axes[0].set_title('Original FC')
    
    im2 = axes[1].imshow(reconstructed_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
    axes[1].set_title('Reconstructed FC')
    
    im3 = axes[2].imshow(generated_mat, cmap='RdBu_r', vmin=vmin, vmax=vmax)
    axes[2].set_title('Generated FC')
    
    for ax, im in zip(axes, [im1, im2, im3]):
        plt.colorbar(im, ax=ax)
    
    plt.tight_layout()
    return fig

def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None):
    """Plot predicted treatment trajectory"""
    fig = plt.figure(figsize=(10, 6))
    
    # Plot current and predicted points
    plt.scatter([0], [current_score], label='Current Status', color='blue', s=100)
    plt.scatter([months_post_stroke], [predicted_score], 
                label='Predicted Outcome', color='red', s=100)
    
    # Plot trajectory
    plt.plot([0, months_post_stroke], [current_score, predicted_score], 
             'g--', label='Predicted Trajectory')
    
    # Add prediction interval if available
    if prediction_std is not None:
        plt.fill_between([months_post_stroke], 
                        [predicted_score - 2*prediction_std],
                        [predicted_score + 2*prediction_std],
                        color='red', alpha=0.2,
                        label='95% Prediction Interval')
    
    plt.xlabel('Months Post Treatment')
    plt.ylabel('WAB Score')
    plt.title('Predicted Treatment Trajectory')
    plt.legend()
    plt.grid(True)
    
    return fig

def plot_learning_curves(train_losses, val_losses):
    """Plot VAE learning curves"""
    fig = plt.figure(figsize=(10, 6))
    
    plt.plot(train_losses, label='Training Loss')
    plt.plot(val_losses, label='Validation Loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('VAE Learning Curves')
    plt.legend()
    plt.grid(True)
    
    return fig