Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from main import run_fc_analysis | |
| import os | |
| import numpy as np | |
| from sklearn.metrics import mean_squared_error, r2_score | |
| import json | |
| import pickle | |
| from rcf_prediction import train_predictor_from_latents | |
| def calculate_fc_accuracy(original_fc, reconstructed_fc): | |
| """ | |
| Calculate accuracy metrics between original and reconstructed FC matrices | |
| """ | |
| # Mean Squared Error (lower is better) | |
| mse = mean_squared_error(original_fc, reconstructed_fc) | |
| # Root Mean Squared Error (lower is better) | |
| rmse = np.sqrt(mse) | |
| # R² Score (higher is better, 1 is perfect) | |
| r2 = r2_score(original_fc, reconstructed_fc) | |
| # Correlation between matrices (higher is better) | |
| corr = np.corrcoef(original_fc.flatten(), reconstructed_fc.flatten())[0, 1] | |
| # Custom similarity score based on normalized dot product (higher is better) | |
| norm_dot = np.dot(original_fc.flatten(), reconstructed_fc.flatten()) / ( | |
| np.linalg.norm(original_fc.flatten()) * np.linalg.norm(reconstructed_fc.flatten())) | |
| return { | |
| "MSE": float(mse), | |
| "RMSE": float(rmse), | |
| "R²": float(r2), | |
| "Correlation": float(corr), | |
| "Cosine Similarity": float(norm_dot) | |
| } | |
| def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'): | |
| """ | |
| Save latent representations and associated demographics to file | |
| """ | |
| os.makedirs('results', exist_ok=True) | |
| # Create a dictionary with latents and demographics | |
| data = { | |
| 'latents': latents, | |
| 'demographics': demographics | |
| } | |
| if subjects is not None: | |
| data['subjects'] = subjects | |
| # Save as pickle for easy loading in Python | |
| with open(os.path.join('results', file_path), 'wb') as f: | |
| pickle.dump(data, f) | |
| # Also save as JSON for more universal access | |
| json_data = { | |
| 'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents, | |
| 'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v | |
| for k, v in demographics.items()} | |
| } | |
| if subjects is not None: | |
| json_data['subjects'] = subjects | |
| with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f: | |
| json.dump(json_data, f) | |
| return os.path.join('results', file_path) | |
| def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset): | |
| """Run the full VAE analysis pipeline with accuracy metrics""" | |
| # Run the original analysis | |
| fig, results = run_fc_analysis( | |
| data_dir=data_source, | |
| demographic_file=None, # We're now getting demographics directly from the dataset | |
| latent_dim=latent_dim, | |
| nepochs=nepochs, | |
| bsize=bsize, | |
| save_model=True, | |
| use_hf_dataset=use_hf_dataset, | |
| return_data=True # New parameter to return data, will need to update main.py | |
| ) | |
| if results: | |
| vae = results.get('vae') | |
| X = results.get('X') | |
| latents = results.get('latents') | |
| demographics = results.get('demographics') | |
| reconstructed_fc = results.get('reconstructed_fc') | |
| generated_fc = results.get('generated_fc') | |
| outcome_measures = results.get('outcome_measures', None) | |
| # Calculate accuracy metrics | |
| accuracy_metrics = {} | |
| if X is not None and reconstructed_fc is not None: | |
| for i in range(min(5, len(X))): # Calculate for up to 5 samples | |
| metrics = calculate_fc_accuracy(X[i], reconstructed_fc[i]) | |
| accuracy_metrics[f"Subject_{i+1}"] = metrics | |
| # Average metrics across subjects | |
| avg_metrics = {} | |
| for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]: | |
| avg_metrics[metric] = np.mean([subject_metrics[metric] | |
| for subject_metrics in accuracy_metrics.values()]) | |
| accuracy_metrics["Average"] = avg_metrics | |
| # Save latent representations if available | |
| if latents is not None and demographics is not None: | |
| latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl') | |
| print(f"Saved latents to {latents_path}") | |
| # Train a predictor model if we have outcome measures | |
| predictor_results = None | |
| if outcome_measures is not None and 'wab_aq' in outcome_measures: | |
| try: | |
| print("Training WAB-AQ prediction model from latent representations...") | |
| wab_scores = np.array(outcome_measures['wab_aq']) | |
| # Filter out any NaN values | |
| valid_indices = ~np.isnan(wab_scores) | |
| if np.sum(valid_indices) > 5: # Only train with sufficient data | |
| filtered_latents = latents[valid_indices] | |
| filtered_wab = wab_scores[valid_indices] | |
| # Extract demographic features for the model | |
| filtered_demographics = {} | |
| for key, values in demographics.items(): | |
| if isinstance(values, (list, np.ndarray)) and len(values) >= len(valid_indices): | |
| filtered_demographics[key] = np.array(values)[valid_indices] | |
| # Train the prediction model with cross-validation | |
| predictor_results = train_predictor_from_latents( | |
| filtered_latents, | |
| filtered_wab, | |
| filtered_demographics, | |
| cv=5, # 5-fold cross-validation | |
| n_estimators=100, # Number of trees in Random Forest | |
| prediction_type="regression" | |
| ) | |
| print("WAB-AQ prediction model training complete!") | |
| except Exception as e: | |
| print(f"Error training prediction model: {str(e)}") | |
| # Prepare status message with accuracy metrics | |
| if accuracy_metrics: | |
| avg = accuracy_metrics["Average"] | |
| status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n" | |
| f"Reconstruction Accuracy Metrics (Average):\n" | |
| f"• MSE: {avg['MSE']:.6f}\n" | |
| f"• RMSE: {avg['RMSE']:.6f}\n" | |
| f"• R²: {avg['R²']:.6f}\n" | |
| f"• Correlation: {avg['Correlation']:.6f}\n" | |
| f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n") | |
| # Add prediction model results if available | |
| if predictor_results is not None: | |
| cv_results = predictor_results.get('cv_results', {}) | |
| mean_metrics = cv_results.get('mean_metrics', {}) | |
| if mean_metrics and 'r2' in mean_metrics: | |
| prediction_r2 = mean_metrics.get('r2', 0) | |
| prediction_rmse = mean_metrics.get('rmse', 0) | |
| status += (f"WAB-AQ Prediction Model Performance:\n" | |
| f"• R²: {prediction_r2:.4f}\n" | |
| f"• RMSE: {prediction_rmse:.4f}\n\n") | |
| status += f"Latent representations saved to results/latents_dim{latent_dim}.pkl" | |
| else: | |
| status = "Analysis complete! VAE model has been trained and demographic relationships analyzed." | |
| else: | |
| status = "Analysis complete, but no results were returned for accuracy calculation." | |
| return fig, status | |
| def create_interface(): | |
| with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface: | |
| gr.Markdown(""" | |
| # Aphasia fMRI to FC Analysis using VAE | |
| This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables. | |
| ## Dataset Information | |
| By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables: | |
| - ID: Subject identifier | |
| - wab_aq: Aphasia severity score | |
| - age: Age of the subject | |
| - mpo: Months post onset | |
| - education: Years of education | |
| - gender: Subject gender | |
| - handedness: Subject handedness (ignored in the analysis) | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Configuration parameters | |
| data_source = gr.Textbox( | |
| label="Data Source (HF Dataset ID or Local Directory)", | |
| value="SreekarB/OSFData" | |
| ) | |
| latent_dim = gr.Slider( | |
| minimum=8, maximum=64, step=8, | |
| label="Latent Dimensions", value=32 | |
| ) | |
| nepochs = gr.Slider( | |
| minimum=100, maximum=5000, step=100, | |
| label="Number of Epochs", value=200 # Reduced for faster demos | |
| ) | |
| bsize = gr.Slider( | |
| minimum=8, maximum=64, step=8, | |
| label="Batch Size", value=16 | |
| ) | |
| use_hf_dataset = gr.Checkbox( | |
| label="Use HuggingFace Dataset", value=True | |
| ) | |
| # Training button | |
| train_button = gr.Button("Start Training", variant="primary") | |
| status_text = gr.Textbox(label="Status", value="Ready to start training") | |
| with gr.Column(scale=2): | |
| # Output plot | |
| output_plot = gr.Plot(label="Analysis Results") | |
| accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here") | |
| # Link the training button to the analysis function | |
| train_button.click( | |
| fn=gradio_fc_analysis, | |
| inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset], | |
| outputs=[output_plot, status_text] | |
| ) | |
| # Custom function to update the accuracy box | |
| def update_accuracy_display(status_text): | |
| if "Accuracy Metrics" in status_text: | |
| # Extract the accuracy metrics section | |
| parts = status_text.split("Reconstruction Accuracy Metrics (Average):") | |
| if len(parts) > 1: | |
| metrics_text = parts[1].split("\n\n")[0] | |
| return f"### Reconstruction Accuracy Metrics\n{metrics_text}" | |
| return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics." | |
| # Update accuracy box when status changes | |
| status_text.change( | |
| fn=update_accuracy_display, | |
| inputs=[status_text], | |
| outputs=[accuracy_box] | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo | |
| ], | |
| inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset], | |
| ) | |
| # Add explanation of the workflow | |
| gr.Markdown(""" | |
| ## How this works | |
| 1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset | |
| 2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices | |
| 3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity | |
| 4. **Predictive Modeling**: The system trains a Random Forest regressor on latent features to predict WAB-AQ scores (aphasia severity) | |
| 5. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables | |
| 6. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations | |
| Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information. | |
| """) | |
| return iface | |
| if __name__ == "__main__": | |
| iface = create_interface() | |
| iface.launch(share=True) | |