SreekarB commited on
Commit
71fbc82
·
verified ·
1 Parent(s): 9641510

Upload 36 files

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ fc_visualization.png filter=lfs diff=lfs merge=lfs -text
__pycache__/config.cpython-311.pyc ADDED
Binary file (779 Bytes). View file
 
__pycache__/data_preprocessing.cpython-311.pyc ADDED
Binary file (21.3 kB). View file
 
__pycache__/main.cpython-311.pyc ADDED
Binary file (20 kB). View file
 
__pycache__/rcf_prediction.cpython-311.pyc ADDED
Binary file (16.6 kB). View file
 
__pycache__/utils.cpython-311.pyc ADDED
Binary file (11.7 kB). View file
 
__pycache__/vae_model.cpython-311.pyc ADDED
Binary file (11.9 kB). View file
 
__pycache__/visualization.cpython-311.pyc ADDED
Binary file (6.91 kB). View file
 
app.py CHANGED
@@ -1868,7 +1868,7 @@ def create_interface():
1868
  with gr.Column(scale=1):
1869
  fmri_file = gr.File(label="Patient fMRI Data (NIfTI file)")
1870
  with gr.Column(scale=1):
1871
- with gr.Group("Patient Demographics"):
1872
  age = gr.Number(label="Age at Stroke", value=60)
1873
  sex = gr.Dropdown(choices=["M", "F"], label="Sex", value="M")
1874
  months = gr.Number(label="Months Post Stroke", value=12)
 
1868
  with gr.Column(scale=1):
1869
  fmri_file = gr.File(label="Patient fMRI Data (NIfTI file)")
1870
  with gr.Column(scale=1):
1871
+ with gr.Group(label="Patient Demographics"):
1872
  age = gr.Number(label="Age at Stroke", value=60)
1873
  sex = gr.Dropdown(choices=["M", "F"], label="Sex", value="M")
1874
  months = gr.Number(label="Months Post Stroke", value=12)
app_fixed.py CHANGED
@@ -195,7 +195,7 @@ def run_demo():
195
  with gr.Row():
196
  with gr.Column(scale=1):
197
  # Configuration inputs
198
- with gr.Group():
199
  gr.Markdown("### Configuration")
200
  data_source = gr.Textbox(value="SreekarB/OSFData", label="Data Source (HuggingFace dataset or directory)")
201
  use_hf_checkbox = gr.Checkbox(value=True, label="Use HuggingFace Dataset API")
@@ -212,8 +212,8 @@ def run_demo():
212
  status_text = gr.Textbox(label="Status", lines=10, interactive=False)
213
 
214
  with gr.Column(scale=2):
215
- # Output plot
216
- output_plot = gr.Plot(label="FC Matrix Analysis", height=400)
217
  accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
218
 
219
  # Link the training button to the analysis function
 
195
  with gr.Row():
196
  with gr.Column(scale=1):
197
  # Configuration inputs
198
+ with gr.Box(): # Switched to Box to avoid any Group issues
199
  gr.Markdown("### Configuration")
200
  data_source = gr.Textbox(value="SreekarB/OSFData", label="Data Source (HuggingFace dataset or directory)")
201
  use_hf_checkbox = gr.Checkbox(value=True, label="Use HuggingFace Dataset API")
 
212
  status_text = gr.Textbox(label="Status", lines=10, interactive=False)
213
 
214
  with gr.Column(scale=2):
215
+ # Output plot
216
+ output_plot = gr.Plot(label="FC Matrix Analysis")
217
  accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
218
 
219
  # Link the training button to the analysis function
direct_fc_visualization.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Direct FC Matrix Visualization Script.
4
+
5
+ This script creates and visualizes FC matrices directly, without relying on fMRI data.
6
+ """
7
+
8
+ import os
9
+ import numpy as np
10
+ import matplotlib.pyplot as plt
11
+ from visualization import vector_to_matrix
12
+
13
+ def create_synthetic_fc_matrices(n_subjects=10, n_rois=264, seed=42):
14
+ """
15
+ Create synthetic FC matrices for visualization.
16
+
17
+ Args:
18
+ n_subjects: Number of synthetic subjects
19
+ n_rois: Number of regions of interest
20
+ seed: Random seed for reproducibility
21
+
22
+ Returns:
23
+ dict: Dictionary with original FC matrices, latent features, and reconstructions
24
+ """
25
+ np.random.seed(seed)
26
+
27
+ # Calculate the size of upper triangular part
28
+ n_triu = n_rois * (n_rois - 1) // 2
29
+
30
+ # Create synthetic FC matrices
31
+ print(f"Creating {n_subjects} synthetic FC matrices with {n_rois} ROIs each")
32
+
33
+ # Create original FC matrices (upper triangular vectors)
34
+ original_fc_vectors = []
35
+ for i in range(n_subjects):
36
+ # Create random correlation values
37
+ np.random.seed(i) # For reproducibility
38
+ # Generate values between -0.8 and 0.8 (typical FC range)
39
+ fc_triu = np.random.rand(n_triu) * 1.6 - 0.8
40
+ original_fc_vectors.append(fc_triu)
41
+
42
+ # Simulate latent features (much lower dimensional)
43
+ latent_dim = 16
44
+ latent_features = np.random.randn(n_subjects, latent_dim)
45
+
46
+ # Simulate reconstructions with some error
47
+ reconstructed_fc_vectors = []
48
+ for i in range(n_subjects):
49
+ # Add some noise to original to simulate reconstruction error
50
+ recon = original_fc_vectors[i] + np.random.randn(n_triu) * 0.1
51
+ # Clip to realistic correlation range
52
+ recon = np.clip(recon, -0.99, 0.99)
53
+ reconstructed_fc_vectors.append(recon)
54
+
55
+ # Simulate a newly generated FC matrix
56
+ generated_fc_vector = np.random.rand(n_triu) * 1.6 - 0.8
57
+
58
+ return {
59
+ 'original_vectors': original_fc_vectors,
60
+ 'reconstructed_vectors': reconstructed_fc_vectors,
61
+ 'generated_vector': generated_fc_vector,
62
+ 'latent_features': latent_features
63
+ }
64
+
65
+ def visualize_fc_matrices(fc_data, subject_idx=0):
66
+ """
67
+ Create visualizations of FC matrices.
68
+
69
+ Args:
70
+ fc_data: Dictionary with FC data
71
+ subject_idx: Subject index to visualize
72
+
73
+ Returns:
74
+ fig: Matplotlib figure
75
+ """
76
+ # Get the vectors
77
+ original_vector = fc_data['original_vectors'][subject_idx]
78
+ reconstructed_vector = fc_data['reconstructed_vectors'][subject_idx]
79
+ generated_vector = fc_data['generated_vector']
80
+
81
+ # Convert to matrices
82
+ original_matrix = vector_to_matrix(original_vector)
83
+ reconstructed_matrix = vector_to_matrix(reconstructed_vector)
84
+ generated_matrix = vector_to_matrix(generated_vector)
85
+
86
+ # Create visualization
87
+ fig, axes = plt.subplots(1, 3, figsize=(15, 5))
88
+
89
+ vmin, vmax = -1, 1
90
+
91
+ # Original FC matrix
92
+ im1 = axes[0].imshow(original_matrix, cmap='RdBu_r', vmin=vmin, vmax=vmax)
93
+ axes[0].set_title('Original FC')
94
+ plt.colorbar(im1, ax=axes[0])
95
+
96
+ # Reconstructed FC matrix
97
+ im2 = axes[1].imshow(reconstructed_matrix, cmap='RdBu_r', vmin=vmin, vmax=vmax)
98
+ axes[1].set_title('Reconstructed FC')
99
+ plt.colorbar(im2, ax=axes[1])
100
+
101
+ # Generated FC matrix
102
+ im3 = axes[2].imshow(generated_matrix, cmap='RdBu_r', vmin=vmin, vmax=vmax)
103
+ axes[2].set_title('Generated FC')
104
+ plt.colorbar(im3, ax=axes[2])
105
+
106
+ plt.tight_layout()
107
+ return fig
108
+
109
+ def calculate_metrics(original, reconstructed):
110
+ """
111
+ Calculate reconstruction metrics
112
+
113
+ Args:
114
+ original: Original FC matrix
115
+ reconstructed: Reconstructed FC matrix
116
+
117
+ Returns:
118
+ dict: Dictionary of metrics
119
+ """
120
+ from sklearn.metrics import mean_squared_error, r2_score
121
+
122
+ # Flatten matrices
123
+ orig_flat = original.flatten()
124
+ recon_flat = reconstructed.flatten()
125
+
126
+ # Calculate metrics
127
+ mse = mean_squared_error(orig_flat, recon_flat)
128
+ rmse = np.sqrt(mse)
129
+ r2 = r2_score(orig_flat, recon_flat)
130
+ corr = np.corrcoef(orig_flat, recon_flat)[0, 1]
131
+
132
+ return {
133
+ 'MSE': mse,
134
+ 'RMSE': rmse,
135
+ 'R²': r2,
136
+ 'Correlation': corr
137
+ }
138
+
139
+ def main():
140
+ """Run the visualization script"""
141
+ print("Creating direct FC matrix visualization without fMRI data")
142
+
143
+ # Create synthetic FC data
144
+ fc_data = create_synthetic_fc_matrices(n_subjects=10)
145
+
146
+ # Visualize FC matrices
147
+ fig = visualize_fc_matrices(fc_data)
148
+
149
+ # Save the figure
150
+ output_file = "fc_visualization.png"
151
+ fig.savefig(output_file, dpi=300, bbox_inches='tight')
152
+ print(f"Saved visualization to {output_file}")
153
+
154
+ # Save matrices for inspection
155
+ original_matrix = vector_to_matrix(fc_data['original_vectors'][0])
156
+ reconstructed_matrix = vector_to_matrix(fc_data['reconstructed_vectors'][0])
157
+ generated_matrix = vector_to_matrix(fc_data['generated_vector'])
158
+
159
+ np.save('original_fc.npy', original_matrix)
160
+ np.save('reconstructed_fc.npy', reconstructed_matrix)
161
+ np.save('generated_fc.npy', generated_matrix)
162
+ print("Saved matrices to NPY files")
163
+
164
+ # Calculate and display metrics
165
+ metrics = calculate_metrics(original_matrix, reconstructed_matrix)
166
+ print("\nFC Reconstruction Metrics:")
167
+ for name, value in metrics.items():
168
+ print(f" {name}: {value:.6f}")
169
+
170
+ print("\nVisualization complete!")
171
+
172
+ if __name__ == "__main__":
173
+ main()
fc_visualization.png ADDED

Git LFS Details

  • SHA256: aac44d5e169a980d7a584ee3c315d401123dfe0d0fae230c5544584229b8994c
  • Pointer size: 132 Bytes
  • Size of remote file: 1.06 MB
fix_group.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #\!/usr/bin/env python
2
+ """
3
+ Simple script to fix the gr.Group error in app.py
4
+ """
5
+
6
+ import os
7
+ import re
8
+
9
+ # Path to the app.py file
10
+ app_path = "/home/user/app/app.py"
11
+
12
+ # Read the file
13
+ with open(app_path, "r") as f:
14
+ content = f.read()
15
+
16
+ # Fix the gr.Group initialization
17
+ fixed_content = re.sub(
18
+ r'gr\.Group\("([^"]+)"\)',
19
+ r'gr.Group(label="\1")',
20
+ content
21
+ )
22
+
23
+ # Write the fixed content back
24
+ with open(app_path, "w") as f:
25
+ f.write(fixed_content)
26
+
27
+ print("Fixed the gr.Group initialization in app.py")
28
+ print("You can now run: python app.py")
generated_fc.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6c87c8e143da0dccdc7c5653e237c36b13644b2ba6eb2596e518dac95d664b1
3
+ size 557696
original_fc.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a30451cf7faabae45932e550fee6746bdd6e8c5d887f5559ae4af7921507196e
3
+ size 557696
reconstructed_fc.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:10f10843c4ef0be793b0737240e3a34e7c5263fc36b38a22286a66d794d2e684
3
+ size 557696
temp_demographics.csv ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ID,age_at_stroke,sex,months_post_stroke,wab_score
2
+ P01,66,F,13,51
3
+ P02,67,M,14,52
4
+ P03,68,F,15,53
5
+ P04,69,M,16,54
6
+ P05,70,F,17,55
7
+ P06,71,M,18,56
8
+ P07,72,F,19,57
9
+ P08,73,M,20,58
10
+ P09,74,F,21,59
11
+ P10,65,M,22,60
12
+ P11,66,F,23,61
13
+ P12,67,M,24,62
14
+ P13,68,F,25,63
15
+ P14,69,M,26,64
16
+ P15,70,F,27,65
17
+ P16,71,M,28,66
18
+ P17,72,F,29,67
19
+ P18,73,M,30,68
20
+ P19,74,F,31,69
21
+ P20,65,M,32,70
22
+ P21,66,F,33,71
23
+ P22,67,M,34,72
24
+ P23,68,F,35,73
25
+ P24,69,M,12,74
26
+ P25,70,F,13,75
27
+ P26,71,M,14,76
28
+ P27,72,F,15,77
29
+ P28,73,M,16,78
30
+ P29,74,F,17,79
31
+ P30,65,M,18,50
visualization.py CHANGED
@@ -36,10 +36,34 @@ def vector_to_matrix(vector):
36
  # Print diagnostic info
37
  print(f"Converting vector to matrix. Vector shape: {vector.shape}, length: {len(vector)}")
38
 
39
- # Calculate matrix size from vector length
40
- n = int(np.sqrt(2 * len(vector) + 0.25) + 0.5)
 
 
 
 
 
 
 
 
41
  print(f"Calculated matrix size: {n}x{n}")
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  # Create empty matrix
44
  matrix = np.zeros((n, n))
45
 
@@ -71,14 +95,43 @@ def vector_to_matrix(vector):
71
  print(f"Vector stats: min={np.min(vector)}, max={np.max(vector)}, mean={np.mean(vector)}")
72
  print(f"Traceback: {traceback.format_exc()}")
73
 
74
- # Fallback - try to reshape if it's a square matrix in flat form
75
  if np.sqrt(len(vector)) == int(np.sqrt(len(vector))):
76
  n = int(np.sqrt(len(vector)))
77
  print(f"Trying fallback reshape to {n}x{n}")
78
  return vector.reshape(n, n)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  else:
80
- # If all else fails, create a dummy matrix
81
- print("Creating fallback matrix")
82
  n = 264 # Standard size for brain atlas
83
  matrix = np.zeros((n, n))
84
  np.fill_diagonal(matrix, 1.0)
 
36
  # Print diagnostic info
37
  print(f"Converting vector to matrix. Vector shape: {vector.shape}, length: {len(vector)}")
38
 
39
+ # Handle FC vectors specifically - for a 264x264 FC matrix, we expect 34716 elements
40
+ # This is the most common case in this application
41
+ if len(vector) == 34716:
42
+ print("Detected standard FC vector with 34716 elements (264x264 matrix)")
43
+ n = 264
44
+ else:
45
+ # For other sized vectors, calculate matrix size from vector length
46
+ # For a matrix of size n×n, the number of elements in the upper triangular part (excl. diagonal) is n(n-1)/2
47
+ n = int(np.sqrt(2 * len(vector) + 0.25) + 0.5)
48
+
49
  print(f"Calculated matrix size: {n}x{n}")
50
 
51
+ # Validate calculation
52
+ expected_elements = int(n * (n-1) / 2)
53
+ if expected_elements != len(vector):
54
+ print(f"WARNING: Vector length {len(vector)} doesn't match expected length {expected_elements} for {n}x{n} matrix")
55
+
56
+ # If the vector length is very close to expected, we can pad or truncate
57
+ if abs(expected_elements - len(vector)) < n:
58
+ if len(vector) < expected_elements:
59
+ print(f"Padding vector with {expected_elements - len(vector)} zeros")
60
+ vector = np.pad(vector, (0, expected_elements - len(vector)))
61
+ else:
62
+ print(f"Truncating vector to {expected_elements} elements")
63
+ vector = vector[:expected_elements]
64
+ else:
65
+ raise ValueError(f"Vector length {len(vector)} incompatible with calculated matrix size {n}x{n}")
66
+
67
  # Create empty matrix
68
  matrix = np.zeros((n, n))
69
 
 
95
  print(f"Vector stats: min={np.min(vector)}, max={np.max(vector)}, mean={np.mean(vector)}")
96
  print(f"Traceback: {traceback.format_exc()}")
97
 
98
+ # Fallback 1 - check if it's already a matrix that was flattened
99
  if np.sqrt(len(vector)) == int(np.sqrt(len(vector))):
100
  n = int(np.sqrt(len(vector)))
101
  print(f"Trying fallback reshape to {n}x{n}")
102
  return vector.reshape(n, n)
103
+
104
+ # Fallback 2 - try standard FC matrix size
105
+ elif len(vector) > 30000 and len(vector) < 40000: # Close to 34716
106
+ print(f"Vector length {len(vector)} is close to 34716, trying 264x264 matrix")
107
+ n = 264
108
+ matrix = np.zeros((n, n))
109
+ np.fill_diagonal(matrix, 1.0)
110
+
111
+ # Try to fill as much as possible
112
+ triu_indices = np.triu_indices_from(matrix, k=1)
113
+ max_idx = min(len(vector), len(triu_indices[0]))
114
+
115
+ # Convert from Fisher z-transform if needed
116
+ if np.any(np.abs(vector[:max_idx]) > 1):
117
+ values = np.tanh(vector[:max_idx])
118
+ else:
119
+ values = vector[:max_idx]
120
+
121
+ # Fill the upper triangle with as many values as we can
122
+ for i in range(max_idx):
123
+ matrix[triu_indices[0][i], triu_indices[1][i]] = values[i]
124
+
125
+ # Make symmetric
126
+ matrix = matrix + matrix.T
127
+ np.fill_diagonal(matrix, 1.0)
128
+
129
+ print(f"Created partial matrix with shape {matrix.shape}")
130
+ return matrix
131
+
132
+ # Fallback 3 - create a dummy identity matrix as last resort
133
  else:
134
+ print("Creating fallback identity matrix")
 
135
  n = 264 # Standard size for brain atlas
136
  matrix = np.zeros((n, n))
137
  np.fill_diagonal(matrix, 1.0)
visualize_fc.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ """
3
+ Standalone script to visualize FC matrices using the VAE.
4
+ """
5
+
6
+ import os
7
+ import sys
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from main import run_fc_analysis
11
+ from config import PREDICTION_CONFIG
12
+
13
+ def main():
14
+ # Configuration
15
+ data_dir = "SreekarB/OSFData" # HuggingFace dataset
16
+ latent_dim = 16
17
+ nepochs = 50
18
+ batch_size = 4
19
+ use_hf_dataset = True
20
+
21
+ # Check if using local data
22
+ if os.path.exists(data_dir) and os.path.isdir(data_dir):
23
+ print(f"Using local directory: {data_dir}")
24
+ use_hf_dataset = False
25
+ else:
26
+ print(f"Using HuggingFace dataset: {data_dir}")
27
+
28
+ print(f"Running FC visualization with:")
29
+ print(f"- Data source: {data_dir}")
30
+ print(f"- Latent dimension: {latent_dim}")
31
+ print(f"- Training epochs: {nepochs}")
32
+ print(f"- Batch size: {batch_size}")
33
+ print(f"- Using HuggingFace API: {use_hf_dataset}")
34
+
35
+ # Run analysis
36
+ try:
37
+ # Update config to allow synthetic data
38
+ PREDICTION_CONFIG['use_synthetic_nifti'] = True
39
+ PREDICTION_CONFIG['use_synthetic_fc'] = True
40
+ print("Enabled synthetic data generation")
41
+
42
+ # Create a dummy demographic file if needed
43
+ demo_file = "temp_demographics.csv"
44
+ with open(demo_file, "w") as f:
45
+ f.write("ID,age_at_stroke,sex,months_post_stroke,wab_score\n")
46
+ # Write some dummy data
47
+ for i in range(1, 31): # 30 subjects
48
+ f.write(f"P{i:02d},{65+i%10},{['M','F'][i%2]},{12+i%24},{50+i%30}\n")
49
+
50
+ print(f"Created temporary demographic file: {demo_file}")
51
+
52
+ fig, results = run_fc_analysis(
53
+ data_dir=data_dir,
54
+ demographic_file=demo_file,
55
+ latent_dim=latent_dim,
56
+ nepochs=nepochs,
57
+ bsize=batch_size,
58
+ save_model=True,
59
+ use_hf_dataset=use_hf_dataset,
60
+ return_data=True
61
+ )
62
+
63
+ # Save the figure
64
+ output_file = "fc_visualization.png"
65
+ fig.savefig(output_file, dpi=300, bbox_inches='tight')
66
+ print(f"Saved visualization to {output_file}")
67
+
68
+ # If results are available, calculate some metrics
69
+ if results:
70
+ X = results.get('X')
71
+ reconstructed_fc = results.get('reconstructed_fc')
72
+
73
+ if X is not None and reconstructed_fc is not None:
74
+ # Calculate MSE between original and reconstructed
75
+ original = X[0]
76
+ recon = reconstructed_fc[0]
77
+
78
+ # Convert to matrices if needed
79
+ from visualization import vector_to_matrix
80
+ if len(original.shape) == 1:
81
+ original = vector_to_matrix(original)
82
+ recon = vector_to_matrix(recon)
83
+
84
+ # Calculate MSE
85
+ mse = np.mean((original - recon) ** 2)
86
+ print(f"Reconstruction MSE: {mse:.6f}")
87
+
88
+ # Save the matrices
89
+ np.save("original_fc.npy", original)
90
+ np.save("reconstructed_fc.npy", recon)
91
+ print("Saved matrices to original_fc.npy and reconstructed_fc.npy")
92
+
93
+ except Exception as e:
94
+ print(f"Error during visualization: {e}")
95
+ import traceback
96
+ traceback.print_exc()
97
+ sys.exit(1)
98
+
99
+ print("Visualization complete!")
100
+
101
+ if __name__ == "__main__":
102
+ main()