import gradio as gr from main import run_fc_analysis import os import numpy as np from sklearn.metrics import mean_squared_error, r2_score import json import pickle def calculate_fc_accuracy(original_fc, reconstructed_fc): """ Calculate accuracy metrics between original and reconstructed FC matrices """ try: # Ensure inputs are 1D arrays for metrics calculations if hasattr(original_fc, 'flatten') and len(original_fc.shape) > 1: original_flat = original_fc.flatten() else: original_flat = original_fc if hasattr(reconstructed_fc, 'flatten') and len(reconstructed_fc.shape) > 1: recon_flat = reconstructed_fc.flatten() else: recon_flat = reconstructed_fc # Mean Squared Error (lower is better) mse = mean_squared_error(original_flat, recon_flat) # Root Mean Squared Error (lower is better) rmse = np.sqrt(mse) # R² Score (higher is better, 1 is perfect) r2 = r2_score(original_flat, recon_flat) # Correlation between matrices (higher is better) corr = np.corrcoef(original_flat, recon_flat)[0, 1] # Custom similarity score based on normalized dot product (higher is better) norm_dot = np.dot(original_flat, recon_flat) / ( np.linalg.norm(original_flat) * np.linalg.norm(recon_flat)) return { "MSE": float(mse), "RMSE": float(rmse), "R²": float(r2), "Correlation": float(corr), "Cosine Similarity": float(norm_dot) } except Exception as e: print(f"Error calculating accuracy metrics: {e}") return { "MSE": float('nan'), "RMSE": float('nan'), "R²": float('nan'), "Correlation": float('nan'), "Cosine Similarity": float('nan') } def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'): """ Save latent representations and associated demographics to file """ os.makedirs('results', exist_ok=True) # Create a dictionary with latents and demographics data = { 'latents': latents, 'demographics': demographics } if subjects is not None: data['subjects'] = subjects # Save as pickle for easy loading in Python with open(os.path.join('results', file_path), 'wb') as f: pickle.dump(data, f) # Also save as JSON for more universal access json_data = { 'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents, 'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v for k, v in demographics.items()} } if subjects is not None: json_data['subjects'] = subjects with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f: json.dump(json_data, f) return os.path.join('results', file_path) def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset): """Run the full VAE analysis pipeline with accuracy metrics""" # Add some initial status print(f"Starting FC analysis with latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}") print(f"Using dataset: {data_source} (HuggingFace API: {use_hf_dataset})") try: # Run the original analysis fig, results = run_fc_analysis( data_dir=data_source, demographic_file=None, # We're now getting demographics directly from the dataset latent_dim=latent_dim, nepochs=nepochs, bsize=bsize, save_model=True, use_hf_dataset=use_hf_dataset, return_data=True ) except Exception as e: import traceback error_msg = f"Error during analysis: {str(e)}\n{traceback.format_exc()}" print(error_msg) # Create an error figure import matplotlib.pyplot as plt fig = plt.figure(figsize=(10, 6)) plt.text(0.5, 0.5, f"Error: {str(e)}", ha='center', va='center', fontsize=12, color='red') plt.axis('off') return fig, f"Analysis failed with error: {str(e)}" if results: vae = results.get('vae') X = results.get('X') latents = results.get('latents') demographics = results.get('demographics') reconstructed_fc = results.get('reconstructed_fc') generated_fc = results.get('generated_fc') # Calculate accuracy metrics accuracy_metrics = {} if X is not None and reconstructed_fc is not None: for i in range(min(5, len(X))): # Calculate for up to 5 samples # Convert to full matrices if needed original = X[i] recon = reconstructed_fc[i] try: # Check if we need to convert from vector to matrix if len(original.shape) == 1: from visualization import vector_to_matrix print(f"Converting subject {i+1} vectors to matrices") original = vector_to_matrix(original) recon = vector_to_matrix(recon) # Flatten matrices for consistent comparison if they're not already flat if len(original.shape) > 1: original_flat = original.flatten() recon_flat = recon.flatten() else: original_flat = original recon_flat = recon except Exception as e: print(f"Error converting matrices for subject {i+1}: {e}") continue # Skip this subject metrics = calculate_fc_accuracy(original_flat, recon_flat) accuracy_metrics[f"Subject_{i+1}"] = metrics # Average metrics across subjects avg_metrics = {} for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]: avg_metrics[metric] = np.mean([subject_metrics[metric] for subject_metrics in accuracy_metrics.values()]) accuracy_metrics["Average"] = avg_metrics # Save latent representations if available if latents is not None and demographics is not None: latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl') print(f"Saved latents to {latents_path}") # Prepare status message with accuracy metrics if accuracy_metrics: avg = accuracy_metrics["Average"] status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n" f"Reconstruction Accuracy Metrics (Average):\n" f"• MSE: {avg['MSE']:.6f}\n" f"• RMSE: {avg['RMSE']:.6f}\n" f"• R²: {avg['R²']:.6f}\n" f"• Correlation: {avg['Correlation']:.6f}\n" f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n" f"Latent representations saved to results/latents_dim{latent_dim}.pkl") else: status = "Analysis complete! VAE model has been trained and demographic relationships analyzed." else: status = "Analysis complete, but no results were returned for accuracy calculation." return fig, status def run_demo(): with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as interface: with gr.Row(): with gr.Column(scale=1): # Configuration inputs with gr.Box(): # Switched to Box to avoid any Group issues gr.Markdown("### Configuration") data_source = gr.Textbox(value="SreekarB/OSFData", label="Data Source (HuggingFace dataset or directory)") use_hf_checkbox = gr.Checkbox(value=True, label="Use HuggingFace Dataset API") latent_dim = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="Latent Dimension") nepochs = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Training Epochs") bsize = gr.Slider(minimum=2, maximum=16, value=4, step=1, label="Batch Size") # Add a note about training time gr.Markdown("**Note:** Lower values for latent dimension, epochs, and batch size will train faster.") train_button = gr.Button("Train VAE & Analyze FC", variant="primary") # Status output area status_text = gr.Textbox(label="Status", lines=10, interactive=False) with gr.Column(scale=2): # Output plot output_plot = gr.Plot(label="FC Matrix Analysis") accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here") # Link the training button to the analysis function train_button.click( fn=gradio_fc_analysis, inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox], outputs=[output_plot, status_text] ) # Custom function to update the accuracy box def update_accuracy_display(status_text): if "Accuracy Metrics" in status_text: # Extract the accuracy metrics section parts = status_text.split("Reconstruction Accuracy Metrics (Average):") if len(parts) > 1: metrics_text = parts[1].split("\n\n")[0] return f"### Reconstruction Accuracy Metrics\n{metrics_text}" return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics." # Update accuracy box when status changes status_text.change( fn=update_accuracy_display, inputs=[status_text], outputs=[accuracy_box] ) # Add examples gr.Examples( examples=[ ["SreekarB/OSFData", 32, 100, 8, True], ["SreekarB/OSFData", 16, 200, 16, True], ["SreekarB/OSFData", 64, 50, 4, True], ], inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox], ) return interface if __name__ == "__main__": demo = run_demo() demo.launch()