AphasiaPred / app.py
SreekarB's picture
Upload 13 files
37a1b01 verified
import gradio as gr
from main import run_fc_analysis
import os
import numpy as np
from sklearn.metrics import mean_squared_error, r2_score
import json
import pickle
from rcf_prediction import train_predictor_from_latents
def calculate_fc_accuracy(original_fc, reconstructed_fc):
"""
Calculate accuracy metrics between original and reconstructed FC matrices
"""
# Mean Squared Error (lower is better)
mse = mean_squared_error(original_fc, reconstructed_fc)
# Root Mean Squared Error (lower is better)
rmse = np.sqrt(mse)
# R² Score (higher is better, 1 is perfect)
r2 = r2_score(original_fc, reconstructed_fc)
# Correlation between matrices (higher is better)
corr = np.corrcoef(original_fc.flatten(), reconstructed_fc.flatten())[0, 1]
# Custom similarity score based on normalized dot product (higher is better)
norm_dot = np.dot(original_fc.flatten(), reconstructed_fc.flatten()) / (
np.linalg.norm(original_fc.flatten()) * np.linalg.norm(reconstructed_fc.flatten()))
return {
"MSE": float(mse),
"RMSE": float(rmse),
"R²": float(r2),
"Correlation": float(corr),
"Cosine Similarity": float(norm_dot)
}
def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'):
"""
Save latent representations and associated demographics to file
"""
os.makedirs('results', exist_ok=True)
# Create a dictionary with latents and demographics
data = {
'latents': latents,
'demographics': demographics
}
if subjects is not None:
data['subjects'] = subjects
# Save as pickle for easy loading in Python
with open(os.path.join('results', file_path), 'wb') as f:
pickle.dump(data, f)
# Also save as JSON for more universal access
json_data = {
'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents,
'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v
for k, v in demographics.items()}
}
if subjects is not None:
json_data['subjects'] = subjects
with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f:
json.dump(json_data, f)
return os.path.join('results', file_path)
def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
"""Run the full VAE analysis pipeline with accuracy metrics"""
# Run the original analysis
fig, results = run_fc_analysis(
data_dir=data_source,
demographic_file=None, # We're now getting demographics directly from the dataset
latent_dim=latent_dim,
nepochs=nepochs,
bsize=bsize,
save_model=True,
use_hf_dataset=use_hf_dataset,
return_data=True # New parameter to return data, will need to update main.py
)
if results:
vae = results.get('vae')
X = results.get('X')
latents = results.get('latents')
demographics = results.get('demographics')
reconstructed_fc = results.get('reconstructed_fc')
generated_fc = results.get('generated_fc')
outcome_measures = results.get('outcome_measures', None)
# Calculate accuracy metrics
accuracy_metrics = {}
if X is not None and reconstructed_fc is not None:
for i in range(min(5, len(X))): # Calculate for up to 5 samples
metrics = calculate_fc_accuracy(X[i], reconstructed_fc[i])
accuracy_metrics[f"Subject_{i+1}"] = metrics
# Average metrics across subjects
avg_metrics = {}
for metric in ["MSE", "RMSE", "R²", "Correlation", "Cosine Similarity"]:
avg_metrics[metric] = np.mean([subject_metrics[metric]
for subject_metrics in accuracy_metrics.values()])
accuracy_metrics["Average"] = avg_metrics
# Save latent representations if available
if latents is not None and demographics is not None:
latents_path = save_latents(latents, demographics, file_path=f'latents_dim{latent_dim}.pkl')
print(f"Saved latents to {latents_path}")
# Train a predictor model if we have outcome measures
predictor_results = None
if outcome_measures is not None and 'wab_aq' in outcome_measures:
try:
print("Training WAB-AQ prediction model from latent representations...")
wab_scores = np.array(outcome_measures['wab_aq'])
# Filter out any NaN values
valid_indices = ~np.isnan(wab_scores)
if np.sum(valid_indices) > 5: # Only train with sufficient data
filtered_latents = latents[valid_indices]
filtered_wab = wab_scores[valid_indices]
# Extract demographic features for the model
filtered_demographics = {}
for key, values in demographics.items():
if isinstance(values, (list, np.ndarray)) and len(values) >= len(valid_indices):
filtered_demographics[key] = np.array(values)[valid_indices]
# Train the prediction model with cross-validation
predictor_results = train_predictor_from_latents(
filtered_latents,
filtered_wab,
filtered_demographics,
cv=5, # 5-fold cross-validation
n_estimators=100, # Number of trees in Random Forest
prediction_type="regression"
)
print("WAB-AQ prediction model training complete!")
except Exception as e:
print(f"Error training prediction model: {str(e)}")
# Prepare status message with accuracy metrics
if accuracy_metrics:
avg = accuracy_metrics["Average"]
status = (f"Analysis complete! Model trained with {latent_dim} dimensions.\n\n"
f"Reconstruction Accuracy Metrics (Average):\n"
f"• MSE: {avg['MSE']:.6f}\n"
f"• RMSE: {avg['RMSE']:.6f}\n"
f"• R²: {avg['R²']:.6f}\n"
f"• Correlation: {avg['Correlation']:.6f}\n"
f"• Cosine Similarity: {avg['Cosine Similarity']:.6f}\n\n")
# Add prediction model results if available
if predictor_results is not None:
cv_results = predictor_results.get('cv_results', {})
mean_metrics = cv_results.get('mean_metrics', {})
if mean_metrics and 'r2' in mean_metrics:
prediction_r2 = mean_metrics.get('r2', 0)
prediction_rmse = mean_metrics.get('rmse', 0)
status += (f"WAB-AQ Prediction Model Performance:\n"
f"• R²: {prediction_r2:.4f}\n"
f"• RMSE: {prediction_rmse:.4f}\n\n")
status += f"Latent representations saved to results/latents_dim{latent_dim}.pkl"
else:
status = "Analysis complete! VAE model has been trained and demographic relationships analyzed."
else:
status = "Analysis complete, but no results were returned for accuracy calculation."
return fig, status
def create_interface():
with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
gr.Markdown("""
# Aphasia fMRI to FC Analysis using VAE
This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
## Dataset Information
By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
- ID: Subject identifier
- wab_aq: Aphasia severity score
- age: Age of the subject
- mpo: Months post onset
- education: Years of education
- gender: Subject gender
- handedness: Subject handedness (ignored in the analysis)
""")
with gr.Row():
with gr.Column(scale=1):
# Configuration parameters
data_source = gr.Textbox(
label="Data Source (HF Dataset ID or Local Directory)",
value="SreekarB/OSFData"
)
latent_dim = gr.Slider(
minimum=8, maximum=64, step=8,
label="Latent Dimensions", value=32
)
nepochs = gr.Slider(
minimum=100, maximum=5000, step=100,
label="Number of Epochs", value=200 # Reduced for faster demos
)
bsize = gr.Slider(
minimum=8, maximum=64, step=8,
label="Batch Size", value=16
)
use_hf_dataset = gr.Checkbox(
label="Use HuggingFace Dataset", value=True
)
# Training button
train_button = gr.Button("Start Training", variant="primary")
status_text = gr.Textbox(label="Status", value="Ready to start training")
with gr.Column(scale=2):
# Output plot
output_plot = gr.Plot(label="Analysis Results")
accuracy_box = gr.Markdown("### Accuracy Metrics\nRun analysis to see reconstruction accuracy metrics here")
# Link the training button to the analysis function
train_button.click(
fn=gradio_fc_analysis,
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
outputs=[output_plot, status_text]
)
# Custom function to update the accuracy box
def update_accuracy_display(status_text):
if "Accuracy Metrics" in status_text:
# Extract the accuracy metrics section
parts = status_text.split("Reconstruction Accuracy Metrics (Average):")
if len(parts) > 1:
metrics_text = parts[1].split("\n\n")[0]
return f"### Reconstruction Accuracy Metrics\n{metrics_text}"
return "### Accuracy Metrics\nNo metrics available yet. Run analysis to generate metrics."
# Update accuracy box when status changes
status_text.change(
fn=update_accuracy_display,
inputs=[status_text],
outputs=[accuracy_box]
)
# Add examples
gr.Examples(
examples=[
["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
],
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
)
# Add explanation of the workflow
gr.Markdown("""
## How this works
1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
4. **Predictive Modeling**: The system trains a Random Forest regressor on latent features to predict WAB-AQ scores (aphasia severity)
5. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
6. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
""")
return iface
if __name__ == "__main__":
iface = create_interface()
iface.launch(share=True)