| import os
|
| import sys
|
| import gradio as gr
|
| import numpy as np
|
| import matplotlib.pyplot as plt
|
| import torch
|
| import pickle
|
| import pandas as pd
|
| import time
|
| import warnings
|
| warnings.filterwarnings('ignore')
|
|
|
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
|
|
| sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pip', 'src'))
|
|
|
|
|
| IS_SPACE = os.environ.get('SPACE_ID') is not None
|
| print(f"Running in {'Hugging Face Spaces' if IS_SPACE else 'local environment'}")
|
|
|
|
|
| try:
|
| from osf_demovae_adapter import (
|
| VAE, load_and_process_data, train_demovae_model, predict_aphasia_recovery,
|
| generate_custom_fc, to_torch, to_cuda, to_numpy, vec2mat, mat2vec, ATLAS_REGIONS,
|
| download_model, get_connectivity_visualization
|
| )
|
| print("Successfully imported osf_demovae_adapter modules")
|
| except ImportError as e:
|
| print(f"Error importing osf_demovae_adapter modules: {e} - make sure path is correct")
|
|
|
|
|
|
|
| MODEL_DIR = "/tmp/osf_models" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
|
| os.makedirs(MODEL_DIR, exist_ok=True)
|
| MODEL_PATH = os.path.join(MODEL_DIR, 'osf_demovae_model.pt')
|
| LATENT_DIM = 30
|
| INPUT_DIM = 1000
|
| DEMO_DIM = 5
|
|
|
| USE_CUDA = torch.cuda.is_available() and not IS_SPACE
|
|
|
|
|
| model = None
|
| demovae_model = None
|
| prediction_model = None
|
| model_loaded = False
|
|
|
|
|
| def get_aphasia_severity_category(wab_score):
|
| """Interpret WAB AQ score to determine aphasia severity category"""
|
| if wab_score >= 93.8:
|
| return "No aphasia (within normal limits)"
|
| elif wab_score >= 75:
|
| return "Mild aphasia"
|
| elif wab_score >= 50:
|
| return "Moderate aphasia"
|
| elif wab_score >= 25:
|
| return "Severe aphasia"
|
| else:
|
| return "Very severe aphasia"
|
|
|
| def load_model():
|
| """Load the VAE model and prediction model from disk or download if not available"""
|
| global model, model_loaded, prediction_model, demovae_model
|
| try:
|
|
|
| combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
|
|
|
|
|
| if not os.path.exists(combined_model_path):
|
| try:
|
| print("Combined model file not found. Attempting to download...")
|
| combined_model_path = download_model('combined')
|
| except Exception as e:
|
| print(f"Could not download combined model: {e}")
|
| combined_model_path = None
|
|
|
|
|
| if combined_model_path and os.path.exists(combined_model_path):
|
| with open(combined_model_path, 'rb') as f:
|
| models_dict = pickle.load(f)
|
| demovae_model = models_dict['demovae']
|
| prediction_model = models_dict['prediction']
|
| model = demovae_model.vae
|
| print("DemoVAE and prediction models loaded successfully from", combined_model_path)
|
| else:
|
|
|
| print("Combined model file not available. Trying to load or download models separately...")
|
|
|
|
|
| if not os.path.exists(MODEL_PATH):
|
| try:
|
| print("DemoVAE model not found. Attempting to download...")
|
| MODEL_PATH = download_model('demovae')
|
| except Exception as e:
|
| print(f"Could not download DemoVAE model: {e}")
|
| return False
|
|
|
|
|
| from demovae.sklearn import DemoVAE
|
| demovae_model = DemoVAE(latent_dim=LATENT_DIM, use_cuda=USE_CUDA)
|
| demovae_model.load(MODEL_PATH)
|
| model = demovae_model.vae
|
| print("DemoVAE model loaded successfully from", MODEL_PATH)
|
|
|
|
|
| pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
|
| if not os.path.exists(pred_model_path):
|
| try:
|
| print("Prediction model not found. Attempting to download...")
|
| pred_model_path = download_model('prediction')
|
| except Exception as e:
|
| print(f"Could not download prediction model: {e}")
|
| print("Warning: Aphasia score prediction will not be available.")
|
| prediction_model = None
|
| model_loaded = True
|
| return True
|
|
|
|
|
| if os.path.exists(pred_model_path):
|
| with open(pred_model_path, 'rb') as f:
|
| prediction_model = pickle.load(f)
|
| print("Prediction model loaded successfully from", pred_model_path)
|
| else:
|
| print("Warning: Prediction model not found. Aphasia score prediction will not be available.")
|
| prediction_model = None
|
|
|
| model_loaded = True
|
| return True
|
| except Exception as e:
|
| print(f"Error loading model: {str(e)}")
|
| model_loaded = False
|
| return False
|
|
|
| def train_model(progress=gr.Progress()):
|
| """Train the model and update progress"""
|
| global model, model_loaded, prediction_model, demovae_model
|
|
|
| status_messages = []
|
|
|
|
|
| progress(0.1, desc="Processing OSF data from HuggingFace...")
|
| status_messages.append("Step 1: Loading and processing OSF data from HuggingFace...")
|
| X_fc, X_demo, y_wab, y_improvement, final_df = load_and_process_data(
|
| quick_test=False
|
| )
|
| status_messages.append(f"✓ Data processed: {len(X_fc)} samples with {X_fc.shape[1]} FC features and {X_demo.shape[1]} demographic features")
|
|
|
|
|
| progress(0.3, desc="Training DemoVAE model (first stage)...")
|
| status_messages.append("\nStep 2: Training DemoVAE model (first stage of pipeline)...")
|
| status_messages.append("This model will learn latent representations of brain connectivity patterns")
|
| demovae_model, z_train, z_test, X_fc_test, X_demo_test, y_test = train_demovae_model(
|
| X_fc, X_demo, y_wab, save_model=True, model_path=MODEL_PATH
|
| )
|
|
|
|
|
| model = demovae_model.vae
|
| model_loaded = True
|
| status_messages.append(f"✓ DemoVAE trained successfully: {demovae_model.latent_dim} latent dimensions")
|
|
|
|
|
| progress(0.7, desc="Training Random Forest model (second stage)...")
|
| status_messages.append("\nStep 3: Training Random Forest model (second stage of pipeline)...")
|
| status_messages.append("This model will predict aphasia scores from latent brain connectivity patterns and demographics")
|
|
|
| print("\n===== STARTING SECOND STAGE: RANDOM FOREST TRAINING =====")
|
| print("The first stage (VAE) extracted latent representations of brain connectivity")
|
| print("Now training Random Forest to predict aphasia scores from these representations")
|
|
|
| X_combined = np.hstack([z_test, X_demo_test])
|
| pred_model, y_pred, rmse_val, r2 = predict_aphasia_recovery(z_test, X_demo_test, y_test)
|
|
|
| status_messages.append(f"✓ Random Forest trained successfully")
|
| status_messages.append(f" - Prediction accuracy: RMSE = {rmse_val:.2f}, R² = {r2:.2f}")
|
|
|
|
|
| status_messages.append("\nStep 4: Saving trained models...")
|
| prediction_model = pred_model
|
| pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
|
| with open(pred_model_path, 'wb') as f:
|
| pickle.dump(pred_model, f)
|
| status_messages.append(f"✓ Saved Random Forest model to {pred_model_path}")
|
|
|
|
|
| combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
|
| with open(combined_model_path, 'wb') as f:
|
| pickle.dump({
|
| 'demovae': demovae_model,
|
| 'prediction': pred_model,
|
| 'latent_dim': demovae_model.latent_dim
|
| }, f)
|
| status_messages.append(f"✓ Saved combined models to {combined_model_path}")
|
|
|
| progress(1.0, desc="Model training complete!")
|
| status_messages.append("\n✅ MODEL TRAINING COMPLETE!")
|
| status_messages.append("You can now use the model to predict aphasia scores and visualize functional connectivity")
|
|
|
| return "\n".join(status_messages)
|
|
|
| def analyze_fc_regions(matrix, region_names, top_n=5):
|
| """Analyze top connected brain regions from a FC matrix"""
|
| n_regions = len(region_names)
|
|
|
|
|
| avg_connectivity = np.zeros(n_regions)
|
| for i in range(n_regions):
|
|
|
| connections = [matrix[i,j] for j in range(n_regions) if i != j]
|
| avg_connectivity[i] = np.mean(connections)
|
|
|
|
|
| pos_indices = np.argsort(avg_connectivity)[-top_n:][::-1]
|
| neg_indices = np.argsort(avg_connectivity)[:top_n]
|
|
|
| top_positive = [(region_names[i], avg_connectivity[i]) for i in pos_indices]
|
| top_negative = [(region_names[i], avg_connectivity[i]) for i in neg_indices]
|
|
|
| return top_positive, top_negative
|
|
|
| def generate_fc_visualization(age, mpo, education, gender, handedness,
|
| aphasia_severity, lesion_size,
|
| use_custom_score=False, custom_score=None):
|
| """Generate FC visualization based on demographics and return results"""
|
| global model_loaded, model, demovae_model, prediction_model
|
|
|
|
|
| if not model_loaded:
|
| if os.path.exists(MODEL_PATH):
|
|
|
| if not load_model():
|
| return None, "Failed to load model. Please train the model first."
|
| else:
|
| return None, "Model not found. Please train the model first."
|
|
|
|
|
| gender_val = 1 if gender == "Male" else 0
|
| handedness_val = 1 if handedness == "Right" else 0
|
|
|
|
|
| demo_values = {
|
| 'age': age,
|
| 'mpo': mpo,
|
| 'education': education,
|
| 'gender': 'male' if gender_val else 'female',
|
| 'handedness': 'right' if handedness_val else 'left'
|
| }
|
|
|
|
|
| predicted_aphasia_score = None
|
| aphasia_score_source = "default"
|
|
|
| if use_custom_score and custom_score is not None:
|
|
|
| predicted_aphasia_score = custom_score
|
| aphasia_score_source = "custom"
|
|
|
|
|
| try:
|
|
|
| custom_fc_mat, gen_predicted_score, viz_path = generate_custom_fc(
|
| demo_values,
|
| demovae_model,
|
| prediction_model if not use_custom_score else None,
|
| visualize=True
|
| )
|
| except (ValueError, TypeError) as e:
|
|
|
| print(f"Warning: Using older generate_custom_fc signature: {e}")
|
| custom_fc_mat, gen_predicted_score = generate_custom_fc(
|
| demo_values,
|
| demovae_model,
|
| prediction_model if not use_custom_score else None
|
| )
|
| viz_path = None
|
|
|
|
|
| if not use_custom_score and gen_predicted_score is not None:
|
| predicted_aphasia_score = gen_predicted_score
|
| aphasia_score_source = "predicted"
|
| elif predicted_aphasia_score is None:
|
|
|
| predicted_aphasia_score = aphasia_severity
|
|
|
|
|
|
|
| viz_dir = "/tmp/fc_visualizations" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
|
| os.makedirs(viz_dir, exist_ok=True)
|
| temp_img_path = os.path.join(viz_dir, f"temp_fc_matrix_{time.strftime('%Y%m%d_%H%M%S')}.png")
|
|
|
| if viz_path and os.path.exists(viz_path):
|
|
|
| import shutil
|
| shutil.copy(viz_path, temp_img_path)
|
| else:
|
|
|
| try:
|
|
|
| get_connectivity_visualization(
|
| custom_fc_mat,
|
| subject_id=f"Patient: Age {age}, Gender {'M' if gender_val else 'F'}, Aphasia Score: {predicted_aphasia_score:.1f}",
|
| output_path=temp_img_path
|
| )
|
| except (NameError, AttributeError):
|
|
|
| plt.figure(figsize=(10, 8))
|
| plt.imshow(custom_fc_mat, cmap='RdBu_r', vmin=-1, vmax=1)
|
| plt.colorbar(label='Correlation')
|
| plt.title(f'FC Matrix: Age {age}, Gender {"M" if gender_val else "F"}, Aphasia Score: {predicted_aphasia_score:.1f}')
|
| plt.savefig(temp_img_path)
|
| plt.close()
|
|
|
|
|
| region_names = ATLAS_REGIONS[:custom_fc_mat.shape[0]]
|
|
|
|
|
| top_positive, top_negative = analyze_fc_regions(custom_fc_mat, region_names)
|
|
|
|
|
| severity_category = get_aphasia_severity_category(predicted_aphasia_score)
|
|
|
| summary = f"""### Aphasia Score: {predicted_aphasia_score:.1f}/100
|
| Category: {severity_category}
|
| Source: {"Model Prediction" if aphasia_score_source == "predicted" else "Custom Value" if aphasia_score_source == "custom" else "Default"}
|
|
|
| ### Demographic Information
|
| - Age: {age} years
|
| - Months Post Onset: {mpo}
|
| - Education: {education} years
|
| - Gender: {gender}
|
| - Handedness: {handedness}
|
| - Lesion Size: {lesion_size}%
|
|
|
| ### Brain Connectivity Analysis
|
| Top connected brain regions:
|
| """
|
|
|
| for region, value in top_positive:
|
| summary += f"- {region}: {value:.2f}\n"
|
|
|
| summary += "\nLeast connected brain regions:\n"
|
|
|
| for region, value in top_negative:
|
| summary += f"- {region}: {value:.2f}\n"
|
|
|
|
|
| df_data = []
|
| for i in range(custom_fc_mat.shape[0]):
|
| for j in range(custom_fc_mat.shape[0]):
|
| if i < j:
|
| df_data.append({
|
| "Region 1": region_names[i],
|
| "Region 2": region_names[j],
|
| "Connectivity": round(float(custom_fc_mat[i, j]), 2)
|
| })
|
|
|
|
|
| df = pd.DataFrame(df_data)
|
| df = df.sort_values(by="Connectivity", key=abs, ascending=False)
|
|
|
|
|
| df = df.head(100)
|
|
|
| return temp_img_path, summary, df
|
|
|
|
|
| if os.path.exists(MODEL_PATH):
|
| print("Model file found. Loading model...")
|
| load_model()
|
| else:
|
| print("No model found. Please train the model first.")
|
|
|
|
|
| with gr.Blocks(title="Aphasia Prediction with FC Visualization") as demo:
|
| gr.Markdown("# Aphasia Prediction with Functional Connectivity Visualization")
|
| gr.Markdown("This app predicts aphasia scores based on patient demographics and displays functional connectivity patterns in the brain.")
|
|
|
| with gr.Tab("Predict & Visualize"):
|
| with gr.Row():
|
| with gr.Column(scale=1):
|
| gr.Markdown("### Patient Demographics")
|
| age = gr.Slider(minimum=20, maximum=90, value=60, step=1, label="Age (years)")
|
| mpo = gr.Slider(minimum=1, maximum=36, value=6, step=1, label="Months Post Onset")
|
| education = gr.Slider(minimum=8, maximum=22, value=16, step=1, label="Education (years)")
|
| gender = gr.Radio(["Male", "Female"], value="Male", label="Gender")
|
| handedness = gr.Radio(["Right", "Left"], value="Right", label="Handedness")
|
|
|
| gr.Markdown("### Aphasia Information")
|
| aphasia_severity = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Aphasia Severity (WAB AQ)")
|
| lesion_size = gr.Slider(minimum=0, maximum=100, value=20, step=1, label="Lesion Size (%)")
|
|
|
| use_custom_score = gr.Checkbox(label="Override with custom score", value=False)
|
| custom_score = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Custom WAB AQ Score",
|
| visible=False)
|
|
|
|
|
| use_custom_score.change(lambda x: gr.update(visible=x), inputs=[use_custom_score], outputs=[custom_score])
|
|
|
| generate_btn = gr.Button("Generate Functional Connectivity", variant="primary")
|
|
|
| with gr.Column(scale=2):
|
| with gr.Row():
|
| fc_image = gr.Image(label="Functional Connectivity Matrix", show_download_button=True)
|
| fc_summary = gr.Markdown(label="Analysis Summary")
|
|
|
| fc_data = gr.DataFrame(label="Top FC Connections")
|
|
|
|
|
| generate_btn.click(
|
| generate_fc_visualization,
|
| inputs=[age, mpo, education, gender, handedness,
|
| aphasia_severity, lesion_size,
|
| use_custom_score, custom_score],
|
| outputs=[fc_image, fc_summary, fc_data]
|
| )
|
|
|
| with gr.Tab("Train Model"):
|
| gr.Markdown("### Train or Retrain the Model")
|
| gr.Markdown("""
|
| This tab allows you to train the two-stage model:
|
| 1. First stage: DemoVAE model learns brain connectivity patterns
|
| 2. Second stage: Random Forest predicts aphasia scores
|
|
|
| Note: This will download data from HuggingFace 'SreekarB/OSFData' and use the full dataset for training.
|
| """)
|
|
|
| train_btn = gr.Button("Train Model", variant="primary")
|
| train_output = gr.Textbox(label="Training Status", lines=20)
|
|
|
| train_btn.click(train_model, inputs=[], outputs=[train_output])
|
|
|
| gr.Markdown("## How to use")
|
| gr.Markdown("""
|
| 1. Set the patient's demographic information and aphasia details
|
| 2. Click "Generate Functional Connectivity" to see the visualization and prediction
|
| 3. Optionally, override the model's prediction with your own custom score
|
| 4. If the model is not trained, go to the "Train Model" tab to train it first
|
|
|
| The heatmap shows correlations between brain regions. Red indicates positive correlations (regions that activate together),
|
| white indicates neutral correlations, and blue indicates negative correlations (regions with opposing activation patterns).
|
| """)
|
|
|
| if __name__ == "__main__":
|
|
|
| if IS_SPACE:
|
| demo.launch(server_name="0.0.0.0", share=False)
|
| else:
|
| demo.launch() |