VAE / app.py
SreekarB's picture
Upload 3 files
0d33b30 verified
import os
import sys
import gradio as gr
import numpy as np
import matplotlib.pyplot as plt
import torch
import pickle
import pandas as pd
import time
import warnings
warnings.filterwarnings('ignore') # Suppress warnings
# Add the current directory to Python path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
# Add PIP package to path
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pip', 'src'))
# Check if running in Hugging Face Spaces
IS_SPACE = os.environ.get('SPACE_ID') is not None
print(f"Running in {'Hugging Face Spaces' if IS_SPACE else 'local environment'}")
# Import VAE model and functions
try:
from osf_demovae_adapter import (
VAE, load_and_process_data, train_demovae_model, predict_aphasia_recovery,
generate_custom_fc, to_torch, to_cuda, to_numpy, vec2mat, mat2vec, ATLAS_REGIONS,
download_model, get_connectivity_visualization
)
print("Successfully imported osf_demovae_adapter modules")
except ImportError as e:
print(f"Error importing osf_demovae_adapter modules: {e} - make sure path is correct")
# Model configuration
# Use /tmp for model storage in Hugging Face Spaces
MODEL_DIR = "/tmp/osf_models" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
os.makedirs(MODEL_DIR, exist_ok=True)
MODEL_PATH = os.path.join(MODEL_DIR, 'osf_demovae_model.pt')
LATENT_DIM = 30
INPUT_DIM = 1000
DEMO_DIM = 5
# Be cautious with CUDA in Spaces as resources may be limited
USE_CUDA = torch.cuda.is_available() and not IS_SPACE # Disable CUDA in Spaces for stability
# Initialize model during startup if needed
model = None
demovae_model = None
prediction_model = None
model_loaded = False
# Helper function for aphasia severity interpretation
def get_aphasia_severity_category(wab_score):
"""Interpret WAB AQ score to determine aphasia severity category"""
if wab_score >= 93.8:
return "No aphasia (within normal limits)"
elif wab_score >= 75:
return "Mild aphasia"
elif wab_score >= 50:
return "Moderate aphasia"
elif wab_score >= 25:
return "Severe aphasia"
else:
return "Very severe aphasia"
def load_model():
"""Load the VAE model and prediction model from disk or download if not available"""
global model, model_loaded, prediction_model, demovae_model
try:
# Try to load both models from the combined pickle file first
combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
# Check if combined model exists, otherwise try to download it
if not os.path.exists(combined_model_path):
try:
print("Combined model file not found. Attempting to download...")
combined_model_path = download_model('combined')
except Exception as e:
print(f"Could not download combined model: {e}")
combined_model_path = None
# If we have a combined model file, load it
if combined_model_path and os.path.exists(combined_model_path):
with open(combined_model_path, 'rb') as f:
models_dict = pickle.load(f)
demovae_model = models_dict['demovae']
prediction_model = models_dict['prediction']
model = demovae_model.vae
print("DemoVAE and prediction models loaded successfully from", combined_model_path)
else:
# Fall back to loading models separately
print("Combined model file not available. Trying to load or download models separately...")
# Check if DemoVAE model exists, otherwise try to download it
if not os.path.exists(MODEL_PATH):
try:
print("DemoVAE model not found. Attempting to download...")
MODEL_PATH = download_model('demovae')
except Exception as e:
print(f"Could not download DemoVAE model: {e}")
return False
# Create model instance and load DemoVAE model
from demovae.sklearn import DemoVAE
demovae_model = DemoVAE(latent_dim=LATENT_DIM, use_cuda=USE_CUDA)
demovae_model.load(MODEL_PATH)
model = demovae_model.vae
print("DemoVAE model loaded successfully from", MODEL_PATH)
# Check for prediction model
pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
if not os.path.exists(pred_model_path):
try:
print("Prediction model not found. Attempting to download...")
pred_model_path = download_model('prediction')
except Exception as e:
print(f"Could not download prediction model: {e}")
print("Warning: Aphasia score prediction will not be available.")
prediction_model = None
model_loaded = True
return True
# Load prediction model if available
if os.path.exists(pred_model_path):
with open(pred_model_path, 'rb') as f:
prediction_model = pickle.load(f)
print("Prediction model loaded successfully from", pred_model_path)
else:
print("Warning: Prediction model not found. Aphasia score prediction will not be available.")
prediction_model = None
model_loaded = True
return True
except Exception as e:
print(f"Error loading model: {str(e)}")
model_loaded = False
return False
def train_model(progress=gr.Progress()):
"""Train the model and update progress"""
global model, model_loaded, prediction_model, demovae_model
status_messages = []
# Process data from HuggingFace
progress(0.1, desc="Processing OSF data from HuggingFace...")
status_messages.append("Step 1: Loading and processing OSF data from HuggingFace...")
X_fc, X_demo, y_wab, y_improvement, final_df = load_and_process_data(
quick_test=False # Use the full dataset, not just a sample
)
status_messages.append(f"✓ Data processed: {len(X_fc)} samples with {X_fc.shape[1]} FC features and {X_demo.shape[1]} demographic features")
# Train VAE model
progress(0.3, desc="Training DemoVAE model (first stage)...")
status_messages.append("\nStep 2: Training DemoVAE model (first stage of pipeline)...")
status_messages.append("This model will learn latent representations of brain connectivity patterns")
demovae_model, z_train, z_test, X_fc_test, X_demo_test, y_test = train_demovae_model(
X_fc, X_demo, y_wab, save_model=True, model_path=MODEL_PATH
)
# Update global model
model = demovae_model.vae
model_loaded = True
status_messages.append(f"✓ DemoVAE trained successfully: {demovae_model.latent_dim} latent dimensions")
# Train Random Forest prediction model for aphasia scores
progress(0.7, desc="Training Random Forest model (second stage)...")
status_messages.append("\nStep 3: Training Random Forest model (second stage of pipeline)...")
status_messages.append("This model will predict aphasia scores from latent brain connectivity patterns and demographics")
print("\n===== STARTING SECOND STAGE: RANDOM FOREST TRAINING =====")
print("The first stage (VAE) extracted latent representations of brain connectivity")
print("Now training Random Forest to predict aphasia scores from these representations")
X_combined = np.hstack([z_test, X_demo_test])
pred_model, y_pred, rmse_val, r2 = predict_aphasia_recovery(z_test, X_demo_test, y_test)
status_messages.append(f"✓ Random Forest trained successfully")
status_messages.append(f" - Prediction accuracy: RMSE = {rmse_val:.2f}, R² = {r2:.2f}")
# Save prediction model
status_messages.append("\nStep 4: Saving trained models...")
prediction_model = pred_model
pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
with open(pred_model_path, 'wb') as f:
pickle.dump(pred_model, f)
status_messages.append(f"✓ Saved Random Forest model to {pred_model_path}")
# Save the trained models for future reference
combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
with open(combined_model_path, 'wb') as f:
pickle.dump({
'demovae': demovae_model,
'prediction': pred_model,
'latent_dim': demovae_model.latent_dim
}, f)
status_messages.append(f"✓ Saved combined models to {combined_model_path}")
progress(1.0, desc="Model training complete!")
status_messages.append("\n✅ MODEL TRAINING COMPLETE!")
status_messages.append("You can now use the model to predict aphasia scores and visualize functional connectivity")
return "\n".join(status_messages)
def analyze_fc_regions(matrix, region_names, top_n=5):
"""Analyze top connected brain regions from a FC matrix"""
n_regions = len(region_names)
# Get the average connectivity per region
avg_connectivity = np.zeros(n_regions)
for i in range(n_regions):
# Skip self-connections
connections = [matrix[i,j] for j in range(n_regions) if i != j]
avg_connectivity[i] = np.mean(connections)
# Get top positive and negative connected regions
pos_indices = np.argsort(avg_connectivity)[-top_n:][::-1]
neg_indices = np.argsort(avg_connectivity)[:top_n]
top_positive = [(region_names[i], avg_connectivity[i]) for i in pos_indices]
top_negative = [(region_names[i], avg_connectivity[i]) for i in neg_indices]
return top_positive, top_negative
def generate_fc_visualization(age, mpo, education, gender, handedness,
aphasia_severity, lesion_size,
use_custom_score=False, custom_score=None):
"""Generate FC visualization based on demographics and return results"""
global model_loaded, model, demovae_model, prediction_model
# Check if model is loaded
if not model_loaded:
if os.path.exists(MODEL_PATH):
# Try to load existing model
if not load_model():
return None, "Failed to load model. Please train the model first."
else:
return None, "Model not found. Please train the model first."
# Convert gender to format expected by model
gender_val = 1 if gender == "Male" else 0
handedness_val = 1 if handedness == "Right" else 0
# Prepare demographics for the model
demo_values = {
'age': age,
'mpo': mpo,
'education': education,
'gender': 'male' if gender_val else 'female',
'handedness': 'right' if handedness_val else 'left'
}
# Set predicted score to None unless we override it
predicted_aphasia_score = None
aphasia_score_source = "default"
if use_custom_score and custom_score is not None:
# Use user-provided custom score
predicted_aphasia_score = custom_score
aphasia_score_source = "custom"
# Generate FC matrix using our adapter function
try:
# Try the new function signature first (returns 3 values)
custom_fc_mat, gen_predicted_score, viz_path = generate_custom_fc(
demo_values,
demovae_model,
prediction_model if not use_custom_score else None,
visualize=True
)
except (ValueError, TypeError) as e:
# Fall back to older function signature (returns 2 values)
print(f"Warning: Using older generate_custom_fc signature: {e}")
custom_fc_mat, gen_predicted_score = generate_custom_fc(
demo_values,
demovae_model,
prediction_model if not use_custom_score else None
)
viz_path = None
# If we're using the generated prediction
if not use_custom_score and gen_predicted_score is not None:
predicted_aphasia_score = gen_predicted_score
aphasia_score_source = "predicted"
elif predicted_aphasia_score is None:
# Fall back to default value if nothing else was set
predicted_aphasia_score = aphasia_severity
# If we have a visualization path from the new function, use it
# Use /tmp for visualization files in Spaces
viz_dir = "/tmp/fc_visualizations" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
os.makedirs(viz_dir, exist_ok=True)
temp_img_path = os.path.join(viz_dir, f"temp_fc_matrix_{time.strftime('%Y%m%d_%H%M%S')}.png")
if viz_path and os.path.exists(viz_path):
# Use the already created visualization
import shutil
shutil.copy(viz_path, temp_img_path)
else:
# Generate FC heatmap with aphasia score in title using our new color scheme
try:
# Use the new visualization function if available
get_connectivity_visualization(
custom_fc_mat,
subject_id=f"Patient: Age {age}, Gender {'M' if gender_val else 'F'}, Aphasia Score: {predicted_aphasia_score:.1f}",
output_path=temp_img_path
)
except (NameError, AttributeError):
# Fall back to old style visualization
plt.figure(figsize=(10, 8))
plt.imshow(custom_fc_mat, cmap='RdBu_r', vmin=-1, vmax=1)
plt.colorbar(label='Correlation')
plt.title(f'FC Matrix: Age {age}, Gender {"M" if gender_val else "F"}, Aphasia Score: {predicted_aphasia_score:.1f}')
plt.savefig(temp_img_path)
plt.close()
# Create DataFrame for FC values
region_names = ATLAS_REGIONS[:custom_fc_mat.shape[0]]
# Analyze FC regions
top_positive, top_negative = analyze_fc_regions(custom_fc_mat, region_names)
# Create summary text with the analysis
severity_category = get_aphasia_severity_category(predicted_aphasia_score)
summary = f"""### Aphasia Score: {predicted_aphasia_score:.1f}/100
Category: {severity_category}
Source: {"Model Prediction" if aphasia_score_source == "predicted" else "Custom Value" if aphasia_score_source == "custom" else "Default"}
### Demographic Information
- Age: {age} years
- Months Post Onset: {mpo}
- Education: {education} years
- Gender: {gender}
- Handedness: {handedness}
- Lesion Size: {lesion_size}%
### Brain Connectivity Analysis
Top connected brain regions:
"""
for region, value in top_positive:
summary += f"- {region}: {value:.2f}\n"
summary += "\nLeast connected brain regions:\n"
for region, value in top_negative:
summary += f"- {region}: {value:.2f}\n"
# Create dataframe for FC matrix
df_data = []
for i in range(custom_fc_mat.shape[0]):
for j in range(custom_fc_mat.shape[0]):
if i < j: # Only include upper triangle to avoid redundancy
df_data.append({
"Region 1": region_names[i],
"Region 2": region_names[j],
"Connectivity": round(float(custom_fc_mat[i, j]), 2)
})
# Sort by absolute connectivity value
df = pd.DataFrame(df_data)
df = df.sort_values(by="Connectivity", key=abs, ascending=False)
# Limit to top 100 connections for performance
df = df.head(100)
return temp_img_path, summary, df
# Check if model exists and try to load it
if os.path.exists(MODEL_PATH):
print("Model file found. Loading model...")
load_model()
else:
print("No model found. Please train the model first.")
# Create Gradio interface
with gr.Blocks(title="Aphasia Prediction with FC Visualization") as demo:
gr.Markdown("# Aphasia Prediction with Functional Connectivity Visualization")
gr.Markdown("This app predicts aphasia scores based on patient demographics and displays functional connectivity patterns in the brain.")
with gr.Tab("Predict & Visualize"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("### Patient Demographics")
age = gr.Slider(minimum=20, maximum=90, value=60, step=1, label="Age (years)")
mpo = gr.Slider(minimum=1, maximum=36, value=6, step=1, label="Months Post Onset")
education = gr.Slider(minimum=8, maximum=22, value=16, step=1, label="Education (years)")
gender = gr.Radio(["Male", "Female"], value="Male", label="Gender")
handedness = gr.Radio(["Right", "Left"], value="Right", label="Handedness")
gr.Markdown("### Aphasia Information")
aphasia_severity = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Aphasia Severity (WAB AQ)")
lesion_size = gr.Slider(minimum=0, maximum=100, value=20, step=1, label="Lesion Size (%)")
use_custom_score = gr.Checkbox(label="Override with custom score", value=False)
custom_score = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Custom WAB AQ Score",
visible=False)
# Make custom score visible only when checkbox is selected
use_custom_score.change(lambda x: gr.update(visible=x), inputs=[use_custom_score], outputs=[custom_score])
generate_btn = gr.Button("Generate Functional Connectivity", variant="primary")
with gr.Column(scale=2):
with gr.Row():
fc_image = gr.Image(label="Functional Connectivity Matrix", show_download_button=True)
fc_summary = gr.Markdown(label="Analysis Summary")
fc_data = gr.DataFrame(label="Top FC Connections")
# Generate FC on button click
generate_btn.click(
generate_fc_visualization,
inputs=[age, mpo, education, gender, handedness,
aphasia_severity, lesion_size,
use_custom_score, custom_score],
outputs=[fc_image, fc_summary, fc_data]
)
with gr.Tab("Train Model"):
gr.Markdown("### Train or Retrain the Model")
gr.Markdown("""
This tab allows you to train the two-stage model:
1. First stage: DemoVAE model learns brain connectivity patterns
2. Second stage: Random Forest predicts aphasia scores
Note: This will download data from HuggingFace 'SreekarB/OSFData' and use the full dataset for training.
""")
train_btn = gr.Button("Train Model", variant="primary")
train_output = gr.Textbox(label="Training Status", lines=20)
train_btn.click(train_model, inputs=[], outputs=[train_output])
gr.Markdown("## How to use")
gr.Markdown("""
1. Set the patient's demographic information and aphasia details
2. Click "Generate Functional Connectivity" to see the visualization and prediction
3. Optionally, override the model's prediction with your own custom score
4. If the model is not trained, go to the "Train Model" tab to train it first
The heatmap shows correlations between brain regions. Red indicates positive correlations (regions that activate together),
white indicates neutral correlations, and blue indicates negative correlations (regions with opposing activation patterns).
""")
if __name__ == "__main__":
# Set up the optimal launch configuration for Hugging Face Spaces
if IS_SPACE:
demo.launch(server_name="0.0.0.0", share=False)
else:
demo.launch()