import gradio as gr from main import run_fc_analysis import os import numpy as np from sklearn.metrics import mean_squared_error, r2_score import json import pickle from rcf_prediction import train_predictor_from_latents def calculate_fc_accuracy(original_fc, reconstructed_fc): """ Calculate accuracy metrics between original and reconstructed FC matrices """ # Mean Squared Error (lower is better) mse = mean_squared_error(original_fc, reconstructed_fc) # Root Mean Squared Error (lower is better) rmse = np.sqrt(mse) # R² Score (higher is better, 1 is perfect) r2 = r2_score(original_fc, reconstructed_fc) # Correlation between matrices (higher is better) corr = np.corrcoef(original_fc.flatten(), reconstructed_fc.flatten())[0, 1] # Custom similarity score based on normalized dot product (higher is better) norm_dot = np.dot(original_fc.flatten(), reconstructed_fc.flatten()) / ( np.linalg.norm(original_fc.flatten()) * np.linalg.norm(reconstructed_fc.flatten())) return { "MSE": float(mse), "RMSE": float(rmse), "R²": float(r2), "Correlation": float(corr), "Cosine Similarity": float(norm_dot) } def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'): """ Save latent representations and associated demographics to file """ os.makedirs('results', exist_ok=True) # Create a dictionary with latents and demographics data = { 'latents': latents, 'demographics': demographics } if subjects is not None: data['subjects'] = subjects # Save as pickle for easy loading in Python with open(os.path.join('results', file_path), 'wb') as f: pickle.dump(data, f) # Also save as JSON for more universal access json_data = { 'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents, 'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in demographics.items()} } if subjects is not None: json_data['subjects'] = subjects with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f: json.dump(json_data, f) return os.path.join('results', file_path) def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset): """Run the full VAE analysis pipeline with accuracy metrics""" # Run the original analysis fig, results = run_fc_analysis( data_dir=data_source, demographic_file=None, # We're now getting demographics directly from the dataset latent_dim=latent_dim, nepochs=nepochs, bsize=bsize, save_model=True, use_hf_dataset=use_hf_dataset, return_data=True # New parameter to return data, will need to update main.py ) if results: vae = results.get('vae') X = results.get('X') latents = results.get('latents') demographics = results.get('demographics') reconstructed_fc = results.get('reconstructed_fc') generated_fc = results.get('generated_fc') outcome_measures = results.get('outcome_measures', None) # Calculate accuracy metrics accuracy_metrics = {} if X is not None and reconstructed_fc is not None: for i in range(min(5, len(X))): # Calculate for up to 5 samples metrics = calculate_fc_accuracy(X[i], reconstructed_fc[i]) accuracy_metrics[f"Subject_{i+1}"] = metrics # Average metrics across subjects avg_metrics = {} for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]: avg_metrics[metric] = np.mean([subject_metrics[metric] for subject_metrics in accuracy_metrics.values()]) accuracy_metrics["Average"] = avg_metrics # Save latent representations if available if latents is not None and demographics is not None: latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl') print(f"Saved latents to {latents_path}") # Train a predictor model if we have outcome measures predictor_results = None if outcome_measures is not None and 'wab_aq' in outcome_measures: try: print("Training WAB-AQ prediction model from latent representations...") wab_scores = np.array(outcome_measures['wab_aq']) # Filter out any NaN values valid_indices = ~np.isnan(wab_scores) if np.sum(valid_indices) > 5: # Only train with sufficient data filtered_latents = latents[valid_indices] filtered_wab = wab_scores[valid_indices] # Extract demographic features for the model filtered_demographics = {} for key, values in demographics.items(): if isinstance(values, (list, np.ndarray)) and len(values) >= len(valid_indices): filtered_demographics[key] = np.array(values)[valid_indices] # Train the prediction model with cross-validation predictor_results = train_predictor_from_latents( filtered_latents, filtered_wab, filtered_demographics, cv=5, # 5-fold cross-validation n_estimators=100, # Number of trees in Random Forest prediction_type="regression" ) print("WAB-AQ prediction model training complete!") except Exception as e: print(f"Error training prediction model: {str(e)}") # Prepare status message with accuracy metrics if accuracy_metrics: avg = accuracy_metrics["Average"] status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n" f"Reconstruction Accuracy Metrics (Average):\n" f"• MSE: {avg['MSE']:.6f}\n" f"• RMSE: {avg['RMSE']:.6f}\n" f"• R²: {avg['R²']:.6f}\n" f"• Correlation: {avg['Correlation']:.6f}\n" f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n") # Add prediction model results if available if predictor_results is not None: cv_results = predictor_results.get('cv_results', {}) mean_metrics = cv_results.get('mean_metrics', {}) if mean_metrics and 'r2' in mean_metrics: prediction_r2 = mean_metrics.get('r2', 0) prediction_rmse = mean_metrics.get('rmse', 0) status += (f"WAB-AQ Prediction Model Performance:\n" f"• R²: {prediction_r2:.4f}\n" f"• RMSE: {prediction_rmse:.4f}\n\n") status += f"Latent representations saved to results/latents_dim{latent_dim}.pkl" else: status = "Analysis complete! VAE model has been trained and demographic relationships analyzed." else: status = "Analysis complete, but no results were returned for accuracy calculation." return fig, status def create_interface(): with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface: gr.Markdown(""" # Aphasia fMRI to FC Analysis using VAE This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables. ## Dataset Information By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables: - ID: Subject identifier - wab_aq: Aphasia severity score - age: Age of the subject - mpo: Months post onset - education: Years of education - gender: Subject gender - handedness: Subject handedness (ignored in the analysis) """) with gr.Row(): with gr.Column(scale=1): # Configuration parameters data_source = gr.Textbox( label="Data Source (HF Dataset ID or Local Directory)", value="SreekarB/OSFData" ) latent_dim = gr.Slider( minimum=8, maximum=64, step=8, label="Latent Dimensions", value=32 ) nepochs = gr.Slider( minimum=100, maximum=5000, step=100, label="Number of Epochs", value=200 # Reduced for faster demos ) bsize = gr.Slider( minimum=8, maximum=64, step=8, label="Batch Size", value=16 ) use_hf_dataset = gr.Checkbox( label="Use HuggingFace Dataset", value=True ) # Training button train_button = gr.Button("Start Training", variant="primary") status_text = gr.Textbox(label="Status", value="Ready to start training") with gr.Column(scale=2): # Output plot output_plot = gr.Plot(label="Analysis Results") accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here") # Link the training button to the analysis function train_button.click( fn=gradio_fc_analysis, inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset], outputs=[output_plot, status_text] ) # Custom function to update the accuracy box def update_accuracy_display(status_text): if "Accuracy Metrics" in status_text: # Extract the accuracy metrics section parts = status_text.split("Reconstruction Accuracy Metrics (Average):") if len(parts) > 1: metrics_text = parts[1].split("\n\n")[0] return f"### Reconstruction Accuracy Metrics\n{metrics_text}" return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics." # Update accuracy box when status changes status_text.change( fn=update_accuracy_display, inputs=[status_text], outputs=[accuracy_box] ) # Add examples gr.Examples( examples=[ ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo ], inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset], ) # Add explanation of the workflow gr.Markdown(""" ## How this works 1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset 2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices 3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity 4. **Predictive Modeling**: The system trains a Random Forest regressor on latent features to predict WAB-AQ scores (aphasia severity) 5. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables 6. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information. """) return iface if __name__ == "__main__": iface = create_interface() iface.launch(share=True)