File size: 3,602 Bytes
71fbc82
 
 
 
 
 
 
 
b507484
 
 
71fbc82
 
 
 
 
 
e81f968
71fbc82
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
#!/usr/bin/env python
"""
Standalone script to visualize FC matrices using the VAE.
"""

import os
import sys
import numpy as np
# Configure matplotlib for headless environment
import matplotlib
matplotlib.use('Agg')  # Use non-interactive backend
import matplotlib.pyplot as plt
from main import run_fc_analysis
from config import PREDICTION_CONFIG

def main():
    # Configuration
    data_dir = "SreekarB/OSFData1"  # HuggingFace dataset
    latent_dim = 16
    nepochs = 50
    batch_size = 4
    use_hf_dataset = True
    
    # Check if using local data
    if os.path.exists(data_dir) and os.path.isdir(data_dir):
        print(f"Using local directory: {data_dir}")
        use_hf_dataset = False
    else:
        print(f"Using HuggingFace dataset: {data_dir}")
    
    print(f"Running FC visualization with:")
    print(f"- Data source: {data_dir}")
    print(f"- Latent dimension: {latent_dim}")
    print(f"- Training epochs: {nepochs}")
    print(f"- Batch size: {batch_size}")
    print(f"- Using HuggingFace API: {use_hf_dataset}")
    
    # Run analysis
    try:
        # Update config to allow synthetic data
        PREDICTION_CONFIG['use_synthetic_nifti'] = True
        PREDICTION_CONFIG['use_synthetic_fc'] = True
        print("Enabled synthetic data generation")
        
        # Create a dummy demographic file if needed
        demo_file = "temp_demographics.csv"
        with open(demo_file, "w") as f:
            f.write("ID,age_at_stroke,sex,months_post_stroke,wab_score\n")
            # Write some dummy data
            for i in range(1, 31):  # 30 subjects
                f.write(f"P{i:02d},{65+i%10},{['M','F'][i%2]},{12+i%24},{50+i%30}\n")
        
        print(f"Created temporary demographic file: {demo_file}")
        
        fig, results = run_fc_analysis(
            data_dir=data_dir,
            demographic_file=demo_file,
            latent_dim=latent_dim,
            nepochs=nepochs,
            bsize=batch_size,
            save_model=True,
            use_hf_dataset=use_hf_dataset,
            return_data=True
        )
        
        # Save the figure
        output_file = "fc_visualization.png"
        fig.savefig(output_file, dpi=300, bbox_inches='tight')
        print(f"Saved visualization to {output_file}")
        
        # If results are available, calculate some metrics
        if results:
            X = results.get('X')
            reconstructed_fc = results.get('reconstructed_fc')
            
            if X is not None and reconstructed_fc is not None:
                # Calculate MSE between original and reconstructed
                original = X[0]
                recon = reconstructed_fc[0]
                
                # Convert to matrices if needed
                from visualization import vector_to_matrix
                if len(original.shape) == 1:
                    original = vector_to_matrix(original)
                    recon = vector_to_matrix(recon)
                
                # Calculate MSE
                mse = np.mean((original - recon) ** 2)
                print(f"Reconstruction MSE: {mse:.6f}")
                
                # Save the matrices
                np.save("original_fc.npy", original)
                np.save("reconstructed_fc.npy", recon)
                print("Saved matrices to original_fc.npy and reconstructed_fc.npy")
    
    except Exception as e:
        print(f"Error during visualization: {e}")
        import traceback
        traceback.print_exc()
        sys.exit(1)
    
    print("Visualization complete!")

if __name__ == "__main__":
    main()