File size: 10,843 Bytes
9641510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71fbc82
9641510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71fbc82
 
9641510
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import gradio as gr
from main import run_fc_analysis
import os
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
import json
import pickle

def calculate_fc_accuracy(original_fc, reconstructed_fc):
    """
    Calculate accuracy metrics between original and reconstructed FC matrices
    """
    try:
        # Ensure inputs are 1D arrays for metrics calculations
        if hasattr(original_fc, 'flatten') and len(original_fc.shape) > 1:
            original_flat = original_fc.flatten()
        else:
            original_flat = original_fc
            
        if hasattr(reconstructed_fc, 'flatten') and len(reconstructed_fc.shape) > 1:
            recon_flat = reconstructed_fc.flatten()
        else:
            recon_flat = reconstructed_fc
            
        # Mean Squared Error (lower is better)
        mse = mean_squared_error(original_flat, recon_flat)
        
        # Root Mean Squared Error (lower is better)
        rmse = np.sqrt(mse)
        
        # R² Score (higher is better, 1 is perfect)
        r2 = r2_score(original_flat, recon_flat)
        
        # Correlation between matrices (higher is better)
        corr = np.corrcoef(original_flat, recon_flat)[0, 1]
        
        # Custom similarity score based on normalized dot product (higher is better)
        norm_dot = np.dot(original_flat, recon_flat) / (
            np.linalg.norm(original_flat) * np.linalg.norm(recon_flat))
            
        return {
            "MSE": float(mse),
            "RMSE": float(rmse),
            "R²": float(r2),
            "Correlation": float(corr),
            "Cosine Similarity": float(norm_dot)
        }
        
    except Exception as e:
        print(f"Error calculating accuracy metrics: {e}")
        return {
            "MSE": float('nan'),
            "RMSE": float('nan'),
            "R²": float('nan'),
            "Correlation": float('nan'),
            "Cosine Similarity": float('nan')
        }

def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'):
    """
    Save latent representations and associated demographics to file
    """
    os.makedirs('results', exist_ok=True)
    
    # Create a dictionary with latents and demographics
    data = {
        'latents': latents,
        'demographics': demographics
    }
    
    if subjects is not None:
        data['subjects'] = subjects
    
    # Save as pickle for easy loading in Python
    with open(os.path.join('results', file_path), 'wb') as f:
        pickle.dump(data, f)
    
    # Also save as JSON for more universal access
    json_data = {
        'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents,
        'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v 
                         for k, v in demographics.items()}
    }
    
    if subjects is not None:
        json_data['subjects'] = subjects
    
    with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f:
        json.dump(json_data, f)
    
    return os.path.join('results', file_path)

def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
    """Run the full VAE analysis pipeline with accuracy metrics"""
    # Add some initial status
    print(f"Starting FC analysis with latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
    print(f"Using dataset: {data_source} (HuggingFace API: {use_hf_dataset})")
    
    try:
        # Run the original analysis
        fig, results = run_fc_analysis(
            data_dir=data_source,
            demographic_file=None,  # We're now getting demographics directly from the dataset
            latent_dim=latent_dim,
            nepochs=nepochs,
            bsize=bsize,
            save_model=True,
            use_hf_dataset=use_hf_dataset,
            return_data=True
        )
    except Exception as e:
        import traceback
        error_msg = f"Error during analysis: {str(e)}\n{traceback.format_exc()}"
        print(error_msg)
        
        # Create an error figure
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(10, 6))
        plt.text(0.5, 0.5, f"Error: {str(e)}", 
                ha='center', va='center', fontsize=12, color='red')
        plt.axis('off')
        
        return fig, f"Analysis failed with error: {str(e)}"
    
    if results:
        vae = results.get('vae')
        X = results.get('X')
        latents = results.get('latents')
        demographics = results.get('demographics')
        reconstructed_fc = results.get('reconstructed_fc')
        generated_fc = results.get('generated_fc')
        
        # Calculate accuracy metrics
        accuracy_metrics = {}
        if X is not None and reconstructed_fc is not None:
            for i in range(min(5, len(X))):  # Calculate for up to 5 samples
                # Convert to full matrices if needed
                original = X[i]
                recon = reconstructed_fc[i]
                
                try:
                    # Check if we need to convert from vector to matrix
                    if len(original.shape) == 1:
                        from visualization import vector_to_matrix
                        print(f"Converting subject {i+1} vectors to matrices")
                        original = vector_to_matrix(original)
                        recon = vector_to_matrix(recon)
                    
                    # Flatten matrices for consistent comparison if they're not already flat
                    if len(original.shape) > 1:
                        original_flat = original.flatten()
                        recon_flat = recon.flatten()
                    else:
                        original_flat = original
                        recon_flat = recon
                except Exception as e:
                    print(f"Error converting matrices for subject {i+1}: {e}")
                    continue  # Skip this subject
                
                metrics = calculate_fc_accuracy(original_flat, recon_flat)
                accuracy_metrics[f"Subject_{i+1}"] = metrics
            
            # Average metrics across subjects
            avg_metrics = {}
            for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]:
                avg_metrics[metric] = np.mean([subject_metrics[metric] 
                                              for subject_metrics in accuracy_metrics.values()])
            accuracy_metrics["Average"] = avg_metrics
        
        # Save latent representations if available
        if latents is not None and demographics is not None:
            latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
            print(f"Saved latents to {latents_path}")
        
        # Prepare status message with accuracy metrics
        if accuracy_metrics:
            avg = accuracy_metrics["Average"]
            status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n"
                     f"Reconstruction Accuracy Metrics (Average):\n"
                     f"• MSE: {avg['MSE']:.6f}\n"
                     f"• RMSE: {avg['RMSE']:.6f}\n"
                     f"• R²: {avg['R²']:.6f}\n"
                     f"• Correlation: {avg['Correlation']:.6f}\n"
                     f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n"
                     f"Latent representations saved to results/latents_dim{latent_dim}.pkl")
        else:
            status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
    else:
        status = "Analysis complete, but no results were returned for accuracy calculation."
    
    return fig, status

def run_demo():
    with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as interface:
        with gr.Row():
            with gr.Column(scale=1):
                # Configuration inputs
                with gr.Box():  # Switched to Box to avoid any Group issues
                    gr.Markdown("### Configuration")
                    data_source = gr.Textbox(value="SreekarB/OSFData", label="Data Source (HuggingFace dataset or directory)")
                    use_hf_checkbox = gr.Checkbox(value=True, label="Use HuggingFace Dataset API")
                    latent_dim = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="Latent Dimension")
                    nepochs = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Training Epochs")
                    bsize = gr.Slider(minimum=2, maximum=16, value=4, step=1, label="Batch Size")
                    
                    # Add a note about training time
                    gr.Markdown("**Note:** Lower values for latent dimension, epochs, and batch size will train faster.")
                    
                    train_button = gr.Button("Train VAE & Analyze FC", variant="primary")
                
                # Status output area
                status_text = gr.Textbox(label="Status", lines=10, interactive=False)
                
            with gr.Column(scale=2):
                # Output plot 
                output_plot = gr.Plot(label="FC Matrix Analysis")
                accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
        
        # Link the training button to the analysis function
        train_button.click(
            fn=gradio_fc_analysis,
            inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox],
            outputs=[output_plot, status_text]
        )
        
        # Custom function to update the accuracy box
        def update_accuracy_display(status_text):
            if "Accuracy Metrics" in status_text:
                # Extract the accuracy metrics section
                parts = status_text.split("Reconstruction Accuracy Metrics (Average):")
                if len(parts) > 1:
                    metrics_text = parts[1].split("\n\n")[0]
                    return f"### Reconstruction Accuracy Metrics\n{metrics_text}"
            return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics."
        
        # Update accuracy box when status changes
        status_text.change(
            fn=update_accuracy_display,
            inputs=[status_text],
            outputs=[accuracy_box]
        )
        
        # Add examples
        gr.Examples(
            examples=[
                ["SreekarB/OSFData", 32, 100, 8, True],
                ["SreekarB/OSFData", 16, 200, 16, True],
                ["SreekarB/OSFData", 64, 50, 4, True],
            ],
            inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox],
        )

    return interface

if __name__ == "__main__":
    demo = run_demo()
    demo.launch()