File size: 12,419 Bytes
ef677f1
1c47445
b32645b
1c47445
 
dbe81c1
 
37a1b01
dbe81c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37a1b01
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37a1b01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1c47445
 
 
 
 
 
 
 
 
 
37a1b01
 
 
 
 
 
 
 
 
 
 
 
 
 
1c47445
 
 
 
 
 
c4a8601
b32645b
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b32645b
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbe81c1
1c47445
 
 
dbe81c1
1c47445
 
 
 
 
 
 
 
 
 
dbe81c1
 
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b32645b
 
 
 
 
1c47445
b32645b
1c47445
b32645b
 
1c47445
b32645b
1c47445
b32645b
1c47445
 
 
37a1b01
 
 
dbe81c1
1c47445
b32645b
 
1c47445
ef677f1
 
1c47445
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
import gradio as gr
from main import run_fc_analysis
import os
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
import json
import pickle
from rcf_prediction import train_predictor_from_latents

def calculate_fc_accuracy(original_fc, reconstructed_fc):
    """
    Calculate accuracy metrics between original and reconstructed FC matrices
    """
    # Mean Squared Error (lower is better)
    mse = mean_squared_error(original_fc, reconstructed_fc)
    
    # Root Mean Squared Error (lower is better)
    rmse = np.sqrt(mse)
    
    # R² Score (higher is better, 1 is perfect)
    r2 = r2_score(original_fc, reconstructed_fc)
    
    # Correlation between matrices (higher is better)
    corr = np.corrcoef(original_fc.flatten(), reconstructed_fc.flatten())[0, 1]
    
    # Custom similarity score based on normalized dot product (higher is better)
    norm_dot = np.dot(original_fc.flatten(), reconstructed_fc.flatten()) / (
        np.linalg.norm(original_fc.flatten()) * np.linalg.norm(reconstructed_fc.flatten()))
    
    return {
        "MSE": float(mse),
        "RMSE": float(rmse),
        "R²": float(r2),
        "Correlation": float(corr),
        "Cosine Similarity": float(norm_dot)
    }

def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'):
    """
    Save latent representations and associated demographics to file
    """
    os.makedirs('results', exist_ok=True)
    
    # Create a dictionary with latents and demographics
    data = {
        'latents': latents,
        'demographics': demographics
    }
    
    if subjects is not None:
        data['subjects'] = subjects
    
    # Save as pickle for easy loading in Python
    with open(os.path.join('results', file_path), 'wb') as f:
        pickle.dump(data, f)
    
    # Also save as JSON for more universal access
    json_data = {
        'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents,
        'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v 
                         for k, v in demographics.items()}
    }
    
    if subjects is not None:
        json_data['subjects'] = subjects
    
    with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f:
        json.dump(json_data, f)
    
    return os.path.join('results', file_path)

def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
    """Run the full VAE analysis pipeline with accuracy metrics"""
    # Run the original analysis
    fig, results = run_fc_analysis(
        data_dir=data_source,
        demographic_file=None,  # We're now getting demographics directly from the dataset
        latent_dim=latent_dim,
        nepochs=nepochs,
        bsize=bsize,
        save_model=True,
        use_hf_dataset=use_hf_dataset,
        return_data=True  # New parameter to return data, will need to update main.py
    )
    
    if results:
        vae = results.get('vae')
        X = results.get('X')
        latents = results.get('latents')
        demographics = results.get('demographics')
        reconstructed_fc = results.get('reconstructed_fc')
        generated_fc = results.get('generated_fc')
        outcome_measures = results.get('outcome_measures', None)
        
        # Calculate accuracy metrics
        accuracy_metrics = {}
        if X is not None and reconstructed_fc is not None:
            for i in range(min(5, len(X))):  # Calculate for up to 5 samples
                metrics = calculate_fc_accuracy(X[i], reconstructed_fc[i])
                accuracy_metrics[f"Subject_{i+1}"] = metrics
            
            # Average metrics across subjects
            avg_metrics = {}
            for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]:
                avg_metrics[metric] = np.mean([subject_metrics[metric] 
                                              for subject_metrics in accuracy_metrics.values()])
            accuracy_metrics["Average"] = avg_metrics
        
        # Save latent representations if available
        if latents is not None and demographics is not None:
            latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
            print(f"Saved latents to {latents_path}")
            
            # Train a predictor model if we have outcome measures
            predictor_results = None
            if outcome_measures is not None and 'wab_aq' in outcome_measures:
                try:
                    print("Training WAB-AQ prediction model from latent representations...")
                    wab_scores = np.array(outcome_measures['wab_aq'])
                    # Filter out any NaN values
                    valid_indices = ~np.isnan(wab_scores)
                    if np.sum(valid_indices) > 5:  # Only train with sufficient data
                        filtered_latents = latents[valid_indices]
                        filtered_wab = wab_scores[valid_indices]
                        
                        # Extract demographic features for the model
                        filtered_demographics = {}
                        for key, values in demographics.items():
                            if isinstance(values, (list, np.ndarray)) and len(values) >= len(valid_indices):
                                filtered_demographics[key] = np.array(values)[valid_indices]
                        
                        # Train the prediction model with cross-validation
                        predictor_results = train_predictor_from_latents(
                            filtered_latents, 
                            filtered_wab,
                            filtered_demographics,
                            cv=5,  # 5-fold cross-validation
                            n_estimators=100,  # Number of trees in Random Forest
                            prediction_type="regression"
                        )
                        print("WAB-AQ prediction model training complete!")
                except Exception as e:
                    print(f"Error training prediction model: {str(e)}")
        
        # Prepare status message with accuracy metrics
        if accuracy_metrics:
            avg = accuracy_metrics["Average"]
            status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n"
                     f"Reconstruction Accuracy Metrics (Average):\n"
                     f"• MSE: {avg['MSE']:.6f}\n"
                     f"• RMSE: {avg['RMSE']:.6f}\n"
                     f"• R²: {avg['R²']:.6f}\n"
                     f"• Correlation: {avg['Correlation']:.6f}\n"
                     f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n")
            
            # Add prediction model results if available
            if predictor_results is not None:
                cv_results = predictor_results.get('cv_results', {})
                mean_metrics = cv_results.get('mean_metrics', {})
                if mean_metrics and 'r2' in mean_metrics:
                    prediction_r2 = mean_metrics.get('r2', 0)
                    prediction_rmse = mean_metrics.get('rmse', 0)
                    status += (f"WAB-AQ Prediction Model Performance:\n"
                              f"• R²: {prediction_r2:.4f}\n"
                              f"• RMSE: {prediction_rmse:.4f}\n\n")
            
            status += f"Latent representations saved to results/latents_dim{latent_dim}.pkl"
        else:
            status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
    else:
        status = "Analysis complete, but no results were returned for accuracy calculation."
    
    return fig, status

def create_interface():
    with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
        gr.Markdown("""
        # Aphasia fMRI to FC Analysis using VAE
        
        This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
        
        ## Dataset Information
        By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
        - ID: Subject identifier
        - wab_aq: Aphasia severity score
        - age: Age of the subject
        - mpo: Months post onset
        - education: Years of education
        - gender: Subject gender
        - handedness: Subject handedness (ignored in the analysis)
        """)
        
        with gr.Row():
            with gr.Column(scale=1):
                # Configuration parameters
                data_source = gr.Textbox(
                    label="Data Source (HF Dataset ID or Local Directory)", 
                    value="SreekarB/OSFData"
                )
                latent_dim = gr.Slider(
                    minimum=8, maximum=64, step=8, 
                    label="Latent Dimensions", value=32
                )
                nepochs = gr.Slider(
                    minimum=100, maximum=5000, step=100, 
                    label="Number of Epochs", value=200  # Reduced for faster demos
                )
                bsize = gr.Slider(
                    minimum=8, maximum=64, step=8, 
                    label="Batch Size", value=16
                )
                use_hf_dataset = gr.Checkbox(
                    label="Use HuggingFace Dataset", value=True
                )
                
                # Training button
                train_button = gr.Button("Start Training", variant="primary")
                status_text = gr.Textbox(label="Status", value="Ready to start training")
                
            with gr.Column(scale=2):
                # Output plot
                output_plot = gr.Plot(label="Analysis Results")
                accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
        
        # Link the training button to the analysis function
        train_button.click(
            fn=gradio_fc_analysis,
            inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
            outputs=[output_plot, status_text]
        )
        
        # Custom function to update the accuracy box
        def update_accuracy_display(status_text):
            if "Accuracy Metrics" in status_text:
                # Extract the accuracy metrics section
                parts = status_text.split("Reconstruction Accuracy Metrics (Average):")
                if len(parts) > 1:
                    metrics_text = parts[1].split("\n\n")[0]
                    return f"### Reconstruction Accuracy Metrics\n{metrics_text}"
            return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics."
        
        # Update accuracy box when status changes
        status_text.change(
            fn=update_accuracy_display,
            inputs=[status_text],
            outputs=[accuracy_box]
        )
        
        # Add examples
        gr.Examples(
            examples=[
                ["SreekarB/OSFData", 32, 200, 16, True],  # Fewer epochs for faster demo
            ],
            inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
        )
        
        # Add explanation of the workflow
        gr.Markdown("""
        ## How this works
        
        1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
        2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
        3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
        4. **Predictive Modeling**: The system trains a Random Forest regressor on latent features to predict WAB-AQ scores (aphasia severity)
        5. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
        6. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
        
        Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
        """)
    
    return iface

if __name__ == "__main__":
    iface = create_interface()
    iface.launch(share=True)