AphasiaPred / app_fixed.py
SreekarB's picture
Upload 36 files
71fbc82 verified
import gradio as gr
from main import run_fc_analysis
import os
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
import json
import pickle
def calculate_fc_accuracy(original_fc, reconstructed_fc):
"""
Calculate accuracy metrics between original and reconstructed FC matrices
"""
try:
# Ensure inputs are 1D arrays for metrics calculations
if hasattr(original_fc, 'flatten') and len(original_fc.shape) > 1:
original_flat = original_fc.flatten()
else:
original_flat = original_fc
if hasattr(reconstructed_fc, 'flatten') and len(reconstructed_fc.shape) > 1:
recon_flat = reconstructed_fc.flatten()
else:
recon_flat = reconstructed_fc
# Mean Squared Error (lower is better)
mse = mean_squared_error(original_flat, recon_flat)
# Root Mean Squared Error (lower is better)
rmse = np.sqrt(mse)
# R² Score (higher is better, 1 is perfect)
r2 = r2_score(original_flat, recon_flat)
# Correlation between matrices (higher is better)
corr = np.corrcoef(original_flat, recon_flat)[0, 1]
# Custom similarity score based on normalized dot product (higher is better)
norm_dot = np.dot(original_flat, recon_flat) / (
np.linalg.norm(original_flat) * np.linalg.norm(recon_flat))
return {
"MSE": float(mse),
"RMSE": float(rmse),
"R²": float(r2),
"Correlation": float(corr),
"Cosine Similarity": float(norm_dot)
}
except Exception as e:
print(f"Error calculating accuracy metrics: {e}")
return {
"MSE": float('nan'),
"RMSE": float('nan'),
"R²": float('nan'),
"Correlation": float('nan'),
"Cosine Similarity": float('nan')
}
def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'):
"""
Save latent representations and associated demographics to file
"""
os.makedirs('results', exist_ok=True)
# Create a dictionary with latents and demographics
data = {
'latents': latents,
'demographics': demographics
}
if subjects is not None:
data['subjects'] = subjects
# Save as pickle for easy loading in Python
with open(os.path.join('results', file_path), 'wb') as f:
pickle.dump(data, f)
# Also save as JSON for more universal access
json_data = {
'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents,
'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v
for k, v in demographics.items()}
}
if subjects is not None:
json_data['subjects'] = subjects
with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f:
json.dump(json_data, f)
return os.path.join('results', file_path)
def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
"""Run the full VAE analysis pipeline with accuracy metrics"""
# Add some initial status
print(f"Starting FC analysis with latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}")
print(f"Using dataset: {data_source} (HuggingFace API: {use_hf_dataset})")
try:
# Run the original analysis
fig, results = run_fc_analysis(
data_dir=data_source,
demographic_file=None, # We're now getting demographics directly from the dataset
latent_dim=latent_dim,
nepochs=nepochs,
bsize=bsize,
save_model=True,
use_hf_dataset=use_hf_dataset,
return_data=True
)
except Exception as e:
import traceback
error_msg = f"Error during analysis: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
# Create an error figure
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 6))
plt.text(0.5, 0.5, f"Error: {str(e)}",
ha='center', va='center', fontsize=12, color='red')
plt.axis('off')
return fig, f"Analysis failed with error: {str(e)}"
if results:
vae = results.get('vae')
X = results.get('X')
latents = results.get('latents')
demographics = results.get('demographics')
reconstructed_fc = results.get('reconstructed_fc')
generated_fc = results.get('generated_fc')
# Calculate accuracy metrics
accuracy_metrics = {}
if X is not None and reconstructed_fc is not None:
for i in range(min(5, len(X))): # Calculate for up to 5 samples
# Convert to full matrices if needed
original = X[i]
recon = reconstructed_fc[i]
try:
# Check if we need to convert from vector to matrix
if len(original.shape) == 1:
from visualization import vector_to_matrix
print(f"Converting subject {i+1} vectors to matrices")
original = vector_to_matrix(original)
recon = vector_to_matrix(recon)
# Flatten matrices for consistent comparison if they're not already flat
if len(original.shape) > 1:
original_flat = original.flatten()
recon_flat = recon.flatten()
else:
original_flat = original
recon_flat = recon
except Exception as e:
print(f"Error converting matrices for subject {i+1}: {e}")
continue # Skip this subject
metrics = calculate_fc_accuracy(original_flat, recon_flat)
accuracy_metrics[f"Subject_{i+1}"] = metrics
# Average metrics across subjects
avg_metrics = {}
for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]:
avg_metrics[metric] = np.mean([subject_metrics[metric]
for subject_metrics in accuracy_metrics.values()])
accuracy_metrics["Average"] = avg_metrics
# Save latent representations if available
if latents is not None and demographics is not None:
latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
print(f"Saved latents to {latents_path}")
# Prepare status message with accuracy metrics
if accuracy_metrics:
avg = accuracy_metrics["Average"]
status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n"
f"Reconstruction Accuracy Metrics (Average):\n"
f"• MSE: {avg['MSE']:.6f}\n"
f"• RMSE: {avg['RMSE']:.6f}\n"
f"• R²: {avg['R²']:.6f}\n"
f"• Correlation: {avg['Correlation']:.6f}\n"
f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n"
f"Latent representations saved to results/latents_dim{latent_dim}.pkl")
else:
status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
else:
status = "Analysis complete, but no results were returned for accuracy calculation."
return fig, status
def run_demo():
with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as interface:
with gr.Row():
with gr.Column(scale=1):
# Configuration inputs
with gr.Box(): # Switched to Box to avoid any Group issues
gr.Markdown("### Configuration")
data_source = gr.Textbox(value="SreekarB/OSFData", label="Data Source (HuggingFace dataset or directory)")
use_hf_checkbox = gr.Checkbox(value=True, label="Use HuggingFace Dataset API")
latent_dim = gr.Slider(minimum=4, maximum=64, value=16, step=4, label="Latent Dimension")
nepochs = gr.Slider(minimum=10, maximum=200, value=50, step=10, label="Training Epochs")
bsize = gr.Slider(minimum=2, maximum=16, value=4, step=1, label="Batch Size")
# Add a note about training time
gr.Markdown("**Note:** Lower values for latent dimension, epochs, and batch size will train faster.")
train_button = gr.Button("Train VAE & Analyze FC", variant="primary")
# Status output area
status_text = gr.Textbox(label="Status", lines=10, interactive=False)
with gr.Column(scale=2):
# Output plot
output_plot = gr.Plot(label="FC Matrix Analysis")
accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
# Link the training button to the analysis function
train_button.click(
fn=gradio_fc_analysis,
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox],
outputs=[output_plot, status_text]
)
# Custom function to update the accuracy box
def update_accuracy_display(status_text):
if "Accuracy Metrics" in status_text:
# Extract the accuracy metrics section
parts = status_text.split("Reconstruction Accuracy Metrics (Average):")
if len(parts) > 1:
metrics_text = parts[1].split("\n\n")[0]
return f"### Reconstruction Accuracy Metrics\n{metrics_text}"
return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics."
# Update accuracy box when status changes
status_text.change(
fn=update_accuracy_display,
inputs=[status_text],
outputs=[accuracy_box]
)
# Add examples
gr.Examples(
examples=[
["SreekarB/OSFData", 32, 100, 8, True],
["SreekarB/OSFData", 16, 200, 16, True],
["SreekarB/OSFData", 64, 50, 4, True],
],
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_checkbox],
)
return interface
if __name__ == "__main__":
demo = run_demo()
demo.launch()