Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from main import run_fc_analysis | |
| import os | |
| import numpy as np | |
| from sklearn.metrics import mean_squared_error, r2_score | |
| import json | |
| import pickle | |
| def calculate_fc_accuracy(original_fc, reconstructed_fc): | |
| """ | |
| Calculate accuracy metrics between original and reconstructed FC matrices | |
| """ | |
| try: | |
| # Ensure inputs are 1D arrays for metrics calculations | |
| if hasattr(original_fc, 'flatten') and len(original_fc.shape) > 1: | |
| original_flat = original_fc.flatten() | |
| else: | |
| original_flat = original_fc | |
| if hasattr(reconstructed_fc, 'flatten') and len(reconstructed_fc.shape) > 1: | |
| recon_flat = reconstructed_fc.flatten() | |
| else: | |
| recon_flat = reconstructed_fc | |
| # Mean Squared Error (lower is better) | |
| mse = mean_squared_error(original_flat, recon_flat) | |
| # Root Mean Squared Error (lower is better) | |
| rmse = np.sqrt(mse) | |
| # R² Score (higher is better, 1 is perfect) | |
| r2 = r2_score(original_flat, recon_flat) | |
| # Correlation between matrices (higher is better) | |
| corr = np.corrcoef(original_flat, recon_flat)[0, 1] | |
| # Custom similarity score based on normalized dot product (higher is better) | |
| norm_dot = np.dot(original_flat, recon_flat) / ( | |
| np.linalg.norm(original_flat) * np.linalg.norm(recon_flat)) | |
| return { | |
| "MSE": float(mse), | |
| "RMSE": float(rmse), | |
| "R²": float(r2), | |
| "Correlation": float(corr), | |
| "Cosine Similarity": float(norm_dot) | |
| } | |
| except Exception as e: | |
| print(f"Error calculating accuracy metrics: {e}") | |
| return { | |
| "MSE": float('nan'), | |
| "RMSE": float('nan'), | |
| "R²": float('nan'), | |
| "Correlation": float('nan'), | |
| "Cosine Similarity": float('nan') | |
| } | |
| def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'): | |
| """ | |
| Save latent representations and associated demographics to file | |
| """ | |
| os.makedirs('results', exist_ok=True) | |
| # Create a dictionary with latents and demographics | |
| data = { | |
| 'latents': latents, | |
| 'demographics': demographics | |
| } | |
| if subjects is not None: | |
| data['subjects'] = subjects | |
| # Save as pickle for easy loading in Python | |
| with open(os.path.join('results', file_path), 'wb') as f: | |
| pickle.dump(data, f) | |
| # Also save as JSON for more universal access | |
| json_data = { | |
| 'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents, | |
| 'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v | |
| for k, v in demographics.items()} | |
| } | |
| if subjects is not None: | |
| json_data['subjects'] = subjects | |
| with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f: | |
| json.dump(json_data, f) | |
| return os.path.join('results', file_path) | |
| def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset): | |
| """Run the full VAE analysis pipeline with accuracy metrics""" | |
| # Add some initial status | |
| print(f"Starting FC analysis with latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}") | |
| print(f"Using dataset: {data_source} (HuggingFace API: {use_hf_dataset})") | |
| try: | |
| # Run the original analysis | |
| fig, results = run_fc_analysis( | |
| data_dir=data_source, | |
| demographic_file=None, # We're now getting demographics directly from the dataset | |
| latent_dim=latent_dim, | |
| nepochs=nepochs, | |
| bsize=bsize, | |
| save_model=True, | |
| use_hf_dataset=use_hf_dataset, | |
| return_data=True | |
| ) | |
| except Exception as e: | |
| import traceback | |
| error_msg = f"Error during analysis: {str(e)}\n{traceback.format_exc()}" | |
| print(error_msg) | |
| # Create an error figure | |
| import matplotlib.pyplot as plt | |
| fig = plt.figure(figsize=(10, 6)) | |
| plt.text(0.5, 0.5, f"Error: {str(e)}", | |
| ha='center', va='center', fontsize=12, color='red') | |
| plt.axis('off') | |
| return fig, f"Analysis failed with error: {str(e)}" | |
| if results: | |
| vae = results.get('vae') | |
| X = results.get('X') | |
| latents = results.get('latents') | |
| demographics = results.get('demographics') | |
| reconstructed_fc = results.get('reconstructed_fc') | |
| generated_fc = results.get('generated_fc') | |
| # Calculate accuracy metrics | |
| accuracy_metrics = {} | |
| if X is not None and reconstructed_fc is not None: | |
| for i in range(min(5, len(X))): # Calculate for up to 5 samples | |
| # Convert to full matrices if needed | |
| original = X[i] | |
| recon = reconstructed_fc[i] | |
| try: | |
| # Check if we need to convert from vector to matrix | |
| if len(original.shape) == 1: | |
| from visualization import vector_to_matrix | |
| print(f"Converting subject {i+1} vectors to matrices") | |
| original = vector_to_matrix(original) | |
| recon = vector_to_matrix(recon) | |
| # Flatten matrices for consistent comparison if they're not already flat | |
| if len(original.shape) > 1: | |
| original_flat = original.flatten() | |
| recon_flat = recon.flatten() | |
| else: | |
| original_flat = original | |
| recon_flat = recon | |
| except Exception as e: | |
| print(f"Error converting matrices for subject {i+1}: {e}") | |
| continue # Skip this subject | |
| metrics = calculate_fc_accuracy(original_flat, recon_flat) | |
| accuracy_metrics[f"Subject_{i+1}"] = metrics | |
| # Average metrics across subjects | |
| avg_metrics = {} | |
| for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]: | |
| avg_metrics[metric] = np.mean([subject_metrics[metric] | |
| for subject_metrics in accuracy_metrics.values()]) | |
| accuracy_metrics["Average"] = avg_metrics | |
| # Save latent representations if available | |
| if latents is not None and demographics is not None: | |
| latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl') | |
| print(f"Saved latents to {latents_path}") | |
| # Prepare status message with accuracy metrics | |
| if accuracy_metrics: | |
| avg = accuracy_metrics["Average"] | |
| status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n" | |
| f"Reconstruction Accuracy Metrics (Average):\n" | |
| f"• MSE: {avg['MSE']:.6f}\n" | |
| f"• RMSE: {avg['RMSE']:.6f}\n" | |
| f"• R²: {avg['R²']:.6f}\n" | |
| f"• Correlation: {avg['Correlation']:.6f}\n" | |
| f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n" | |
| f"Latent representations saved to results/latents_dim{latent_dim}.pkl") | |
| else: | |
| status = "Analysis complete! VAE model has been trained and demographic relationships analyzed." | |
| else: | |
| status = "Analysis complete, but no results were returned for accuracy calculation." | |
| return fig, status | |
| def run_demo(): | |
| with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as interface: | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| # Configuration inputs | |
| with gr.Box(): # Switched to Box to avoid any Group issues | |
| gr.Markdown("### Configuration") | |
| data_source = gr.Textbox(value="SreekarB/OSFData", label="Data Source (HuggingFace dataset or directory)") | |
| use_hf_checkbox = gr.Checkbox(value=True, label="Use HuggingFace Dataset API") | |
| latent_dim = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="Latent Dimension") | |
| nepochs = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Training Epochs") | |
| bsize = gr.Slider(minimum=2, maximum=16, value=4, step=1, label="Batch Size") | |
| # Add a note about training time | |
| gr.Markdown("**Note:** Lower values for latent dimension, epochs, and batch size will train faster.") | |
| train_button = gr.Button("Train VAE & Analyze FC", variant="primary") | |
| # Status output area | |
| status_text = gr.Textbox(label="Status", lines=10, interactive=False) | |
| with gr.Column(scale=2): | |
| # Output plot | |
| output_plot = gr.Plot(label="FC Matrix Analysis") | |
| accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here") | |
| # Link the training button to the analysis function | |
| train_button.click( | |
| fn=gradio_fc_analysis, | |
| inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox], | |
| outputs=[output_plot, status_text] | |
| ) | |
| # Custom function to update the accuracy box | |
| def update_accuracy_display(status_text): | |
| if "Accuracy Metrics" in status_text: | |
| # Extract the accuracy metrics section | |
| parts = status_text.split("Reconstruction Accuracy Metrics (Average):") | |
| if len(parts) > 1: | |
| metrics_text = parts[1].split("\n\n")[0] | |
| return f"### Reconstruction Accuracy Metrics\n{metrics_text}" | |
| return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics." | |
| # Update accuracy box when status changes | |
| status_text.change( | |
| fn=update_accuracy_display, | |
| inputs=[status_text], | |
| outputs=[accuracy_box] | |
| ) | |
| # Add examples | |
| gr.Examples( | |
| examples=[ | |
| ["SreekarB/OSFData", 32, 100, 8, True], | |
| ["SreekarB/OSFData", 16, 200, 16, True], | |
| ["SreekarB/OSFData", 64, 50, 4, True], | |
| ], | |
| inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox], | |
| ) | |
| return interface | |
| if __name__ == "__main__": | |
| demo = run_demo() | |
| demo.launch() |