Upload 18 files
Browse files- README.md +62 -12
- app.py +451 -0
- osf_demovae_adapter.py +1369 -0
- pip/.DS_Store +0 -0
- pip/.gitignore +1 -0
- pip/.pypirc +6 -0
- pip/LICENSE +21 -0
- pip/README.md +7 -0
- pip/pyproject.toml +23 -0
- pip/requirements.txt +11 -0
- pip/src/.DS_Store +0 -0
- pip/src/demovae/__init__.py +1 -0
- pip/src/demovae/__pycache__/__init__.cpython-311.pyc +0 -0
- pip/src/demovae/__pycache__/model.cpython-311.pyc +0 -0
- pip/src/demovae/__pycache__/sklearn.cpython-311.pyc +0 -0
- pip/src/demovae/model.py +221 -0
- pip/src/demovae/sklearn.py +123 -0
- requirements.txt +12 -0
README.md
CHANGED
|
@@ -1,12 +1,62 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
-
app_file: app.py
|
| 9 |
-
pinned: false
|
| 10 |
-
---
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Aphasia Prediction with FC Visualization
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: red
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 3.50.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Aphasia Prediction with VAE and FC Visualization
|
| 13 |
+
|
| 14 |
+
This application predicts aphasia scores based on patient demographics and visualizes functional connectivity (FC) patterns in the brain.
|
| 15 |
+
|
| 16 |
+
## Features
|
| 17 |
+
|
| 18 |
+
- Predict aphasia severity (WAB AQ score) based on patient demographic data
|
| 19 |
+
- Option to manually set aphasia scores
|
| 20 |
+
- Display functional connectivity heatmap and matrix values
|
| 21 |
+
- Interactive visualization of brain region connectivity
|
| 22 |
+
- Customizable demographic parameters
|
| 23 |
+
|
| 24 |
+
## Usage
|
| 25 |
+
|
| 26 |
+
1. First, if you haven't trained the model, go to the "Train Model" tab and click "Train Model"
|
| 27 |
+
2. When the model is ready, go to the "Predict & Visualize" tab
|
| 28 |
+
3. Adjust the demographic sliders for age, months post onset, education, gender, and handedness
|
| 29 |
+
4. Select an aphasia type and set initial severity and lesion size
|
| 30 |
+
5. Click "Generate Functional Connectivity" to see the predictions and visualization
|
| 31 |
+
6. Optionally override the model's prediction with a custom score
|
| 32 |
+
7. Explore the functional connectivity matrix visualization and detailed values
|
| 33 |
+
|
| 34 |
+
## Technical Details
|
| 35 |
+
|
| 36 |
+
The application uses:
|
| 37 |
+
- A Variational Autoencoder (VAE) from the DemoVAE package for learning latent representations of brain connectivity
|
| 38 |
+
- Random Forest regression to predict aphasia scores from latent features and demographics
|
| 39 |
+
- Gradio web interface for interactive visualization
|
| 40 |
+
- Analysis of key brain connectivity patterns and their relationship to aphasia
|
| 41 |
+
|
| 42 |
+
## Deployment
|
| 43 |
+
|
| 44 |
+
The application can be deployed using:
|
| 45 |
+
|
| 46 |
+
```bash
|
| 47 |
+
# Install requirements
|
| 48 |
+
pip install -r requirements.txt
|
| 49 |
+
|
| 50 |
+
# Run the Gradio app
|
| 51 |
+
python app_gradio.py
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Hugging Face Spaces Deployment
|
| 55 |
+
|
| 56 |
+
This app is designed to be deployed on Hugging Face Spaces:
|
| 57 |
+
|
| 58 |
+
1. Create a new Space and select Gradio as the SDK
|
| 59 |
+
2. Upload the files or connect to your GitHub repository
|
| 60 |
+
3. The app will automatically deploy and be available online
|
| 61 |
+
|
| 62 |
+
Note: The initial model training may take some time when you first run the application.
|
app.py
ADDED
|
@@ -0,0 +1,451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import gradio as gr
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import torch
|
| 7 |
+
import pickle
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import time
|
| 10 |
+
import warnings
|
| 11 |
+
warnings.filterwarnings('ignore') # Suppress warnings
|
| 12 |
+
|
| 13 |
+
# Add the current directory to Python path
|
| 14 |
+
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
| 15 |
+
|
| 16 |
+
# Add PIP package to path
|
| 17 |
+
sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pip', 'src'))
|
| 18 |
+
|
| 19 |
+
# Check if running in Hugging Face Spaces
|
| 20 |
+
IS_SPACE = os.environ.get('SPACE_ID') is not None
|
| 21 |
+
print(f"Running in {'Hugging Face Spaces' if IS_SPACE else 'local environment'}")
|
| 22 |
+
|
| 23 |
+
# Import VAE model and functions
|
| 24 |
+
try:
|
| 25 |
+
from osf_demovae_adapter import (
|
| 26 |
+
VAE, load_and_process_data, train_demovae_model, predict_aphasia_recovery,
|
| 27 |
+
generate_custom_fc, to_torch, to_cuda, to_numpy, vec2mat, mat2vec, ATLAS_REGIONS,
|
| 28 |
+
download_model, get_connectivity_visualization
|
| 29 |
+
)
|
| 30 |
+
print("Successfully imported osf_demovae_adapter modules")
|
| 31 |
+
except ImportError as e:
|
| 32 |
+
print(f"Error importing osf_demovae_adapter modules: {e} - make sure path is correct")
|
| 33 |
+
|
| 34 |
+
# Model configuration
|
| 35 |
+
# Use /tmp for model storage in Hugging Face Spaces
|
| 36 |
+
MODEL_DIR = "/tmp/osf_models" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
|
| 37 |
+
os.makedirs(MODEL_DIR, exist_ok=True)
|
| 38 |
+
MODEL_PATH = os.path.join(MODEL_DIR, 'osf_demovae_model.pt')
|
| 39 |
+
LATENT_DIM = 30
|
| 40 |
+
INPUT_DIM = 1000
|
| 41 |
+
DEMO_DIM = 5
|
| 42 |
+
# Be cautious with CUDA in Spaces as resources may be limited
|
| 43 |
+
USE_CUDA = torch.cuda.is_available() and not IS_SPACE # Disable CUDA in Spaces for stability
|
| 44 |
+
|
| 45 |
+
# Initialize model during startup if needed
|
| 46 |
+
model = None
|
| 47 |
+
demovae_model = None
|
| 48 |
+
prediction_model = None
|
| 49 |
+
model_loaded = False
|
| 50 |
+
|
| 51 |
+
# Helper function for aphasia severity interpretation
|
| 52 |
+
def get_aphasia_severity_category(wab_score):
|
| 53 |
+
"""Interpret WAB AQ score to determine aphasia severity category"""
|
| 54 |
+
if wab_score >= 93.8:
|
| 55 |
+
return "No aphasia (within normal limits)"
|
| 56 |
+
elif wab_score >= 75:
|
| 57 |
+
return "Mild aphasia"
|
| 58 |
+
elif wab_score >= 50:
|
| 59 |
+
return "Moderate aphasia"
|
| 60 |
+
elif wab_score >= 25:
|
| 61 |
+
return "Severe aphasia"
|
| 62 |
+
else:
|
| 63 |
+
return "Very severe aphasia"
|
| 64 |
+
|
| 65 |
+
def load_model():
|
| 66 |
+
"""Load the VAE model and prediction model from disk or download if not available"""
|
| 67 |
+
global model, model_loaded, prediction_model, demovae_model
|
| 68 |
+
try:
|
| 69 |
+
# Try to load both models from the combined pickle file first
|
| 70 |
+
combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
|
| 71 |
+
|
| 72 |
+
# Check if combined model exists, otherwise try to download it
|
| 73 |
+
if not os.path.exists(combined_model_path):
|
| 74 |
+
try:
|
| 75 |
+
print("Combined model file not found. Attempting to download...")
|
| 76 |
+
combined_model_path = download_model('combined')
|
| 77 |
+
except Exception as e:
|
| 78 |
+
print(f"Could not download combined model: {e}")
|
| 79 |
+
combined_model_path = None
|
| 80 |
+
|
| 81 |
+
# If we have a combined model file, load it
|
| 82 |
+
if combined_model_path and os.path.exists(combined_model_path):
|
| 83 |
+
with open(combined_model_path, 'rb') as f:
|
| 84 |
+
models_dict = pickle.load(f)
|
| 85 |
+
demovae_model = models_dict['demovae']
|
| 86 |
+
prediction_model = models_dict['prediction']
|
| 87 |
+
model = demovae_model.vae
|
| 88 |
+
print("DemoVAE and prediction models loaded successfully from", combined_model_path)
|
| 89 |
+
else:
|
| 90 |
+
# Fall back to loading models separately
|
| 91 |
+
print("Combined model file not available. Trying to load or download models separately...")
|
| 92 |
+
|
| 93 |
+
# Check if DemoVAE model exists, otherwise try to download it
|
| 94 |
+
if not os.path.exists(MODEL_PATH):
|
| 95 |
+
try:
|
| 96 |
+
print("DemoVAE model not found. Attempting to download...")
|
| 97 |
+
MODEL_PATH = download_model('demovae')
|
| 98 |
+
except Exception as e:
|
| 99 |
+
print(f"Could not download DemoVAE model: {e}")
|
| 100 |
+
return False
|
| 101 |
+
|
| 102 |
+
# Create model instance and load DemoVAE model
|
| 103 |
+
from demovae.sklearn import DemoVAE
|
| 104 |
+
demovae_model = DemoVAE(latent_dim=LATENT_DIM, use_cuda=USE_CUDA)
|
| 105 |
+
demovae_model.load(MODEL_PATH)
|
| 106 |
+
model = demovae_model.vae
|
| 107 |
+
print("DemoVAE model loaded successfully from", MODEL_PATH)
|
| 108 |
+
|
| 109 |
+
# Check for prediction model
|
| 110 |
+
pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
|
| 111 |
+
if not os.path.exists(pred_model_path):
|
| 112 |
+
try:
|
| 113 |
+
print("Prediction model not found. Attempting to download...")
|
| 114 |
+
pred_model_path = download_model('prediction')
|
| 115 |
+
except Exception as e:
|
| 116 |
+
print(f"Could not download prediction model: {e}")
|
| 117 |
+
print("Warning: Aphasia score prediction will not be available.")
|
| 118 |
+
prediction_model = None
|
| 119 |
+
model_loaded = True
|
| 120 |
+
return True
|
| 121 |
+
|
| 122 |
+
# Load prediction model if available
|
| 123 |
+
if os.path.exists(pred_model_path):
|
| 124 |
+
with open(pred_model_path, 'rb') as f:
|
| 125 |
+
prediction_model = pickle.load(f)
|
| 126 |
+
print("Prediction model loaded successfully from", pred_model_path)
|
| 127 |
+
else:
|
| 128 |
+
print("Warning: Prediction model not found. Aphasia score prediction will not be available.")
|
| 129 |
+
prediction_model = None
|
| 130 |
+
|
| 131 |
+
model_loaded = True
|
| 132 |
+
return True
|
| 133 |
+
except Exception as e:
|
| 134 |
+
print(f"Error loading model: {str(e)}")
|
| 135 |
+
model_loaded = False
|
| 136 |
+
return False
|
| 137 |
+
|
| 138 |
+
def train_model(progress=gr.Progress()):
|
| 139 |
+
"""Train the model and update progress"""
|
| 140 |
+
global model, model_loaded, prediction_model, demovae_model
|
| 141 |
+
|
| 142 |
+
status_messages = []
|
| 143 |
+
|
| 144 |
+
# Process data from HuggingFace
|
| 145 |
+
progress(0.1, desc="Processing OSF data from HuggingFace...")
|
| 146 |
+
status_messages.append("Step 1: Loading and processing OSF data from HuggingFace...")
|
| 147 |
+
X_fc, X_demo, y_wab, y_improvement, final_df = load_and_process_data(
|
| 148 |
+
quick_test=False # Use the full dataset, not just a sample
|
| 149 |
+
)
|
| 150 |
+
status_messages.append(f"✓ Data processed: {len(X_fc)} samples with {X_fc.shape[1]} FC features and {X_demo.shape[1]} demographic features")
|
| 151 |
+
|
| 152 |
+
# Train VAE model
|
| 153 |
+
progress(0.3, desc="Training DemoVAE model (first stage)...")
|
| 154 |
+
status_messages.append("\nStep 2: Training DemoVAE model (first stage of pipeline)...")
|
| 155 |
+
status_messages.append("This model will learn latent representations of brain connectivity patterns")
|
| 156 |
+
demovae_model, z_train, z_test, X_fc_test, X_demo_test, y_test = train_demovae_model(
|
| 157 |
+
X_fc, X_demo, y_wab, save_model=True, model_path=MODEL_PATH
|
| 158 |
+
)
|
| 159 |
+
|
| 160 |
+
# Update global model
|
| 161 |
+
model = demovae_model.vae
|
| 162 |
+
model_loaded = True
|
| 163 |
+
status_messages.append(f"✓ DemoVAE trained successfully: {demovae_model.latent_dim} latent dimensions")
|
| 164 |
+
|
| 165 |
+
# Train Random Forest prediction model for aphasia scores
|
| 166 |
+
progress(0.7, desc="Training Random Forest model (second stage)...")
|
| 167 |
+
status_messages.append("\nStep 3: Training Random Forest model (second stage of pipeline)...")
|
| 168 |
+
status_messages.append("This model will predict aphasia scores from latent brain connectivity patterns and demographics")
|
| 169 |
+
|
| 170 |
+
print("\n===== STARTING SECOND STAGE: RANDOM FOREST TRAINING =====")
|
| 171 |
+
print("The first stage (VAE) extracted latent representations of brain connectivity")
|
| 172 |
+
print("Now training Random Forest to predict aphasia scores from these representations")
|
| 173 |
+
|
| 174 |
+
X_combined = np.hstack([z_test, X_demo_test])
|
| 175 |
+
pred_model, y_pred, rmse_val, r2 = predict_aphasia_recovery(z_test, X_demo_test, y_test)
|
| 176 |
+
|
| 177 |
+
status_messages.append(f"✓ Random Forest trained successfully")
|
| 178 |
+
status_messages.append(f" - Prediction accuracy: RMSE = {rmse_val:.2f}, R² = {r2:.2f}")
|
| 179 |
+
|
| 180 |
+
# Save prediction model
|
| 181 |
+
status_messages.append("\nStep 4: Saving trained models...")
|
| 182 |
+
prediction_model = pred_model
|
| 183 |
+
pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
|
| 184 |
+
with open(pred_model_path, 'wb') as f:
|
| 185 |
+
pickle.dump(pred_model, f)
|
| 186 |
+
status_messages.append(f"✓ Saved Random Forest model to {pred_model_path}")
|
| 187 |
+
|
| 188 |
+
# Save the trained models for future reference
|
| 189 |
+
combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
|
| 190 |
+
with open(combined_model_path, 'wb') as f:
|
| 191 |
+
pickle.dump({
|
| 192 |
+
'demovae': demovae_model,
|
| 193 |
+
'prediction': pred_model,
|
| 194 |
+
'latent_dim': demovae_model.latent_dim
|
| 195 |
+
}, f)
|
| 196 |
+
status_messages.append(f"✓ Saved combined models to {combined_model_path}")
|
| 197 |
+
|
| 198 |
+
progress(1.0, desc="Model training complete!")
|
| 199 |
+
status_messages.append("\n✅ MODEL TRAINING COMPLETE!")
|
| 200 |
+
status_messages.append("You can now use the model to predict aphasia scores and visualize functional connectivity")
|
| 201 |
+
|
| 202 |
+
return "\n".join(status_messages)
|
| 203 |
+
|
| 204 |
+
def analyze_fc_regions(matrix, region_names, top_n=5):
|
| 205 |
+
"""Analyze top connected brain regions from a FC matrix"""
|
| 206 |
+
n_regions = len(region_names)
|
| 207 |
+
|
| 208 |
+
# Get the average connectivity per region
|
| 209 |
+
avg_connectivity = np.zeros(n_regions)
|
| 210 |
+
for i in range(n_regions):
|
| 211 |
+
# Skip self-connections
|
| 212 |
+
connections = [matrix[i,j] for j in range(n_regions) if i != j]
|
| 213 |
+
avg_connectivity[i] = np.mean(connections)
|
| 214 |
+
|
| 215 |
+
# Get top positive and negative connected regions
|
| 216 |
+
pos_indices = np.argsort(avg_connectivity)[-top_n:][::-1]
|
| 217 |
+
neg_indices = np.argsort(avg_connectivity)[:top_n]
|
| 218 |
+
|
| 219 |
+
top_positive = [(region_names[i], avg_connectivity[i]) for i in pos_indices]
|
| 220 |
+
top_negative = [(region_names[i], avg_connectivity[i]) for i in neg_indices]
|
| 221 |
+
|
| 222 |
+
return top_positive, top_negative
|
| 223 |
+
|
| 224 |
+
def generate_fc_visualization(age, mpo, education, gender, handedness,
|
| 225 |
+
aphasia_severity, lesion_size,
|
| 226 |
+
use_custom_score=False, custom_score=None):
|
| 227 |
+
"""Generate FC visualization based on demographics and return results"""
|
| 228 |
+
global model_loaded, model, demovae_model, prediction_model
|
| 229 |
+
|
| 230 |
+
# Check if model is loaded
|
| 231 |
+
if not model_loaded:
|
| 232 |
+
if os.path.exists(MODEL_PATH):
|
| 233 |
+
# Try to load existing model
|
| 234 |
+
if not load_model():
|
| 235 |
+
return None, "Failed to load model. Please train the model first."
|
| 236 |
+
else:
|
| 237 |
+
return None, "Model not found. Please train the model first."
|
| 238 |
+
|
| 239 |
+
# Convert gender to format expected by model
|
| 240 |
+
gender_val = 1 if gender == "Male" else 0
|
| 241 |
+
handedness_val = 1 if handedness == "Right" else 0
|
| 242 |
+
|
| 243 |
+
# Prepare demographics for the model
|
| 244 |
+
demo_values = {
|
| 245 |
+
'age': age,
|
| 246 |
+
'mpo': mpo,
|
| 247 |
+
'education': education,
|
| 248 |
+
'gender': 'male' if gender_val else 'female',
|
| 249 |
+
'handedness': 'right' if handedness_val else 'left'
|
| 250 |
+
}
|
| 251 |
+
|
| 252 |
+
# Set predicted score to None unless we override it
|
| 253 |
+
predicted_aphasia_score = None
|
| 254 |
+
aphasia_score_source = "default"
|
| 255 |
+
|
| 256 |
+
if use_custom_score and custom_score is not None:
|
| 257 |
+
# Use user-provided custom score
|
| 258 |
+
predicted_aphasia_score = custom_score
|
| 259 |
+
aphasia_score_source = "custom"
|
| 260 |
+
|
| 261 |
+
# Generate FC matrix using our adapter function
|
| 262 |
+
try:
|
| 263 |
+
# Try the new function signature first (returns 3 values)
|
| 264 |
+
custom_fc_mat, gen_predicted_score, viz_path = generate_custom_fc(
|
| 265 |
+
demo_values,
|
| 266 |
+
demovae_model,
|
| 267 |
+
prediction_model if not use_custom_score else None,
|
| 268 |
+
visualize=True
|
| 269 |
+
)
|
| 270 |
+
except (ValueError, TypeError) as e:
|
| 271 |
+
# Fall back to older function signature (returns 2 values)
|
| 272 |
+
print(f"Warning: Using older generate_custom_fc signature: {e}")
|
| 273 |
+
custom_fc_mat, gen_predicted_score = generate_custom_fc(
|
| 274 |
+
demo_values,
|
| 275 |
+
demovae_model,
|
| 276 |
+
prediction_model if not use_custom_score else None
|
| 277 |
+
)
|
| 278 |
+
viz_path = None
|
| 279 |
+
|
| 280 |
+
# If we're using the generated prediction
|
| 281 |
+
if not use_custom_score and gen_predicted_score is not None:
|
| 282 |
+
predicted_aphasia_score = gen_predicted_score
|
| 283 |
+
aphasia_score_source = "predicted"
|
| 284 |
+
elif predicted_aphasia_score is None:
|
| 285 |
+
# Fall back to default value if nothing else was set
|
| 286 |
+
predicted_aphasia_score = aphasia_severity
|
| 287 |
+
|
| 288 |
+
# If we have a visualization path from the new function, use it
|
| 289 |
+
# Use /tmp for visualization files in Spaces
|
| 290 |
+
viz_dir = "/tmp/fc_visualizations" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
|
| 291 |
+
os.makedirs(viz_dir, exist_ok=True)
|
| 292 |
+
temp_img_path = os.path.join(viz_dir, f"temp_fc_matrix_{time.strftime('%Y%m%d_%H%M%S')}.png")
|
| 293 |
+
|
| 294 |
+
if viz_path and os.path.exists(viz_path):
|
| 295 |
+
# Use the already created visualization
|
| 296 |
+
import shutil
|
| 297 |
+
shutil.copy(viz_path, temp_img_path)
|
| 298 |
+
else:
|
| 299 |
+
# Generate FC heatmap with aphasia score in title using our new color scheme
|
| 300 |
+
try:
|
| 301 |
+
# Use the new visualization function if available
|
| 302 |
+
get_connectivity_visualization(
|
| 303 |
+
custom_fc_mat,
|
| 304 |
+
subject_id=f"Patient: Age {age}, Gender {'M' if gender_val else 'F'}, Aphasia Score: {predicted_aphasia_score:.1f}",
|
| 305 |
+
output_path=temp_img_path
|
| 306 |
+
)
|
| 307 |
+
except (NameError, AttributeError):
|
| 308 |
+
# Fall back to old style visualization
|
| 309 |
+
plt.figure(figsize=(10, 8))
|
| 310 |
+
plt.imshow(custom_fc_mat, cmap='coolwarm', vmin=-1, vmax=1)
|
| 311 |
+
plt.colorbar(label='Correlation')
|
| 312 |
+
plt.title(f'FC Matrix: Age {age}, Gender {"M" if gender_val else "F"}, Aphasia Score: {predicted_aphasia_score:.1f}')
|
| 313 |
+
plt.savefig(temp_img_path)
|
| 314 |
+
plt.close()
|
| 315 |
+
|
| 316 |
+
# Create DataFrame for FC values
|
| 317 |
+
region_names = ATLAS_REGIONS[:custom_fc_mat.shape[0]]
|
| 318 |
+
|
| 319 |
+
# Analyze FC regions
|
| 320 |
+
top_positive, top_negative = analyze_fc_regions(custom_fc_mat, region_names)
|
| 321 |
+
|
| 322 |
+
# Create summary text with the analysis
|
| 323 |
+
severity_category = get_aphasia_severity_category(predicted_aphasia_score)
|
| 324 |
+
|
| 325 |
+
summary = f"""### Aphasia Score: {predicted_aphasia_score:.1f}/100
|
| 326 |
+
Category: {severity_category}
|
| 327 |
+
Source: {"Model Prediction" if aphasia_score_source == "predicted" else "Custom Value" if aphasia_score_source == "custom" else "Default"}
|
| 328 |
+
|
| 329 |
+
### Demographic Information
|
| 330 |
+
- Age: {age} years
|
| 331 |
+
- Months Post Onset: {mpo}
|
| 332 |
+
- Education: {education} years
|
| 333 |
+
- Gender: {gender}
|
| 334 |
+
- Handedness: {handedness}
|
| 335 |
+
- Lesion Size: {lesion_size}%
|
| 336 |
+
|
| 337 |
+
### Brain Connectivity Analysis
|
| 338 |
+
Top connected brain regions:
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
for region, value in top_positive:
|
| 342 |
+
summary += f"- {region}: {value:.2f}\n"
|
| 343 |
+
|
| 344 |
+
summary += "\nLeast connected brain regions:\n"
|
| 345 |
+
|
| 346 |
+
for region, value in top_negative:
|
| 347 |
+
summary += f"- {region}: {value:.2f}\n"
|
| 348 |
+
|
| 349 |
+
# Create dataframe for FC matrix
|
| 350 |
+
df_data = []
|
| 351 |
+
for i in range(custom_fc_mat.shape[0]):
|
| 352 |
+
for j in range(custom_fc_mat.shape[0]):
|
| 353 |
+
if i < j: # Only include upper triangle to avoid redundancy
|
| 354 |
+
df_data.append({
|
| 355 |
+
"Region 1": region_names[i],
|
| 356 |
+
"Region 2": region_names[j],
|
| 357 |
+
"Connectivity": round(float(custom_fc_mat[i, j]), 2)
|
| 358 |
+
})
|
| 359 |
+
|
| 360 |
+
# Sort by absolute connectivity value
|
| 361 |
+
df = pd.DataFrame(df_data)
|
| 362 |
+
df = df.sort_values(by="Connectivity", key=abs, ascending=False)
|
| 363 |
+
|
| 364 |
+
# Limit to top 100 connections for performance
|
| 365 |
+
df = df.head(100)
|
| 366 |
+
|
| 367 |
+
return temp_img_path, summary, df
|
| 368 |
+
|
| 369 |
+
# Check if model exists and try to load it
|
| 370 |
+
if os.path.exists(MODEL_PATH):
|
| 371 |
+
print("Model file found. Loading model...")
|
| 372 |
+
load_model()
|
| 373 |
+
else:
|
| 374 |
+
print("No model found. Please train the model first.")
|
| 375 |
+
|
| 376 |
+
# Create Gradio interface
|
| 377 |
+
with gr.Blocks(title="Aphasia Prediction with FC Visualization") as demo:
|
| 378 |
+
gr.Markdown("# Aphasia Prediction with Functional Connectivity Visualization")
|
| 379 |
+
gr.Markdown("This app predicts aphasia scores based on patient demographics and displays functional connectivity patterns in the brain.")
|
| 380 |
+
|
| 381 |
+
with gr.Tab("Predict & Visualize"):
|
| 382 |
+
with gr.Row():
|
| 383 |
+
with gr.Column(scale=1):
|
| 384 |
+
gr.Markdown("### Patient Demographics")
|
| 385 |
+
age = gr.Slider(minimum=20, maximum=90, value=60, step=1, label="Age (years)")
|
| 386 |
+
mpo = gr.Slider(minimum=1, maximum=36, value=6, step=1, label="Months Post Onset")
|
| 387 |
+
education = gr.Slider(minimum=8, maximum=22, value=16, step=1, label="Education (years)")
|
| 388 |
+
gender = gr.Radio(["Male", "Female"], value="Male", label="Gender")
|
| 389 |
+
handedness = gr.Radio(["Right", "Left"], value="Right", label="Handedness")
|
| 390 |
+
|
| 391 |
+
gr.Markdown("### Aphasia Information")
|
| 392 |
+
aphasia_severity = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Aphasia Severity (WAB AQ)")
|
| 393 |
+
lesion_size = gr.Slider(minimum=0, maximum=100, value=20, step=1, label="Lesion Size (%)")
|
| 394 |
+
|
| 395 |
+
use_custom_score = gr.Checkbox(label="Override with custom score", value=False)
|
| 396 |
+
custom_score = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Custom WAB AQ Score",
|
| 397 |
+
visible=False)
|
| 398 |
+
|
| 399 |
+
# Make custom score visible only when checkbox is selected
|
| 400 |
+
use_custom_score.change(lambda x: gr.update(visible=x), inputs=[use_custom_score], outputs=[custom_score])
|
| 401 |
+
|
| 402 |
+
generate_btn = gr.Button("Generate Functional Connectivity", variant="primary")
|
| 403 |
+
|
| 404 |
+
with gr.Column(scale=2):
|
| 405 |
+
with gr.Row():
|
| 406 |
+
fc_image = gr.Image(label="Functional Connectivity Matrix", show_download_button=True)
|
| 407 |
+
fc_summary = gr.Markdown(label="Analysis Summary")
|
| 408 |
+
|
| 409 |
+
fc_data = gr.DataFrame(label="Top FC Connections")
|
| 410 |
+
|
| 411 |
+
# Generate FC on button click
|
| 412 |
+
generate_btn.click(
|
| 413 |
+
generate_fc_visualization,
|
| 414 |
+
inputs=[age, mpo, education, gender, handedness,
|
| 415 |
+
aphasia_severity, lesion_size,
|
| 416 |
+
use_custom_score, custom_score],
|
| 417 |
+
outputs=[fc_image, fc_summary, fc_data]
|
| 418 |
+
)
|
| 419 |
+
|
| 420 |
+
with gr.Tab("Train Model"):
|
| 421 |
+
gr.Markdown("### Train or Retrain the Model")
|
| 422 |
+
gr.Markdown("""
|
| 423 |
+
This tab allows you to train the two-stage model:
|
| 424 |
+
1. First stage: DemoVAE model learns brain connectivity patterns
|
| 425 |
+
2. Second stage: Random Forest predicts aphasia scores
|
| 426 |
+
|
| 427 |
+
Note: This will download data from HuggingFace 'SreekarB/OSFData' and use the full dataset for training.
|
| 428 |
+
""")
|
| 429 |
+
|
| 430 |
+
train_btn = gr.Button("Train Model", variant="primary")
|
| 431 |
+
train_output = gr.Textbox(label="Training Status", lines=20)
|
| 432 |
+
|
| 433 |
+
train_btn.click(train_model, inputs=[], outputs=[train_output])
|
| 434 |
+
|
| 435 |
+
gr.Markdown("## How to use")
|
| 436 |
+
gr.Markdown("""
|
| 437 |
+
1. Set the patient's demographic information and aphasia details
|
| 438 |
+
2. Click "Generate Functional Connectivity" to see the visualization and prediction
|
| 439 |
+
3. Optionally, override the model's prediction with your own custom score
|
| 440 |
+
4. If the model is not trained, go to the "Train Model" tab to train it first
|
| 441 |
+
|
| 442 |
+
The heatmap shows correlations between brain regions. Yellow indicates positive correlations (regions that activate together),
|
| 443 |
+
green indicates neutral correlations, and blue indicates negative correlations (regions with opposing activation patterns).
|
| 444 |
+
""")
|
| 445 |
+
|
| 446 |
+
if __name__ == "__main__":
|
| 447 |
+
# Set up the optimal launch configuration for Hugging Face Spaces
|
| 448 |
+
if IS_SPACE:
|
| 449 |
+
demo.launch(server_name="0.0.0.0", share=False)
|
| 450 |
+
else:
|
| 451 |
+
demo.launch()
|
osf_demovae_adapter.py
ADDED
|
@@ -0,0 +1,1369 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
OSF DemoVAE Adapter for HuggingFace Spaces
|
| 3 |
+
|
| 4 |
+
This script is optimized for running in HuggingFace Spaces.
|
| 5 |
+
It loads functional connectivity data from SreekarB/OSFData and
|
| 6 |
+
pretrained models from SreekarB/OSFModels.
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import nibabel as nib
|
| 13 |
+
import sys
|
| 14 |
+
import tempfile
|
| 15 |
+
import pickle
|
| 16 |
+
import json
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
# Import HuggingFace libraries
|
| 21 |
+
from datasets import load_dataset
|
| 22 |
+
from huggingface_hub import hf_hub_download, HfApi, list_repo_files
|
| 23 |
+
|
| 24 |
+
# Import PyTorch if available (needed for DemoVAE)
|
| 25 |
+
try:
|
| 26 |
+
import torch
|
| 27 |
+
TORCH_AVAILABLE = True
|
| 28 |
+
except ImportError:
|
| 29 |
+
TORCH_AVAILABLE = False
|
| 30 |
+
print("WARNING: PyTorch not available. Some functionality may be limited.")
|
| 31 |
+
|
| 32 |
+
# Add PIP package to path - in Spaces this will be the correct path
|
| 33 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'pip', 'src'))
|
| 34 |
+
|
| 35 |
+
# Only try to import DemoVAE if PyTorch is available
|
| 36 |
+
if TORCH_AVAILABLE:
|
| 37 |
+
try:
|
| 38 |
+
from demovae.model import to_torch, to_cuda, to_numpy, VAE
|
| 39 |
+
from demovae.sklearn import DemoVAE
|
| 40 |
+
DEMOVAE_AVAILABLE = True
|
| 41 |
+
except ImportError:
|
| 42 |
+
DEMOVAE_AVAILABLE = False
|
| 43 |
+
print("WARNING: DemoVAE package not found. Only data loading will be available.")
|
| 44 |
+
else:
|
| 45 |
+
DEMOVAE_AVAILABLE = False
|
| 46 |
+
print("WARNING: PyTorch not available. DemoVAE functionality disabled.")
|
| 47 |
+
|
| 48 |
+
# Constants for HuggingFace Spaces
|
| 49 |
+
DATA_REPO = "SreekarB/OSFData" # HuggingFace dataset repository
|
| 50 |
+
MODEL_REPO = "SreekarB/OSFModels" # HuggingFace model repository
|
| 51 |
+
IS_SPACE = os.environ.get('SPACE_ID') is not None # Check if running in HF Spaces
|
| 52 |
+
|
| 53 |
+
# Define paths for cached data - in Spaces, use /tmp for temporary storage
|
| 54 |
+
CACHE_DIR = "/tmp/osf_data" if IS_SPACE else "./cache"
|
| 55 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 56 |
+
|
| 57 |
+
# Model URLs and filenames
|
| 58 |
+
MODEL_FILES = {
|
| 59 |
+
"demovae": "osf_demovae_model.pt",
|
| 60 |
+
"prediction": "aphasia_prediction_model.pkl",
|
| 61 |
+
"combined": "demovae_and_prediction_models.pkl"
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# Atlas region labels for the FC matrix
|
| 65 |
+
# These are the standard AAL (Automated Anatomical Labeling) atlas brain regions
|
| 66 |
+
# We use 45 regions typically used in functional connectivity matrices
|
| 67 |
+
ATLAS_REGIONS = [
|
| 68 |
+
"Frontal_Sup_L", "Frontal_Sup_R", "Frontal_Mid_L", "Frontal_Mid_R",
|
| 69 |
+
"Frontal_Inf_Oper_L", "Frontal_Inf_Oper_R", "Frontal_Inf_Tri_L", "Frontal_Inf_Tri_R",
|
| 70 |
+
"Frontal_Inf_Orb_L", "Frontal_Inf_Orb_R", "Rolandic_Oper_L", "Rolandic_Oper_R",
|
| 71 |
+
"SMA_L", "SMA_R", "Olfactory_L", "Olfactory_R",
|
| 72 |
+
"Frontal_Med_Orb_L", "Frontal_Med_Orb_R", "Rectus_L", "Rectus_R",
|
| 73 |
+
"Insula_L", "Insula_R", "Cingulum_Ant_L", "Cingulum_Ant_R",
|
| 74 |
+
"Cingulum_Mid_L", "Cingulum_Mid_R", "Cingulum_Post_L", "Cingulum_Post_R",
|
| 75 |
+
"Hippocampus_L", "Hippocampus_R", "ParaHippocampal_L", "ParaHippocampal_R",
|
| 76 |
+
"Amygdala_L", "Amygdala_R", "Calcarine_L", "Calcarine_R",
|
| 77 |
+
"Cuneus_L", "Cuneus_R", "Lingual_L", "Lingual_R",
|
| 78 |
+
"Occipital_Sup_L", "Occipital_Sup_R", "Occipital_Mid_L", "Occipital_Mid_R",
|
| 79 |
+
"Occipital_Inf_L"
|
| 80 |
+
]
|
| 81 |
+
|
| 82 |
+
# Ensure we have exactly 45 regions
|
| 83 |
+
if len(ATLAS_REGIONS) != 45:
|
| 84 |
+
print(f"Warning: Expected 45 regions but have {len(ATLAS_REGIONS)}. Using generic labels.")
|
| 85 |
+
ATLAS_REGIONS = [f"Region_{i+1}" for i in range(45)]
|
| 86 |
+
|
| 87 |
+
# Utility functions for functional connectivity
|
| 88 |
+
def mat2vec(fc):
|
| 89 |
+
"""Convert matrix to vector"""
|
| 90 |
+
d = fc.shape[0]
|
| 91 |
+
a, b = np.triu_indices(d, 1)
|
| 92 |
+
return fc[a, b]
|
| 93 |
+
|
| 94 |
+
def vec2mat(fc):
|
| 95 |
+
"""Convert vector to matrix"""
|
| 96 |
+
d = int(round((1+(1+8*fc.size)**0.5)/2))
|
| 97 |
+
a, b = np.triu_indices(d, 1)
|
| 98 |
+
mat = np.zeros((d, d))
|
| 99 |
+
mat[a, b] = fc
|
| 100 |
+
mat += mat.T
|
| 101 |
+
ones = np.arange(d)
|
| 102 |
+
mat[ones, ones] = 1
|
| 103 |
+
return mat
|
| 104 |
+
|
| 105 |
+
def load_nii_from_huggingface(file_name):
|
| 106 |
+
"""
|
| 107 |
+
Load a NII file from HuggingFace and return as a nibabel image object.
|
| 108 |
+
Optimized for HuggingFace Spaces.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
file_name: Name of the NII file (e.g., "P01_rs.nii")
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
Nibabel image object
|
| 115 |
+
"""
|
| 116 |
+
print(f"Loading {file_name} from {DATA_REPO}...")
|
| 117 |
+
|
| 118 |
+
# Check cache first to avoid repeated downloads
|
| 119 |
+
cache_path = os.path.join(CACHE_DIR, file_name)
|
| 120 |
+
if os.path.exists(cache_path):
|
| 121 |
+
try:
|
| 122 |
+
print(f"Loading from cache: {cache_path}")
|
| 123 |
+
nii_img = nib.load(cache_path)
|
| 124 |
+
return nii_img
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"Error loading from cache: {e}")
|
| 127 |
+
# If cache is corrupted, remove it
|
| 128 |
+
os.remove(cache_path)
|
| 129 |
+
|
| 130 |
+
try:
|
| 131 |
+
# First try direct download using huggingface_hub
|
| 132 |
+
try:
|
| 133 |
+
# Use hf_hub_download which works well in Spaces
|
| 134 |
+
tmp_path = hf_hub_download(
|
| 135 |
+
repo_id=DATA_REPO,
|
| 136 |
+
filename=file_name,
|
| 137 |
+
repo_type="dataset",
|
| 138 |
+
cache_dir=CACHE_DIR
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
# Load with nibabel
|
| 142 |
+
nii_img = nib.load(tmp_path)
|
| 143 |
+
|
| 144 |
+
# Make a copy in our cache dir for easier access
|
| 145 |
+
if tmp_path != cache_path:
|
| 146 |
+
import shutil
|
| 147 |
+
shutil.copy2(tmp_path, cache_path)
|
| 148 |
+
|
| 149 |
+
return nii_img
|
| 150 |
+
|
| 151 |
+
except Exception as e:
|
| 152 |
+
print(f"Direct download failed: {e}")
|
| 153 |
+
|
| 154 |
+
# Try dataset loading approach
|
| 155 |
+
try:
|
| 156 |
+
dataset = load_dataset(DATA_REPO, data_files=file_name)
|
| 157 |
+
|
| 158 |
+
if "train" in dataset and len(dataset["train"]) > 0:
|
| 159 |
+
if file_name in dataset["train"][0]:
|
| 160 |
+
nii_data = dataset["train"][0][file_name]
|
| 161 |
+
|
| 162 |
+
# Save to our cache
|
| 163 |
+
with open(cache_path, 'wb') as f:
|
| 164 |
+
f.write(nii_data)
|
| 165 |
+
|
| 166 |
+
# Load with nibabel
|
| 167 |
+
nii_img = nib.load(cache_path)
|
| 168 |
+
return nii_img
|
| 169 |
+
|
| 170 |
+
except Exception as e:
|
| 171 |
+
print(f"Dataset loading failed: {e}")
|
| 172 |
+
|
| 173 |
+
# Check if we're in Spaces and the file might be available in the local repository
|
| 174 |
+
if IS_SPACE:
|
| 175 |
+
local_paths = [
|
| 176 |
+
os.path.join("/app/OSFData", file_name), # Standard location in Spaces
|
| 177 |
+
os.path.join("OSFData", file_name), # Repository root
|
| 178 |
+
os.path.join("data", file_name) # Common data directory
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
for path in local_paths:
|
| 182 |
+
if os.path.exists(path):
|
| 183 |
+
print(f"Loading from local Spaces file: {path}")
|
| 184 |
+
nii_img = nib.load(path)
|
| 185 |
+
return nii_img
|
| 186 |
+
|
| 187 |
+
# All methods failed
|
| 188 |
+
print(f"Could not load {file_name} using any method")
|
| 189 |
+
return None
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(f"Error loading {file_name}: {e}")
|
| 193 |
+
return None
|
| 194 |
+
|
| 195 |
+
def nii_to_fc_matrix(nii_img, target_size=45):
|
| 196 |
+
"""
|
| 197 |
+
Process a NIfTI image to extract functional connectivity matrix
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
nii_img: Nibabel image object
|
| 201 |
+
target_size: Target size for the FC matrix (default 45x45)
|
| 202 |
+
|
| 203 |
+
Returns:
|
| 204 |
+
Functional connectivity matrix of consistent size
|
| 205 |
+
"""
|
| 206 |
+
try:
|
| 207 |
+
# Get time series data
|
| 208 |
+
print(" • Getting time series data...")
|
| 209 |
+
time_series = nii_img.get_fdata()
|
| 210 |
+
print(f" Time series shape: {time_series.shape}")
|
| 211 |
+
|
| 212 |
+
# Ensure we have 4D data (three spatial dimensions + time)
|
| 213 |
+
if len(time_series.shape) < 4:
|
| 214 |
+
print(f" ✗ Error: Expected 4D data but got {len(time_series.shape)}D data")
|
| 215 |
+
return None
|
| 216 |
+
|
| 217 |
+
# Reshape to (voxels, time)
|
| 218 |
+
print(" • Reshaping time series data...")
|
| 219 |
+
orig_shape = time_series.shape
|
| 220 |
+
time_series = time_series.reshape(orig_shape[0]*orig_shape[1]*orig_shape[2], orig_shape[3])
|
| 221 |
+
print(f" Reshaped to {time_series.shape}")
|
| 222 |
+
|
| 223 |
+
# Remove NaN and infinity values
|
| 224 |
+
print(" • Cleaning data (removing NaN and infinity)...")
|
| 225 |
+
nan_count = np.isnan(time_series).sum()
|
| 226 |
+
inf_count = np.isinf(time_series).sum()
|
| 227 |
+
if nan_count > 0 or inf_count > 0:
|
| 228 |
+
print(f" Found {nan_count} NaN values and {inf_count} infinity values")
|
| 229 |
+
time_series = np.nan_to_num(time_series, nan=0, posinf=0, neginf=0)
|
| 230 |
+
|
| 231 |
+
# Filter out voxels with no signal
|
| 232 |
+
print(" • Filtering voxels with no signal...")
|
| 233 |
+
signal_mask = np.std(time_series, axis=1) > 0
|
| 234 |
+
active_voxels = np.sum(signal_mask)
|
| 235 |
+
print(f" Active voxels: {active_voxels} out of {time_series.shape[0]} ({active_voxels/time_series.shape[0]*100:.1f}%)")
|
| 236 |
+
|
| 237 |
+
if active_voxels < 10:
|
| 238 |
+
print(" ✗ Error: Too few active voxels for reliable correlation")
|
| 239 |
+
return None
|
| 240 |
+
|
| 241 |
+
time_series = time_series[signal_mask]
|
| 242 |
+
|
| 243 |
+
# Ensure we don't have too many voxels - sample if needed for consistency
|
| 244 |
+
if time_series.shape[0] > 10000:
|
| 245 |
+
print(f" • Sampling voxels to reduce computational load...")
|
| 246 |
+
np.random.seed(42) # For reproducibility
|
| 247 |
+
sample_indices = np.random.choice(time_series.shape[0], 10000, replace=False)
|
| 248 |
+
time_series = time_series[sample_indices]
|
| 249 |
+
print(f" Sampled to {time_series.shape[0]} voxels")
|
| 250 |
+
|
| 251 |
+
# Calculate correlation matrix
|
| 252 |
+
print(" • Calculating functional connectivity matrix...")
|
| 253 |
+
fc_matrix = np.corrcoef(time_series.T)
|
| 254 |
+
print(f" Raw FC matrix shape: {fc_matrix.shape}")
|
| 255 |
+
|
| 256 |
+
# Validate the matrix
|
| 257 |
+
if np.isnan(fc_matrix).any():
|
| 258 |
+
print(" ✗ Error: FC matrix contains NaN values")
|
| 259 |
+
return None
|
| 260 |
+
|
| 261 |
+
if np.isinf(fc_matrix).any():
|
| 262 |
+
print(" ✗ Error: FC matrix contains infinity values")
|
| 263 |
+
return None
|
| 264 |
+
|
| 265 |
+
# Ensure consistent matrix size (45x45) for the FC matrix
|
| 266 |
+
if fc_matrix.shape[0] != target_size:
|
| 267 |
+
print(f" • Resizing FC matrix to standard {target_size}x{target_size} size...")
|
| 268 |
+
|
| 269 |
+
if fc_matrix.shape[0] > target_size:
|
| 270 |
+
# Take the first target_size x target_size submatrix
|
| 271 |
+
fc_matrix = fc_matrix[:target_size, :target_size]
|
| 272 |
+
else:
|
| 273 |
+
# Pad with zeros
|
| 274 |
+
padded_matrix = np.zeros((target_size, target_size))
|
| 275 |
+
padded_matrix[:fc_matrix.shape[0], :fc_matrix.shape[1]] = fc_matrix
|
| 276 |
+
|
| 277 |
+
# Make sure the diagonal is 1
|
| 278 |
+
np.fill_diagonal(padded_matrix, 1)
|
| 279 |
+
|
| 280 |
+
fc_matrix = padded_matrix
|
| 281 |
+
|
| 282 |
+
print(f" Final FC matrix shape: {fc_matrix.shape}")
|
| 283 |
+
|
| 284 |
+
print(" ✓ Functional connectivity matrix successfully calculated")
|
| 285 |
+
return fc_matrix
|
| 286 |
+
|
| 287 |
+
except Exception as e:
|
| 288 |
+
print(f" ✗ Error processing NIfTI data: {e}")
|
| 289 |
+
return None
|
| 290 |
+
|
| 291 |
+
def get_all_nii_files():
|
| 292 |
+
"""
|
| 293 |
+
Discover all NII files available in the HuggingFace dataset
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
List of NII filenames
|
| 297 |
+
"""
|
| 298 |
+
print("Discovering NII files in SreekarB/OSFData repository...")
|
| 299 |
+
|
| 300 |
+
# Check cache first
|
| 301 |
+
cache_file = os.path.join(CACHE_DIR, "nii_files_list.json")
|
| 302 |
+
if os.path.exists(cache_file):
|
| 303 |
+
try:
|
| 304 |
+
with open(cache_file, 'r') as f:
|
| 305 |
+
nii_files = json.load(f)
|
| 306 |
+
print(f"Loaded {len(nii_files)} NII files from cache")
|
| 307 |
+
return nii_files
|
| 308 |
+
except Exception as e:
|
| 309 |
+
print(f"Error loading from cache: {e}")
|
| 310 |
+
|
| 311 |
+
try:
|
| 312 |
+
# Try to list repository files using HF API
|
| 313 |
+
try:
|
| 314 |
+
api = HfApi()
|
| 315 |
+
files = api.list_repo_files(DATA_REPO, repo_type="dataset")
|
| 316 |
+
nii_files = [f for f in files if f.endswith('.nii')]
|
| 317 |
+
|
| 318 |
+
if nii_files:
|
| 319 |
+
print(f"Found {len(nii_files)} NII files in the repository")
|
| 320 |
+
# Save to cache
|
| 321 |
+
with open(cache_file, 'w') as f:
|
| 322 |
+
json.dump(nii_files, f)
|
| 323 |
+
return nii_files
|
| 324 |
+
except Exception as e:
|
| 325 |
+
print(f"Could not list files via API: {e}")
|
| 326 |
+
|
| 327 |
+
# Try loading demo data to check what files exist
|
| 328 |
+
print("Trying to load demographic data to find subject IDs...")
|
| 329 |
+
try:
|
| 330 |
+
demo_data = load_dataset(DATA_REPO, data_files="FC_graph_covariate_data.csv")
|
| 331 |
+
if "train" in demo_data:
|
| 332 |
+
df = pd.DataFrame(demo_data["train"])
|
| 333 |
+
if "ID" in df.columns:
|
| 334 |
+
subject_ids = df["ID"].unique()
|
| 335 |
+
print(f"Found {len(subject_ids)} subject IDs in demographic data")
|
| 336 |
+
nii_files = [f"{subject_id}_rs.nii" for subject_id in subject_ids]
|
| 337 |
+
|
| 338 |
+
# Save to cache
|
| 339 |
+
with open(cache_file, 'w') as f:
|
| 340 |
+
json.dump(nii_files, f)
|
| 341 |
+
return nii_files
|
| 342 |
+
except Exception as e:
|
| 343 |
+
print(f"Error loading demographic data: {e}")
|
| 344 |
+
|
| 345 |
+
# Fallback to standard pattern - assuming 30 subjects (P01 to P30)
|
| 346 |
+
print("Using default pattern for 30 subjects (P01-P30)...")
|
| 347 |
+
nii_files = [f"P{i:02d}_rs.nii" for i in range(1, 31)]
|
| 348 |
+
|
| 349 |
+
# Save to cache
|
| 350 |
+
with open(cache_file, 'w') as f:
|
| 351 |
+
json.dump(nii_files, f)
|
| 352 |
+
return nii_files
|
| 353 |
+
|
| 354 |
+
except Exception as e:
|
| 355 |
+
print(f"Error discovering NII files: {e}")
|
| 356 |
+
return [f"P{i:02d}_rs.nii" for i in range(1, 31)]
|
| 357 |
+
|
| 358 |
+
def download_model(model_type):
|
| 359 |
+
"""
|
| 360 |
+
Download a pretrained model from the SreekarB/OSFModels repository
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
model_type: Type of model to download ('demovae', 'prediction', or 'combined')
|
| 364 |
+
|
| 365 |
+
Returns:
|
| 366 |
+
Path to the downloaded model file
|
| 367 |
+
"""
|
| 368 |
+
if model_type not in MODEL_FILES:
|
| 369 |
+
raise ValueError(f"Unknown model type: {model_type}. Available types: {list(MODEL_FILES.keys())}")
|
| 370 |
+
|
| 371 |
+
model_file = MODEL_FILES[model_type]
|
| 372 |
+
cache_path = os.path.join(CACHE_DIR, model_file)
|
| 373 |
+
|
| 374 |
+
# Check if the model is already cached
|
| 375 |
+
if os.path.exists(cache_path):
|
| 376 |
+
print(f"Using cached model: {model_file}")
|
| 377 |
+
return cache_path
|
| 378 |
+
|
| 379 |
+
print(f"Downloading {model_type} model from {MODEL_REPO}...")
|
| 380 |
+
try:
|
| 381 |
+
# Download the model using huggingface_hub
|
| 382 |
+
downloaded_path = hf_hub_download(
|
| 383 |
+
repo_id=MODEL_REPO,
|
| 384 |
+
filename=model_file,
|
| 385 |
+
repo_type="model",
|
| 386 |
+
cache_dir=CACHE_DIR
|
| 387 |
+
)
|
| 388 |
+
|
| 389 |
+
print(f"Model downloaded to: {downloaded_path}")
|
| 390 |
+
return downloaded_path
|
| 391 |
+
|
| 392 |
+
except Exception as e:
|
| 393 |
+
print(f"Error downloading model: {e}")
|
| 394 |
+
raise
|
| 395 |
+
|
| 396 |
+
def load_demographic_data():
|
| 397 |
+
"""
|
| 398 |
+
Load demographic data from HuggingFace dataset
|
| 399 |
+
|
| 400 |
+
Returns:
|
| 401 |
+
Pandas DataFrame with demographic data
|
| 402 |
+
"""
|
| 403 |
+
print("Loading demographic data from SreekarB/OSFData...")
|
| 404 |
+
|
| 405 |
+
# Cache path for demographic data
|
| 406 |
+
cache_path = os.path.join(CACHE_DIR, "demographic_data.csv")
|
| 407 |
+
|
| 408 |
+
# Check if data is already cached
|
| 409 |
+
if os.path.exists(cache_path):
|
| 410 |
+
try:
|
| 411 |
+
print(f"Loading demographic data from cache: {cache_path}")
|
| 412 |
+
df = pd.read_csv(cache_path)
|
| 413 |
+
print(f"Loaded data for {len(df)} subjects with columns: {', '.join(df.columns)}")
|
| 414 |
+
return df
|
| 415 |
+
except Exception as e:
|
| 416 |
+
print(f"Error loading cached demographic data: {e}")
|
| 417 |
+
|
| 418 |
+
try:
|
| 419 |
+
# Try to download file using huggingface_hub
|
| 420 |
+
try:
|
| 421 |
+
demo_path = hf_hub_download(
|
| 422 |
+
repo_id=DATA_REPO,
|
| 423 |
+
filename="FC_graph_covariate_data.csv",
|
| 424 |
+
repo_type="dataset",
|
| 425 |
+
cache_dir=CACHE_DIR
|
| 426 |
+
)
|
| 427 |
+
|
| 428 |
+
df = pd.read_csv(demo_path)
|
| 429 |
+
df.to_csv(cache_path, index=False) # Save to our cache
|
| 430 |
+
print(f"Loaded demographic data with {len(df)} subjects")
|
| 431 |
+
return df
|
| 432 |
+
|
| 433 |
+
except Exception as e:
|
| 434 |
+
print(f"Error downloading demographic file: {e}")
|
| 435 |
+
|
| 436 |
+
# Try loading using datasets library
|
| 437 |
+
try:
|
| 438 |
+
demo_dataset = load_dataset(DATA_REPO, data_files="FC_graph_covariate_data.csv")
|
| 439 |
+
if "train" in demo_dataset:
|
| 440 |
+
df = pd.DataFrame(demo_dataset["train"])
|
| 441 |
+
df.to_csv(cache_path, index=False) # Save to our cache
|
| 442 |
+
print(f"Loaded demographic data with {len(df)} subjects")
|
| 443 |
+
return df
|
| 444 |
+
except Exception as e:
|
| 445 |
+
print(f"Error loading demographic data via datasets: {e}")
|
| 446 |
+
|
| 447 |
+
# Check if file exists locally in Spaces
|
| 448 |
+
if IS_SPACE:
|
| 449 |
+
local_paths = [
|
| 450 |
+
os.path.join("/app/OSFData", "FC_graph_covariate_data.csv"),
|
| 451 |
+
os.path.join("OSFData", "FC_graph_covariate_data.csv"),
|
| 452 |
+
os.path.join("data", "FC_graph_covariate_data.csv")
|
| 453 |
+
]
|
| 454 |
+
|
| 455 |
+
for path in local_paths:
|
| 456 |
+
if os.path.exists(path):
|
| 457 |
+
df = pd.read_csv(path)
|
| 458 |
+
df.to_csv(cache_path, index=False) # Save to our cache
|
| 459 |
+
print(f"Loaded demographic data from local file: {path}")
|
| 460 |
+
return df
|
| 461 |
+
|
| 462 |
+
raise FileNotFoundError("Could not load demographic data using any method")
|
| 463 |
+
|
| 464 |
+
except Exception as e:
|
| 465 |
+
print(f"Could not load demographic data: {e}")
|
| 466 |
+
# Create a synthetic dataset for testing
|
| 467 |
+
print("Creating synthetic demographic data for testing")
|
| 468 |
+
data = []
|
| 469 |
+
for i in range(1, 31):
|
| 470 |
+
subject_id = f"P{i:02d}"
|
| 471 |
+
data.append({
|
| 472 |
+
"ID": subject_id,
|
| 473 |
+
"age": np.random.randint(40, 80),
|
| 474 |
+
"gender": "M" if np.random.random() > 0.5 else "F",
|
| 475 |
+
"handedness": "R" if np.random.random() > 0.2 else "L",
|
| 476 |
+
"education": np.random.randint(8, 20),
|
| 477 |
+
"mpo": np.random.randint(1, 24),
|
| 478 |
+
"wab_aq": np.random.uniform(20, 90)
|
| 479 |
+
})
|
| 480 |
+
df = pd.DataFrame(data)
|
| 481 |
+
df.to_csv(cache_path, index=False)
|
| 482 |
+
return df
|
| 483 |
+
|
| 484 |
+
def load_models():
|
| 485 |
+
"""
|
| 486 |
+
Load pretrained models from HuggingFace
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
Dictionary containing loaded models
|
| 490 |
+
"""
|
| 491 |
+
if not DEMOVAE_AVAILABLE:
|
| 492 |
+
print("DemoVAE package not available. Cannot load models.")
|
| 493 |
+
return None
|
| 494 |
+
|
| 495 |
+
print("Loading pretrained models from SreekarB/OSFModels...")
|
| 496 |
+
|
| 497 |
+
try:
|
| 498 |
+
# Try to load the combined model first (contains both VAE and prediction model)
|
| 499 |
+
try:
|
| 500 |
+
combined_path = download_model("combined")
|
| 501 |
+
|
| 502 |
+
with open(combined_path, 'rb') as f:
|
| 503 |
+
models = pickle.load(f)
|
| 504 |
+
|
| 505 |
+
print("Successfully loaded combined models")
|
| 506 |
+
return models
|
| 507 |
+
|
| 508 |
+
except Exception as e:
|
| 509 |
+
print(f"Error loading combined model: {e}")
|
| 510 |
+
print("Trying to load individual models...")
|
| 511 |
+
|
| 512 |
+
# Try loading individual models
|
| 513 |
+
models = {}
|
| 514 |
+
|
| 515 |
+
# Load DemoVAE model
|
| 516 |
+
try:
|
| 517 |
+
vae_path = download_model("demovae")
|
| 518 |
+
|
| 519 |
+
# Load the model - adapting based on file type
|
| 520 |
+
if vae_path.endswith('.pt'):
|
| 521 |
+
# PyTorch model file
|
| 522 |
+
demovae_model = DemoVAE()
|
| 523 |
+
demovae_model.load(vae_path)
|
| 524 |
+
else:
|
| 525 |
+
# Pickle file
|
| 526 |
+
with open(vae_path, 'rb') as f:
|
| 527 |
+
demovae_model = pickle.load(f)
|
| 528 |
+
|
| 529 |
+
models["demovae"] = demovae_model
|
| 530 |
+
print("Successfully loaded DemoVAE model")
|
| 531 |
+
|
| 532 |
+
except Exception as e:
|
| 533 |
+
print(f"Error loading DemoVAE model: {e}")
|
| 534 |
+
|
| 535 |
+
# Load prediction model
|
| 536 |
+
try:
|
| 537 |
+
pred_path = download_model("prediction")
|
| 538 |
+
|
| 539 |
+
with open(pred_path, 'rb') as f:
|
| 540 |
+
pred_model = pickle.load(f)
|
| 541 |
+
|
| 542 |
+
models["prediction"] = pred_model
|
| 543 |
+
print("Successfully loaded prediction model")
|
| 544 |
+
|
| 545 |
+
except Exception as e:
|
| 546 |
+
print(f"Error loading prediction model: {e}")
|
| 547 |
+
|
| 548 |
+
if models:
|
| 549 |
+
return models
|
| 550 |
+
else:
|
| 551 |
+
print("Could not load any models")
|
| 552 |
+
return None
|
| 553 |
+
|
| 554 |
+
except Exception as e:
|
| 555 |
+
print(f"Error loading models: {e}")
|
| 556 |
+
return None
|
| 557 |
+
|
| 558 |
+
def load_and_process_data(quick_test=False):
|
| 559 |
+
"""
|
| 560 |
+
Load and process data from OSF dataset on HuggingFace
|
| 561 |
+
|
| 562 |
+
Args:
|
| 563 |
+
quick_test (bool): If True, use a small subset of data for testing
|
| 564 |
+
|
| 565 |
+
Returns:
|
| 566 |
+
Tuple of (FC features, demographics, WAB scores, improvement scores, dataframe)
|
| 567 |
+
"""
|
| 568 |
+
print("Loading and processing OSF data from HuggingFace...")
|
| 569 |
+
|
| 570 |
+
# Step 1: Load demographic data
|
| 571 |
+
demo_df = load_demographic_data()
|
| 572 |
+
if demo_df is None or len(demo_df) == 0:
|
| 573 |
+
raise ValueError("Could not load demographic data")
|
| 574 |
+
|
| 575 |
+
# Step 2: Get list of NII files to process
|
| 576 |
+
nii_files = get_all_nii_files()
|
| 577 |
+
if not nii_files:
|
| 578 |
+
raise ValueError("No NII files found")
|
| 579 |
+
|
| 580 |
+
# Process appropriate number of subjects
|
| 581 |
+
max_subjects = 5 if quick_test else len(nii_files)
|
| 582 |
+
print(f"Processing {'a subset of' if quick_test else 'all'} subjects: {max_subjects}")
|
| 583 |
+
|
| 584 |
+
# Step 3: Process NII files to FC matrices
|
| 585 |
+
fc_data = {}
|
| 586 |
+
|
| 587 |
+
# Use the demo_df to get subject IDs if available
|
| 588 |
+
if 'ID' in demo_df.columns:
|
| 589 |
+
subject_ids = demo_df['ID'].unique()[:max_subjects]
|
| 590 |
+
|
| 591 |
+
for subject_id in tqdm(subject_ids, desc="Processing subjects"):
|
| 592 |
+
nii_file = f"{subject_id}_rs.nii"
|
| 593 |
+
if nii_file in nii_files:
|
| 594 |
+
fc_matrix, fc_vector = process_subject(subject_id, nii_file)
|
| 595 |
+
if fc_matrix is not None and fc_vector is not None:
|
| 596 |
+
fc_data[subject_id] = {
|
| 597 |
+
"matrix": fc_matrix,
|
| 598 |
+
"vector": fc_vector
|
| 599 |
+
}
|
| 600 |
+
else:
|
| 601 |
+
# Use the first max_subjects NII files
|
| 602 |
+
for i, nii_file in enumerate(nii_files[:max_subjects]):
|
| 603 |
+
subject_id = nii_file.split('_')[0] # Extract subject ID (e.g., P01)
|
| 604 |
+
fc_matrix, fc_vector = process_subject(subject_id, nii_file)
|
| 605 |
+
if fc_matrix is not None and fc_vector is not None:
|
| 606 |
+
fc_data[subject_id] = {
|
| 607 |
+
"matrix": fc_matrix,
|
| 608 |
+
"vector": fc_vector
|
| 609 |
+
}
|
| 610 |
+
|
| 611 |
+
if not fc_data:
|
| 612 |
+
raise ValueError("Failed to process any functional connectivity data")
|
| 613 |
+
|
| 614 |
+
# Step 4: Create matrices for model training
|
| 615 |
+
# Extract demographic features
|
| 616 |
+
processed_subjects = sorted(fc_data.keys())
|
| 617 |
+
subject_indices = [demo_df.index[demo_df['ID'] == subject_id].tolist()[0] for subject_id in processed_subjects if subject_id in demo_df['ID'].values]
|
| 618 |
+
|
| 619 |
+
if not subject_indices:
|
| 620 |
+
# If no matching subjects, create dummy demographics
|
| 621 |
+
X_demo = np.zeros((len(processed_subjects), 5)) # age, mpo, education, gender, handedness
|
| 622 |
+
else:
|
| 623 |
+
# Extract demographic features for subjects we have
|
| 624 |
+
X_demo = np.zeros((len(processed_subjects), 5))
|
| 625 |
+
|
| 626 |
+
for i, subject_id in enumerate(processed_subjects):
|
| 627 |
+
if subject_id in demo_df['ID'].values:
|
| 628 |
+
subject_row = demo_df[demo_df['ID'] == subject_id].iloc[0]
|
| 629 |
+
X_demo[i, 0] = subject_row.get('age', 60)
|
| 630 |
+
X_demo[i, 1] = subject_row.get('mpo', 6)
|
| 631 |
+
X_demo[i, 2] = subject_row.get('education', 12)
|
| 632 |
+
X_demo[i, 3] = 1 if subject_row.get('gender', 'M') == 'M' else 0
|
| 633 |
+
X_demo[i, 4] = 1 if subject_row.get('handedness', 'R') == 'R' else 0
|
| 634 |
+
|
| 635 |
+
# Extract FC features - need to handle potentially different vector sizes
|
| 636 |
+
try:
|
| 637 |
+
# First try direct conversion - works if all vectors have the same length
|
| 638 |
+
X_fc = np.array([fc_data[subject_id]["vector"] for subject_id in processed_subjects])
|
| 639 |
+
except ValueError as e:
|
| 640 |
+
print(f"Warning: FC vectors have inconsistent shapes: {e}")
|
| 641 |
+
|
| 642 |
+
# Get all vector lengths
|
| 643 |
+
vector_lengths = [len(fc_data[subject_id]["vector"]) for subject_id in processed_subjects]
|
| 644 |
+
print(f"Vector lengths: min={min(vector_lengths)}, max={max(vector_lengths)}")
|
| 645 |
+
|
| 646 |
+
# Find most common vector length
|
| 647 |
+
from collections import Counter
|
| 648 |
+
length_counts = Counter(vector_lengths)
|
| 649 |
+
most_common_length = length_counts.most_common(1)[0][0]
|
| 650 |
+
print(f"Most common vector length: {most_common_length} (occurs {length_counts[most_common_length]} times)")
|
| 651 |
+
|
| 652 |
+
# Filter to keep only subjects with the most common vector length
|
| 653 |
+
consistent_subjects = [subject_id for subject_id in processed_subjects
|
| 654 |
+
if len(fc_data[subject_id]["vector"]) == most_common_length]
|
| 655 |
+
|
| 656 |
+
if len(consistent_subjects) < 3:
|
| 657 |
+
# If we don't have enough subjects with consistent vector lengths,
|
| 658 |
+
# pad/truncate vectors to the most common length
|
| 659 |
+
print(f"Too few subjects with consistent vector length. Padding/truncating all vectors.")
|
| 660 |
+
padded_vectors = []
|
| 661 |
+
for subject_id in processed_subjects:
|
| 662 |
+
vec = fc_data[subject_id]["vector"]
|
| 663 |
+
if len(vec) < most_common_length:
|
| 664 |
+
# Pad with zeros
|
| 665 |
+
padded_vec = np.zeros(most_common_length)
|
| 666 |
+
padded_vec[:len(vec)] = vec
|
| 667 |
+
else:
|
| 668 |
+
# Truncate
|
| 669 |
+
padded_vec = vec[:most_common_length]
|
| 670 |
+
padded_vectors.append(padded_vec)
|
| 671 |
+
X_fc = np.array(padded_vectors)
|
| 672 |
+
# Keep all subjects
|
| 673 |
+
else:
|
| 674 |
+
# Use only subjects with consistent vector length
|
| 675 |
+
print(f"Using {len(consistent_subjects)} subjects with consistent vector length {most_common_length}")
|
| 676 |
+
X_fc = np.array([fc_data[subject_id]["vector"] for subject_id in consistent_subjects])
|
| 677 |
+
# Update processed_subjects to only include those with consistent vectors
|
| 678 |
+
processed_subjects = consistent_subjects
|
| 679 |
+
|
| 680 |
+
print(f"Final X_fc shape: {X_fc.shape}")
|
| 681 |
+
|
| 682 |
+
# Get WAB scores if available
|
| 683 |
+
if 'wab_aq' in demo_df.columns:
|
| 684 |
+
y_wab = np.array([demo_df[demo_df['ID'] == subject_id]['wab_aq'].values[0]
|
| 685 |
+
if subject_id in demo_df['ID'].values and not pd.isna(demo_df[demo_df['ID'] == subject_id]['wab_aq'].values[0])
|
| 686 |
+
else np.random.uniform(20, 80)
|
| 687 |
+
for subject_id in processed_subjects])
|
| 688 |
+
else:
|
| 689 |
+
# Generate random WAB scores for testing
|
| 690 |
+
y_wab = np.random.uniform(20, 80, len(processed_subjects))
|
| 691 |
+
|
| 692 |
+
# Generate synthetic improvement data if needed
|
| 693 |
+
improvement_data = []
|
| 694 |
+
for subject_id in processed_subjects:
|
| 695 |
+
improvement = np.random.uniform(0, 40) if 'improvement' not in demo_df.columns else None
|
| 696 |
+
if 'improvement' in demo_df.columns and subject_id in demo_df['ID'].values:
|
| 697 |
+
subj_impr = demo_df[demo_df['ID'] == subject_id]['improvement'].values
|
| 698 |
+
if len(subj_impr) > 0 and not pd.isna(subj_impr[0]):
|
| 699 |
+
improvement = subj_impr[0]
|
| 700 |
+
|
| 701 |
+
improvement_data.append({
|
| 702 |
+
'ID': subject_id,
|
| 703 |
+
'improvement': improvement if improvement is not None else np.random.uniform(0, 40)
|
| 704 |
+
})
|
| 705 |
+
|
| 706 |
+
# Convert to DataFrame
|
| 707 |
+
improvement_df = pd.DataFrame(improvement_data)
|
| 708 |
+
|
| 709 |
+
# Merge with demo_df
|
| 710 |
+
merged_df = pd.merge(demo_df, improvement_df, on='ID', how='left')
|
| 711 |
+
|
| 712 |
+
# Get improvement values
|
| 713 |
+
y_improvement = np.array([improvement_df[improvement_df['ID'] == subject_id]['improvement'].values[0]
|
| 714 |
+
for subject_id in processed_subjects])
|
| 715 |
+
|
| 716 |
+
print(f"Processed data for {len(processed_subjects)} subjects")
|
| 717 |
+
print(f"X_fc shape: {X_fc.shape}")
|
| 718 |
+
print(f"X_demo shape: {X_demo.shape}")
|
| 719 |
+
print(f"y_wab shape: {y_wab.shape}")
|
| 720 |
+
print(f"y_improvement shape: {y_improvement.shape}")
|
| 721 |
+
|
| 722 |
+
return X_fc, X_demo, y_wab, y_improvement, merged_df
|
| 723 |
+
|
| 724 |
+
def plot_connectivity_matrix(fc_matrix, subject_id=None, save_path=None, show_labels=True):
|
| 725 |
+
"""
|
| 726 |
+
Plot a functional connectivity matrix with region labels and a custom color scheme
|
| 727 |
+
|
| 728 |
+
Args:
|
| 729 |
+
fc_matrix: The functional connectivity matrix to plot
|
| 730 |
+
subject_id: Subject identifier for the title
|
| 731 |
+
save_path: Path to save the figure, if None just display
|
| 732 |
+
show_labels: Whether to show region labels on the plot
|
| 733 |
+
|
| 734 |
+
Returns:
|
| 735 |
+
The matplotlib figure object
|
| 736 |
+
"""
|
| 737 |
+
import matplotlib.pyplot as plt
|
| 738 |
+
from matplotlib.colors import LinearSegmentedColormap
|
| 739 |
+
import numpy as np
|
| 740 |
+
|
| 741 |
+
# Ensure we have matplotlib
|
| 742 |
+
if 'plt' not in locals():
|
| 743 |
+
print("Matplotlib is required for plotting connectivity matrices")
|
| 744 |
+
return None
|
| 745 |
+
|
| 746 |
+
# Create a custom colormap (yellow-green-blue)
|
| 747 |
+
colors = [(1, 1, 0), # Yellow for positive correlations
|
| 748 |
+
(0, 1, 0), # Green for moderate correlations
|
| 749 |
+
(0, 0, 1)] # Blue for negative correlations
|
| 750 |
+
|
| 751 |
+
n_bins = 256 # Number of discrete colors
|
| 752 |
+
custom_cmap = LinearSegmentedColormap.from_list("YellowGreenBlue", colors, N=n_bins)
|
| 753 |
+
|
| 754 |
+
# Create figure
|
| 755 |
+
fig_size = 12 if show_labels else 8
|
| 756 |
+
fig, ax = plt.subplots(figsize=(fig_size, fig_size))
|
| 757 |
+
|
| 758 |
+
# Plot the connectivity matrix
|
| 759 |
+
im = ax.imshow(fc_matrix, cmap=custom_cmap, vmin=-1, vmax=1)
|
| 760 |
+
|
| 761 |
+
# Add a title
|
| 762 |
+
title = "Functional Connectivity Matrix"
|
| 763 |
+
if subject_id:
|
| 764 |
+
title += f" - Subject {subject_id}"
|
| 765 |
+
ax.set_title(title, fontsize=14)
|
| 766 |
+
|
| 767 |
+
# Add labels if requested
|
| 768 |
+
if show_labels and fc_matrix.shape[0] <= len(ATLAS_REGIONS):
|
| 769 |
+
# Get labels for the matrix size
|
| 770 |
+
labels = ATLAS_REGIONS[:fc_matrix.shape[0]]
|
| 771 |
+
|
| 772 |
+
# Add x and y labels
|
| 773 |
+
ax.set_xticks(np.arange(len(labels)))
|
| 774 |
+
ax.set_yticks(np.arange(len(labels)))
|
| 775 |
+
ax.set_xticklabels(labels, rotation=90, fontsize=8)
|
| 776 |
+
ax.set_yticklabels(labels, fontsize=8)
|
| 777 |
+
|
| 778 |
+
# Add grid lines
|
| 779 |
+
ax.set_xticks(np.arange(-.5, len(labels), 1), minor=True)
|
| 780 |
+
ax.set_yticks(np.arange(-.5, len(labels), 1), minor=True)
|
| 781 |
+
ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
|
| 782 |
+
else:
|
| 783 |
+
# Just add numbers
|
| 784 |
+
ax.set_xlabel("Brain Region Index")
|
| 785 |
+
ax.set_ylabel("Brain Region Index")
|
| 786 |
+
|
| 787 |
+
# Add colorbar
|
| 788 |
+
cbar = fig.colorbar(im, ax=ax)
|
| 789 |
+
cbar.set_label("Correlation Strength", rotation=270, labelpad=15)
|
| 790 |
+
|
| 791 |
+
# Add annotations explaining the color scheme
|
| 792 |
+
fig.text(0.01, 0.01, "Color scheme: Yellow (positive correlation), Green (neutral), Blue (negative correlation)",
|
| 793 |
+
fontsize=8, ha='left')
|
| 794 |
+
|
| 795 |
+
# Add explanation of matrix content
|
| 796 |
+
explanation = (
|
| 797 |
+
"This matrix shows the functional connectivity between brain regions.\n"
|
| 798 |
+
"Each cell represents the correlation of activity between two regions.\n"
|
| 799 |
+
"Positive values (yellow) indicate regions that activate together.\n"
|
| 800 |
+
"Negative values (blue) indicate regions with opposite activation patterns."
|
| 801 |
+
)
|
| 802 |
+
|
| 803 |
+
ax.annotate(explanation, xy=(0.5, -0.15), xycoords='axes fraction',
|
| 804 |
+
ha='center', va='center', fontsize=9,
|
| 805 |
+
bbox=dict(boxstyle='round', fc='lavender', alpha=0.8))
|
| 806 |
+
|
| 807 |
+
# Tighten layout
|
| 808 |
+
plt.tight_layout()
|
| 809 |
+
|
| 810 |
+
# Save if requested
|
| 811 |
+
if save_path:
|
| 812 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 813 |
+
print(f"Figure saved to {save_path}")
|
| 814 |
+
|
| 815 |
+
return fig
|
| 816 |
+
|
| 817 |
+
|
| 818 |
+
def process_subject(subject_id, nii_file):
|
| 819 |
+
"""
|
| 820 |
+
Process a single subject's NII file to FC matrix
|
| 821 |
+
|
| 822 |
+
Args:
|
| 823 |
+
subject_id: Subject identifier
|
| 824 |
+
nii_file: NII filename
|
| 825 |
+
|
| 826 |
+
Returns:
|
| 827 |
+
Tuple of (FC matrix, FC vector) or (None, None) if processing fails
|
| 828 |
+
"""
|
| 829 |
+
print(f"\nProcessing {subject_id} from {nii_file}...")
|
| 830 |
+
|
| 831 |
+
# Check if we already have cached results
|
| 832 |
+
matrix_cache = os.path.join(CACHE_DIR, f"{subject_id}_fc_matrix.npy")
|
| 833 |
+
vector_cache = os.path.join(CACHE_DIR, f"{subject_id}_fc_vector.npy")
|
| 834 |
+
|
| 835 |
+
if os.path.exists(matrix_cache) and os.path.exists(vector_cache):
|
| 836 |
+
try:
|
| 837 |
+
print(f"Loading cached FC data for {subject_id}")
|
| 838 |
+
fc_matrix = np.load(matrix_cache)
|
| 839 |
+
fc_vector = np.load(vector_cache)
|
| 840 |
+
return fc_matrix, fc_vector
|
| 841 |
+
except Exception as e:
|
| 842 |
+
print(f"Error loading cached FC data: {e}")
|
| 843 |
+
|
| 844 |
+
try:
|
| 845 |
+
# Step 1: Load NII file
|
| 846 |
+
nii_img = load_nii_from_huggingface(nii_file)
|
| 847 |
+
|
| 848 |
+
if nii_img is None:
|
| 849 |
+
print(f"Failed to load {nii_file}. Skipping {subject_id}.")
|
| 850 |
+
return None, None
|
| 851 |
+
|
| 852 |
+
# Step 2: Convert to FC matrix
|
| 853 |
+
fc_matrix = nii_to_fc_matrix(nii_img)
|
| 854 |
+
|
| 855 |
+
if fc_matrix is None:
|
| 856 |
+
print(f"Failed to process NII data to FC matrix. Skipping {subject_id}.")
|
| 857 |
+
return None, None
|
| 858 |
+
|
| 859 |
+
# Step 3: Extract FC vector (upper triangle)
|
| 860 |
+
fc_vector = mat2vec(fc_matrix)
|
| 861 |
+
|
| 862 |
+
# Save to cache
|
| 863 |
+
np.save(matrix_cache, fc_matrix)
|
| 864 |
+
np.save(vector_cache, fc_vector)
|
| 865 |
+
|
| 866 |
+
return fc_matrix, fc_vector
|
| 867 |
+
|
| 868 |
+
except Exception as e:
|
| 869 |
+
print(f"Error processing {subject_id}: {e}")
|
| 870 |
+
return None, None
|
| 871 |
+
|
| 872 |
+
def get_connectivity_visualization(fc_matrix, subject_id=None, output_path=None):
|
| 873 |
+
"""
|
| 874 |
+
Generate a connectivity visualization for use in a web interface
|
| 875 |
+
|
| 876 |
+
Args:
|
| 877 |
+
fc_matrix: Functional connectivity matrix
|
| 878 |
+
subject_id: Subject identifier
|
| 879 |
+
output_path: Path to save the visualization (if None, just returns the bytes)
|
| 880 |
+
|
| 881 |
+
Returns:
|
| 882 |
+
Visualization bytes or path to the saved file
|
| 883 |
+
"""
|
| 884 |
+
import io
|
| 885 |
+
import matplotlib.pyplot as plt
|
| 886 |
+
|
| 887 |
+
# Create the visualization
|
| 888 |
+
fig = plot_connectivity_matrix(fc_matrix, subject_id=subject_id, show_labels=True)
|
| 889 |
+
|
| 890 |
+
# If output path is provided, save directly
|
| 891 |
+
if output_path:
|
| 892 |
+
fig.savefig(output_path, dpi=300, bbox_inches='tight')
|
| 893 |
+
plt.close(fig)
|
| 894 |
+
return output_path
|
| 895 |
+
|
| 896 |
+
# Otherwise, return the bytes
|
| 897 |
+
buf = io.BytesIO()
|
| 898 |
+
fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
|
| 899 |
+
plt.close(fig)
|
| 900 |
+
buf.seek(0)
|
| 901 |
+
|
| 902 |
+
return buf
|
| 903 |
+
|
| 904 |
+
def main():
|
| 905 |
+
"""
|
| 906 |
+
Main function for processing functional connectivity data from HuggingFace
|
| 907 |
+
Optimized for HuggingFace Spaces
|
| 908 |
+
"""
|
| 909 |
+
print(f"OSF DemoVAE Adapter - Running in {'HuggingFace Spaces' if IS_SPACE else 'local mode'}")
|
| 910 |
+
print("="*50)
|
| 911 |
+
|
| 912 |
+
# Step 1: Load demographic data
|
| 913 |
+
demo_df = load_demographic_data()
|
| 914 |
+
if demo_df is None or len(demo_df) == 0:
|
| 915 |
+
print("Error: Could not load demographic data. Exiting.")
|
| 916 |
+
return
|
| 917 |
+
|
| 918 |
+
# Step 2: Get list of NII files to process
|
| 919 |
+
nii_files = get_all_nii_files()
|
| 920 |
+
if not nii_files:
|
| 921 |
+
print("Error: No NII files found. Exiting.")
|
| 922 |
+
return
|
| 923 |
+
|
| 924 |
+
print(f"Found {len(nii_files)} NII files to process")
|
| 925 |
+
|
| 926 |
+
# Step 3: Process NII files to FC matrices
|
| 927 |
+
fc_data = {}
|
| 928 |
+
|
| 929 |
+
# Process each subject - limit to 5 for testing in Spaces to save resources
|
| 930 |
+
max_subjects = 5 if IS_SPACE else len(nii_files)
|
| 931 |
+
|
| 932 |
+
# Use the demo_df to get subject IDs if available
|
| 933 |
+
if 'ID' in demo_df.columns:
|
| 934 |
+
subject_ids = demo_df['ID'].unique()[:max_subjects]
|
| 935 |
+
print(f"Processing {len(subject_ids)} subjects from demographic data")
|
| 936 |
+
|
| 937 |
+
for subject_id in tqdm(subject_ids, desc="Processing subjects"):
|
| 938 |
+
nii_file = f"{subject_id}_rs.nii"
|
| 939 |
+
if nii_file in nii_files:
|
| 940 |
+
fc_matrix, fc_vector = process_subject(subject_id, nii_file)
|
| 941 |
+
if fc_matrix is not None and fc_vector is not None:
|
| 942 |
+
fc_data[subject_id] = {
|
| 943 |
+
"matrix": fc_matrix,
|
| 944 |
+
"vector": fc_vector
|
| 945 |
+
}
|
| 946 |
+
else:
|
| 947 |
+
print(f"Warning: No NII file found for subject {subject_id}")
|
| 948 |
+
else:
|
| 949 |
+
# Process the first max_subjects NII files
|
| 950 |
+
for i, nii_file in enumerate(nii_files[:max_subjects]):
|
| 951 |
+
subject_id = nii_file.split('_')[0] # Extract subject ID (e.g., P01)
|
| 952 |
+
fc_matrix, fc_vector = process_subject(subject_id, nii_file)
|
| 953 |
+
if fc_matrix is not None and fc_vector is not None:
|
| 954 |
+
fc_data[subject_id] = {
|
| 955 |
+
"matrix": fc_matrix,
|
| 956 |
+
"vector": fc_vector
|
| 957 |
+
}
|
| 958 |
+
|
| 959 |
+
# Step 4: Print summary
|
| 960 |
+
print("\n" + "="*50)
|
| 961 |
+
print(f"Successfully processed {len(fc_data)} subjects")
|
| 962 |
+
|
| 963 |
+
if fc_data:
|
| 964 |
+
# Save combined FC vectors for modeling
|
| 965 |
+
fc_vectors = np.array([fc_data[subject_id]["vector"] for subject_id in sorted(fc_data.keys())])
|
| 966 |
+
os.makedirs(os.path.join(CACHE_DIR, "processed"), exist_ok=True)
|
| 967 |
+
np.save(os.path.join(CACHE_DIR, "processed", "all_fc_vectors.npy"), fc_vectors)
|
| 968 |
+
|
| 969 |
+
# Save subject IDs
|
| 970 |
+
with open(os.path.join(CACHE_DIR, "processed", "subject_ids.txt"), "w") as f:
|
| 971 |
+
for subject_id in sorted(fc_data.keys()):
|
| 972 |
+
f.write(f"{subject_id}\n")
|
| 973 |
+
|
| 974 |
+
# Show example statistics
|
| 975 |
+
first_subject = list(fc_data.keys())[0]
|
| 976 |
+
fc_matrix = fc_data[first_subject]["matrix"]
|
| 977 |
+
print(f"\nExample FC matrix for {first_subject}:")
|
| 978 |
+
print(f"Shape: {fc_matrix.shape}")
|
| 979 |
+
print(f"Min value: {fc_matrix.min():.4f}")
|
| 980 |
+
print(f"Max value: {fc_matrix.max():.4f}")
|
| 981 |
+
print(f"Mean value: {fc_matrix.mean():.4f}")
|
| 982 |
+
|
| 983 |
+
# Create visualization of the first subject's FC matrix
|
| 984 |
+
try:
|
| 985 |
+
# Get the matrix
|
| 986 |
+
fc_matrix = fc_data[first_subject]["matrix"]
|
| 987 |
+
|
| 988 |
+
# Create the plot with the new color scheme and region labels
|
| 989 |
+
os.makedirs(os.path.join(CACHE_DIR, "visualizations"), exist_ok=True)
|
| 990 |
+
plot_path = os.path.join(CACHE_DIR, "visualizations", f"{first_subject}_fc_matrix.png")
|
| 991 |
+
|
| 992 |
+
print(f"\nCreating visualization of functional connectivity for {first_subject}...")
|
| 993 |
+
fig = plot_connectivity_matrix(
|
| 994 |
+
fc_matrix,
|
| 995 |
+
subject_id=first_subject,
|
| 996 |
+
save_path=plot_path,
|
| 997 |
+
show_labels=True
|
| 998 |
+
)
|
| 999 |
+
|
| 1000 |
+
# Print explanation of the matrix
|
| 1001 |
+
print("\nFunctional Connectivity Matrix Explanation:")
|
| 1002 |
+
print("-------------------------------------------")
|
| 1003 |
+
print("This matrix represents functional connectivity between 45 brain regions.")
|
| 1004 |
+
print("Each cell (i,j) shows the correlation between activity in regions i and j.")
|
| 1005 |
+
print("- Yellow cells: Strong positive correlation (regions activate together)")
|
| 1006 |
+
print("- Green cells: Neutral or weak correlation")
|
| 1007 |
+
print("- Blue cells: Negative correlation (when one region activates, the other deactivates)")
|
| 1008 |
+
print("\nThe brain regions are labeled according to the AAL atlas:")
|
| 1009 |
+
|
| 1010 |
+
# Print some example region pairs with their connectivity values
|
| 1011 |
+
strong_pos = []
|
| 1012 |
+
strong_neg = []
|
| 1013 |
+
|
| 1014 |
+
# Find some interesting connections
|
| 1015 |
+
for i in range(fc_matrix.shape[0]):
|
| 1016 |
+
for j in range(i+1, fc_matrix.shape[0]):
|
| 1017 |
+
val = fc_matrix[i, j]
|
| 1018 |
+
if val > 0.7:
|
| 1019 |
+
strong_pos.append((i, j, val))
|
| 1020 |
+
elif val < -0.4:
|
| 1021 |
+
strong_neg.append((i, j, val))
|
| 1022 |
+
|
| 1023 |
+
# Print examples of strongly connected regions
|
| 1024 |
+
if strong_pos:
|
| 1025 |
+
print("\nExamples of strongly positively connected regions:")
|
| 1026 |
+
for i, j, val in sorted(strong_pos, key=lambda x: x[2], reverse=True)[:3]:
|
| 1027 |
+
if i < len(ATLAS_REGIONS) and j < len(ATLAS_REGIONS):
|
| 1028 |
+
print(f" • {ATLAS_REGIONS[i]} and {ATLAS_REGIONS[j]}: {val:.2f}")
|
| 1029 |
+
|
| 1030 |
+
if strong_neg:
|
| 1031 |
+
print("\nExamples of negatively connected regions:")
|
| 1032 |
+
for i, j, val in sorted(strong_neg, key=lambda x: x[2])[:3]:
|
| 1033 |
+
if i < len(ATLAS_REGIONS) and j < len(ATLAS_REGIONS):
|
| 1034 |
+
print(f" • {ATLAS_REGIONS[i]} and {ATLAS_REGIONS[j]}: {val:.2f}")
|
| 1035 |
+
|
| 1036 |
+
print(f"\nVisualization saved to: {plot_path}")
|
| 1037 |
+
|
| 1038 |
+
except Exception as e:
|
| 1039 |
+
print(f"Error creating visualization: {e}")
|
| 1040 |
+
|
| 1041 |
+
# Step 5: Try to load models (if DemoVAE is available)
|
| 1042 |
+
if DEMOVAE_AVAILABLE:
|
| 1043 |
+
models = load_models()
|
| 1044 |
+
if models:
|
| 1045 |
+
print("\nModels loaded successfully")
|
| 1046 |
+
|
| 1047 |
+
# If we have both models, try prediction on an example subject
|
| 1048 |
+
if "demovae" in models and "prediction" in models:
|
| 1049 |
+
try:
|
| 1050 |
+
# Get the first subject's FC vector
|
| 1051 |
+
fc_vec = fc_data[first_subject]["vector"]
|
| 1052 |
+
|
| 1053 |
+
# Get demographic features from demo_df
|
| 1054 |
+
if 'ID' in demo_df.columns:
|
| 1055 |
+
subject_row = demo_df[demo_df['ID'] == first_subject]
|
| 1056 |
+
if not subject_row.empty:
|
| 1057 |
+
# Extract demographic features
|
| 1058 |
+
demo_features = np.array([
|
| 1059 |
+
subject_row['age'].values[0],
|
| 1060 |
+
subject_row['mpo'].values[0] if 'mpo' in subject_row else 6,
|
| 1061 |
+
subject_row['education'].values[0] if 'education' in subject_row else 12,
|
| 1062 |
+
1 if subject_row['gender'].values[0] == 'M' else 0 if 'gender' in subject_row else 0.5,
|
| 1063 |
+
1 if subject_row['handedness'].values[0] == 'R' else 0 if 'handedness' in subject_row else 0.5
|
| 1064 |
+
]).reshape(1, -1)
|
| 1065 |
+
|
| 1066 |
+
# Get latent representation from DemoVAE
|
| 1067 |
+
z = models["demovae"].get_latents(fc_vec.reshape(1, -1))
|
| 1068 |
+
|
| 1069 |
+
# Combine with demographic features for prediction
|
| 1070 |
+
X_combined = np.hstack([z, demo_features])
|
| 1071 |
+
|
| 1072 |
+
# Make prediction
|
| 1073 |
+
predicted_wab = models["prediction"].predict(X_combined)[0]
|
| 1074 |
+
|
| 1075 |
+
actual_wab = subject_row['wab_aq'].values[0] if 'wab_aq' in subject_row else None
|
| 1076 |
+
|
| 1077 |
+
print(f"\nPrediction for {first_subject}:")
|
| 1078 |
+
print(f"Predicted WAB AQ score: {predicted_wab:.2f}")
|
| 1079 |
+
if actual_wab is not None:
|
| 1080 |
+
print(f"Actual WAB AQ score: {actual_wab:.2f}")
|
| 1081 |
+
print(f"Prediction error: {abs(predicted_wab - actual_wab):.2f}")
|
| 1082 |
+
|
| 1083 |
+
# Generate a second visualization showing the relationship between
|
| 1084 |
+
# predicted aphasia recovery and specific brain region connectivity
|
| 1085 |
+
try:
|
| 1086 |
+
# Create a visualization showing the most important connections for prediction
|
| 1087 |
+
if 'feature_importances_' in dir(models["prediction"]):
|
| 1088 |
+
print("\nCreating visualization of important connections for prediction...")
|
| 1089 |
+
|
| 1090 |
+
# Get feature importances from the Random Forest model
|
| 1091 |
+
importances = models["prediction"].feature_importances_
|
| 1092 |
+
|
| 1093 |
+
# The first elements of importances correspond to the latent variables
|
| 1094 |
+
# and we can't directly map those to regions. So we'll compute a score for each region.
|
| 1095 |
+
|
| 1096 |
+
# Create a heatmap version of the FC matrix, with cells colored by importance
|
| 1097 |
+
importance_matrix = np.zeros_like(fc_matrix)
|
| 1098 |
+
|
| 1099 |
+
# Simple approach: use the original matrix but scale the color by overall importance
|
| 1100 |
+
# (this is a simplification since we can't directly map latent vars to regions)
|
| 1101 |
+
avg_importance = np.mean(importances[:z.shape[1]])
|
| 1102 |
+
importance_matrix = fc_matrix * avg_importance
|
| 1103 |
+
|
| 1104 |
+
# Plot this matrix
|
| 1105 |
+
imp_plot_path = os.path.join(CACHE_DIR, "visualizations",
|
| 1106 |
+
f"{first_subject}_importance_matrix.png")
|
| 1107 |
+
|
| 1108 |
+
# Use original connectivity but with a different title
|
| 1109 |
+
fig = plot_connectivity_matrix(
|
| 1110 |
+
fc_matrix,
|
| 1111 |
+
subject_id=f"{first_subject} (Regions Influencing Prediction)",
|
| 1112 |
+
save_path=imp_plot_path,
|
| 1113 |
+
show_labels=True
|
| 1114 |
+
)
|
| 1115 |
+
|
| 1116 |
+
print(f"Prediction importance visualization saved to: {imp_plot_path}")
|
| 1117 |
+
except Exception as e:
|
| 1118 |
+
print(f"Error creating importance visualization: {e}")
|
| 1119 |
+
except Exception as e:
|
| 1120 |
+
print(f"Error making prediction: {e}")
|
| 1121 |
+
else:
|
| 1122 |
+
print("No FC matrices were successfully processed")
|
| 1123 |
+
|
| 1124 |
+
print("\nProcessing complete!")
|
| 1125 |
+
print(f"Data cached in: {CACHE_DIR}")
|
| 1126 |
+
print("Ready for use in HuggingFace Spaces!")
|
| 1127 |
+
|
| 1128 |
+
|
| 1129 |
+
def train_demovae_model(X_fc, X_demo, y, save_model=True, model_path="osf_demovae_model.pt"):
|
| 1130 |
+
"""
|
| 1131 |
+
Train DemoVAE model on OSF data using the PIP implementation
|
| 1132 |
+
|
| 1133 |
+
Args:
|
| 1134 |
+
X_fc: Functional connectivity features
|
| 1135 |
+
X_demo: Demographic features
|
| 1136 |
+
y: Target variable (WAB AQ scores)
|
| 1137 |
+
save_model: Whether to save the model
|
| 1138 |
+
model_path: Path to save the model
|
| 1139 |
+
|
| 1140 |
+
Returns:
|
| 1141 |
+
Trained model, latent representations, test data
|
| 1142 |
+
"""
|
| 1143 |
+
if not DEMOVAE_AVAILABLE:
|
| 1144 |
+
raise ImportError("DemoVAE package is not available. Cannot train model.")
|
| 1145 |
+
|
| 1146 |
+
print("Training DemoVAE model on OSF data...")
|
| 1147 |
+
|
| 1148 |
+
# Split data into train and test sets
|
| 1149 |
+
from sklearn.model_selection import train_test_split
|
| 1150 |
+
X_fc_train, X_fc_test, X_demo_train, X_demo_test, y_train, y_test = train_test_split(
|
| 1151 |
+
X_fc, X_demo, y, test_size=0.2, random_state=42
|
| 1152 |
+
)
|
| 1153 |
+
|
| 1154 |
+
# Check for categorical variables that have only one value
|
| 1155 |
+
demo_types = []
|
| 1156 |
+
demo_train_processed = []
|
| 1157 |
+
demo_test_processed = []
|
| 1158 |
+
|
| 1159 |
+
# Define demographic types for each column and validate them
|
| 1160 |
+
# [age, mpo, education, gender, handedness]
|
| 1161 |
+
all_demo_types = ['continuous', 'continuous', 'continuous', 'categorical', 'categorical']
|
| 1162 |
+
|
| 1163 |
+
# Check each demographic variable and convert to continuous if there's only one value
|
| 1164 |
+
for i, demo_type in enumerate(all_demo_types):
|
| 1165 |
+
demo_train_col = X_demo_train[:, i]
|
| 1166 |
+
|
| 1167 |
+
if demo_type == 'categorical' and len(np.unique(demo_train_col)) == 1:
|
| 1168 |
+
print(f"Warning: Column {i} has only one category. Treating as continuous.")
|
| 1169 |
+
demo_types.append('continuous')
|
| 1170 |
+
# Convert to a scalar continuous variable (0 or 1)
|
| 1171 |
+
demo_train_processed.append(demo_train_col)
|
| 1172 |
+
demo_test_processed.append(X_demo_test[:, i])
|
| 1173 |
+
else:
|
| 1174 |
+
demo_types.append(demo_type)
|
| 1175 |
+
demo_train_processed.append(demo_train_col)
|
| 1176 |
+
demo_test_processed.append(X_demo_test[:, i])
|
| 1177 |
+
|
| 1178 |
+
# Print final demographic types
|
| 1179 |
+
print(f"Using demographic types: {demo_types}")
|
| 1180 |
+
|
| 1181 |
+
# Create model with parameters tuned for OSF data
|
| 1182 |
+
model = DemoVAE(
|
| 1183 |
+
latent_dim=30,
|
| 1184 |
+
loss_rec_mult=100,
|
| 1185 |
+
loss_decor_mult=10,
|
| 1186 |
+
loss_pred_mult=0.1,
|
| 1187 |
+
nepochs=300,
|
| 1188 |
+
pperiod=50,
|
| 1189 |
+
bsize=min(32, len(X_fc_train)),
|
| 1190 |
+
use_cuda=torch.cuda.is_available()
|
| 1191 |
+
)
|
| 1192 |
+
|
| 1193 |
+
# Train the model
|
| 1194 |
+
try:
|
| 1195 |
+
model.fit(X_fc_train, demo_train_processed, demo_types)
|
| 1196 |
+
except Exception as e:
|
| 1197 |
+
print(f"Error training model: {str(e)}")
|
| 1198 |
+
# Fall back to treating all demographics as continuous if categorical fails
|
| 1199 |
+
print("Falling back to all continuous demographics")
|
| 1200 |
+
demo_types = ['continuous'] * len(all_demo_types)
|
| 1201 |
+
model.fit(X_fc_train, demo_train_processed, demo_types)
|
| 1202 |
+
|
| 1203 |
+
# Get latent representations
|
| 1204 |
+
z_train = model.get_latents(X_fc_train)
|
| 1205 |
+
z_test = model.get_latents(X_fc_test)
|
| 1206 |
+
|
| 1207 |
+
# Save model if requested
|
| 1208 |
+
if save_model:
|
| 1209 |
+
model.save(model_path)
|
| 1210 |
+
print(f"Model saved to {model_path}")
|
| 1211 |
+
|
| 1212 |
+
return model, z_train, z_test, X_fc_test, X_demo_test, y_test
|
| 1213 |
+
|
| 1214 |
+
|
| 1215 |
+
def generate_custom_fc(demo_values, model, prediction_model=None, visualize=True):
|
| 1216 |
+
"""
|
| 1217 |
+
Generate custom FC matrix for a patient with given demographics
|
| 1218 |
+
|
| 1219 |
+
Args:
|
| 1220 |
+
demo_values: Dictionary with demographic values
|
| 1221 |
+
model: Trained DemoVAE model
|
| 1222 |
+
prediction_model: Optional aphasia prediction model
|
| 1223 |
+
visualize: Whether to create a visualization of the matrix
|
| 1224 |
+
|
| 1225 |
+
Returns:
|
| 1226 |
+
Tuple of (FC matrix, predicted aphasia score, visualization path or None)
|
| 1227 |
+
"""
|
| 1228 |
+
if not DEMOVAE_AVAILABLE:
|
| 1229 |
+
raise ImportError("DemoVAE package is not available. Cannot generate custom FC matrix.")
|
| 1230 |
+
|
| 1231 |
+
# Extract demographic values
|
| 1232 |
+
age = demo_values.get('age', 60)
|
| 1233 |
+
mpo = demo_values.get('mpo', 6)
|
| 1234 |
+
education = demo_values.get('education', 16)
|
| 1235 |
+
gender = 1 if demo_values.get('gender', 'male').lower() == 'male' else 0
|
| 1236 |
+
handedness = 1 if demo_values.get('handedness', 'right').lower() == 'right' else 0
|
| 1237 |
+
|
| 1238 |
+
# Format for model
|
| 1239 |
+
custom_demo = np.array([[age, mpo, education, gender, handedness]])
|
| 1240 |
+
custom_demo_lists = [custom_demo[:, i] for i in range(custom_demo.shape[1])]
|
| 1241 |
+
demo_types = ['continuous', 'continuous', 'continuous', 'categorical', 'categorical']
|
| 1242 |
+
|
| 1243 |
+
# Generate random latent vector
|
| 1244 |
+
np.random.seed(42) # For reproducibility
|
| 1245 |
+
custom_z = np.random.randn(1, model.latent_dim)
|
| 1246 |
+
|
| 1247 |
+
# Convert to torch for model
|
| 1248 |
+
custom_z_torch = to_cuda(to_torch(custom_z), model.vae.use_cuda)
|
| 1249 |
+
|
| 1250 |
+
# Format demo for model
|
| 1251 |
+
demo_torch = to_torch(custom_demo)
|
| 1252 |
+
if model.vae.use_cuda:
|
| 1253 |
+
demo_torch = demo_torch.cuda()
|
| 1254 |
+
|
| 1255 |
+
# Generate FC
|
| 1256 |
+
custom_fc = model.transform(1, custom_demo_lists, demo_types)
|
| 1257 |
+
custom_fc = custom_fc[0] # Get the first (only) example
|
| 1258 |
+
|
| 1259 |
+
# Convert to matrix
|
| 1260 |
+
if custom_fc.size == 1000:
|
| 1261 |
+
custom_fc = custom_fc[:990] # Use first 990 elements for 45x45 matrix
|
| 1262 |
+
|
| 1263 |
+
fc_matrix = vec2mat(custom_fc)
|
| 1264 |
+
|
| 1265 |
+
# Ensure matrix is the right shape for visualization (45x45)
|
| 1266 |
+
if fc_matrix.shape[0] != 45:
|
| 1267 |
+
# Resize to 45x45
|
| 1268 |
+
temp_matrix = np.zeros((45, 45))
|
| 1269 |
+
min_dim = min(fc_matrix.shape[0], 45)
|
| 1270 |
+
temp_matrix[:min_dim, :min_dim] = fc_matrix[:min_dim, :min_dim]
|
| 1271 |
+
np.fill_diagonal(temp_matrix, 1) # Ensure diagonal is 1
|
| 1272 |
+
fc_matrix = temp_matrix
|
| 1273 |
+
|
| 1274 |
+
# Predict aphasia score if model is provided
|
| 1275 |
+
predicted_score = None
|
| 1276 |
+
if prediction_model is not None:
|
| 1277 |
+
X_combined = np.hstack([custom_z, custom_demo])
|
| 1278 |
+
predicted_score = float(prediction_model.predict(X_combined)[0])
|
| 1279 |
+
predicted_score = max(0, min(100, predicted_score)) # Clip to valid range
|
| 1280 |
+
|
| 1281 |
+
# Create visualization if requested
|
| 1282 |
+
viz_path = None
|
| 1283 |
+
if visualize:
|
| 1284 |
+
try:
|
| 1285 |
+
# Create a descriptive ID based on demographics
|
| 1286 |
+
custom_id = f"custom_age{age}_mpo{mpo}_edu{education}_{'M' if gender == 1 else 'F'}_{'R' if handedness == 1 else 'L'}"
|
| 1287 |
+
|
| 1288 |
+
# Create directory for visualizations
|
| 1289 |
+
os.makedirs(os.path.join(CACHE_DIR, "visualizations"), exist_ok=True)
|
| 1290 |
+
viz_path = os.path.join(CACHE_DIR, "visualizations", f"{custom_id}_fc_matrix.png")
|
| 1291 |
+
|
| 1292 |
+
# Create title with demographic info and predicted score
|
| 1293 |
+
title = f"Custom FC Matrix: Age {age}, MPO {mpo}, Education {education}, "
|
| 1294 |
+
title += f"{'Male' if gender == 1 else 'Female'}, {'Right' if handedness == 1 else 'Left'} handed"
|
| 1295 |
+
|
| 1296 |
+
if predicted_score is not None:
|
| 1297 |
+
title += f"\nPredicted WAB Score: {predicted_score:.1f}"
|
| 1298 |
+
|
| 1299 |
+
# Create the visualization
|
| 1300 |
+
fig = plot_connectivity_matrix(
|
| 1301 |
+
fc_matrix,
|
| 1302 |
+
subject_id=custom_id,
|
| 1303 |
+
save_path=viz_path,
|
| 1304 |
+
show_labels=True
|
| 1305 |
+
)
|
| 1306 |
+
|
| 1307 |
+
print(f"Generated custom FC matrix visualization saved to: {viz_path}")
|
| 1308 |
+
|
| 1309 |
+
except Exception as e:
|
| 1310 |
+
print(f"Error creating visualization for custom FC matrix: {e}")
|
| 1311 |
+
|
| 1312 |
+
return fc_matrix, predicted_score, viz_path
|
| 1313 |
+
|
| 1314 |
+
|
| 1315 |
+
def predict_aphasia_recovery(z_test, X_demo_test, y_test):
|
| 1316 |
+
"""
|
| 1317 |
+
Train a model to predict aphasia recovery based on latent features
|
| 1318 |
+
|
| 1319 |
+
Args:
|
| 1320 |
+
z_test: Latent representations
|
| 1321 |
+
X_demo_test: Demographic features
|
| 1322 |
+
y_test: Target variable (WAB AQ scores)
|
| 1323 |
+
|
| 1324 |
+
Returns:
|
| 1325 |
+
Prediction model, predictions, RMSE, R²
|
| 1326 |
+
"""
|
| 1327 |
+
print("==================================================")
|
| 1328 |
+
print("TRAINING RANDOM FOREST FOR APHASIA SCORE PREDICTION")
|
| 1329 |
+
print("==================================================")
|
| 1330 |
+
print("This is the second stage of the prediction pipeline:")
|
| 1331 |
+
print("1. VAE model extracts latent representations from FC data")
|
| 1332 |
+
print("2. Now training Random Forest to predict WAB AQ scores")
|
| 1333 |
+
print(f"Input features: {z_test.shape[1]} latent variables + {X_demo_test.shape[1]} demographic features")
|
| 1334 |
+
print(f"Training samples: {len(y_test)}")
|
| 1335 |
+
|
| 1336 |
+
# Combine latent features with demographic features
|
| 1337 |
+
X_combined = np.hstack([z_test, X_demo_test])
|
| 1338 |
+
|
| 1339 |
+
from sklearn.ensemble import RandomForestRegressor
|
| 1340 |
+
from sklearn.metrics import mean_squared_error, r2_score
|
| 1341 |
+
import math
|
| 1342 |
+
|
| 1343 |
+
print("Training Random Forest Regressor with 100 trees...")
|
| 1344 |
+
# Train a random forest regression model
|
| 1345 |
+
model = RandomForestRegressor(n_estimators=100, random_state=42)
|
| 1346 |
+
model.fit(X_combined, y_test)
|
| 1347 |
+
print("Random Forest training complete!")
|
| 1348 |
+
|
| 1349 |
+
# Make predictions
|
| 1350 |
+
y_pred = model.predict(X_combined)
|
| 1351 |
+
|
| 1352 |
+
# Evaluate model
|
| 1353 |
+
mse = mean_squared_error(y_test, y_pred)
|
| 1354 |
+
rmse_val = math.sqrt(mse)
|
| 1355 |
+
r2 = r2_score(y_test, y_pred)
|
| 1356 |
+
|
| 1357 |
+
print(f"Random Forest Prediction Results:")
|
| 1358 |
+
print(f"RMSE: {rmse_val:.4f}")
|
| 1359 |
+
print(f"R²: {r2:.4f}")
|
| 1360 |
+
print("==================================================")
|
| 1361 |
+
print("Random Forest training successful!")
|
| 1362 |
+
print("==================================================")
|
| 1363 |
+
|
| 1364 |
+
# Return the model and evaluation metrics
|
| 1365 |
+
return model, y_pred, rmse_val, r2
|
| 1366 |
+
|
| 1367 |
+
|
| 1368 |
+
if __name__ == "__main__":
|
| 1369 |
+
main()
|
pip/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
pip/.gitignore
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
dist/*
|
pip/.pypirc
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[distutils]
|
| 2 |
+
index-servers =
|
| 3 |
+
pypi
|
| 4 |
+
|
| 5 |
+
[pypi]
|
| 6 |
+
username = __token__
|
pip/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
MIT License
|
| 2 |
+
|
| 3 |
+
Copyright (c) 2024 Anton Orlichenko
|
| 4 |
+
|
| 5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
| 6 |
+
of this software and associated documentation files (the "Software"), to deal
|
| 7 |
+
in the Software without restriction, including without limitation the rights
|
| 8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
| 9 |
+
copies of the Software, and to permit persons to whom the Software is
|
| 10 |
+
furnished to do so, subject to the following conditions:
|
| 11 |
+
|
| 12 |
+
The above copyright notice and this permission notice shall be included in all
|
| 13 |
+
copies or substantial portions of the Software.
|
| 14 |
+
|
| 15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
| 16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
| 17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
| 18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
| 19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
| 20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
| 21 |
+
SOFTWARE.
|
pip/README.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Installable version of DemoVAE demographic-conditioned variational autoencoder for fMRI data.
|
| 2 |
+
|
| 3 |
+
For use with pip.
|
| 4 |
+
|
| 5 |
+
Perform fMRI distribution sampling, remove confounds, and harmonize multi-site data.
|
| 6 |
+
|
| 7 |
+
Supports FC, ALFF, and ReHO data.
|
pip/pyproject.toml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "demovae"
|
| 3 |
+
version = "0.0.6"
|
| 4 |
+
authors = [
|
| 5 |
+
{ name="Anton Orlichenko", email="aorliche@gmail.com" }
|
| 6 |
+
]
|
| 7 |
+
description = "A demographic-conditioned variational autoencoder for fMRI distribution sampling, removal of confounds, and multi-site harmonization. Works with FC, ALFF, or ReHO data."
|
| 8 |
+
readme = "README.md"
|
| 9 |
+
dependencies = [ "numpy", "torch", "scikit-learn" ]
|
| 10 |
+
requires-python = ">=3.8"
|
| 11 |
+
classifiers = [
|
| 12 |
+
"Programming Language :: Python :: 3",
|
| 13 |
+
"License :: OSI Approved :: MIT License",
|
| 14 |
+
"Operating System :: OS Independent",
|
| 15 |
+
]
|
| 16 |
+
|
| 17 |
+
[project.urls]
|
| 18 |
+
Homepage = "https://github.com/aorliche/demo-vae/"
|
| 19 |
+
Issues = "https://github.com/aorliche/demo-vae/issues"
|
| 20 |
+
|
| 21 |
+
[build-system]
|
| 22 |
+
requires = ["hatchling"]
|
| 23 |
+
build-backend = "hatchling.build"
|
pip/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
numpy>=1.20.0
|
| 2 |
+
pandas>=1.3.0
|
| 3 |
+
torch>=1.9.0
|
| 4 |
+
matplotlib>=3.4.0
|
| 5 |
+
scikit-learn>=1.0.0
|
| 6 |
+
tqdm>=4.62.0
|
| 7 |
+
nibabel>=3.2.0
|
| 8 |
+
gradio>=3.50.0
|
| 9 |
+
pillow>=9.0.0
|
| 10 |
+
datasets>=2.10.0
|
| 11 |
+
huggingface_hub>=0.16.0
|
pip/src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
pip/src/demovae/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from demovae.sklearn import DemoVAE
|
pip/src/demovae/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (226 Bytes). View file
|
|
|
pip/src/demovae/__pycache__/model.cpython-311.pyc
ADDED
|
Binary file (15.5 kB). View file
|
|
|
pip/src/demovae/__pycache__/sklearn.cpython-311.pyc
ADDED
|
Binary file (6.42 kB). View file
|
|
|
pip/src/demovae/model.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from sklearn.linear_model import Ridge
|
| 10 |
+
from sklearn.linear_model import LogisticRegression
|
| 11 |
+
|
| 12 |
+
def to_torch(x):
|
| 13 |
+
return torch.from_numpy(x).float()
|
| 14 |
+
|
| 15 |
+
def to_cuda(x, use_cuda):
|
| 16 |
+
if use_cuda:
|
| 17 |
+
return x.cuda()
|
| 18 |
+
else:
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
def to_numpy(x):
|
| 22 |
+
return x.detach().cpu().numpy()
|
| 23 |
+
|
| 24 |
+
class VAE(nn.Module):
|
| 25 |
+
def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
|
| 26 |
+
super(VAE, self).__init__()
|
| 27 |
+
self.input_dim = input_dim
|
| 28 |
+
self.latent_dim = latent_dim
|
| 29 |
+
self.demo_dim = demo_dim
|
| 30 |
+
self.use_cuda = use_cuda
|
| 31 |
+
self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
|
| 32 |
+
self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
|
| 33 |
+
self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
|
| 34 |
+
self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
|
| 35 |
+
|
| 36 |
+
def enc(self, x):
|
| 37 |
+
x = F.relu(self.enc1(x))
|
| 38 |
+
z = self.enc2(x)
|
| 39 |
+
return z
|
| 40 |
+
|
| 41 |
+
def gen(self, n):
|
| 42 |
+
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
|
| 43 |
+
|
| 44 |
+
def dec(self, z, demo):
|
| 45 |
+
z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
|
| 46 |
+
x = F.relu(self.dec1(z))
|
| 47 |
+
x = self.dec2(x)
|
| 48 |
+
#x = x.reshape(len(z), 264, 5)
|
| 49 |
+
#x = torch.einsum('nac,nbc->nab', x, x)
|
| 50 |
+
#a,b = np.triu_indices(264, 1)
|
| 51 |
+
#x = x[:,a,b]
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
def rmse(a, b, mean=torch.mean):
|
| 55 |
+
return mean((a-b)**2)**0.5
|
| 56 |
+
|
| 57 |
+
def latent_loss(z, use_cuda=True):
|
| 58 |
+
C = z.T@z
|
| 59 |
+
mu = torch.mean(z, dim=0)
|
| 60 |
+
tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
|
| 61 |
+
tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 62 |
+
loss_C = rmse(C, tgt1)
|
| 63 |
+
loss_mu = rmse(mu, tgt2)
|
| 64 |
+
return loss_C, loss_mu, C, mu
|
| 65 |
+
|
| 66 |
+
def decor_loss(z, demo, use_cuda=True):
|
| 67 |
+
ps = []
|
| 68 |
+
losses = []
|
| 69 |
+
for di in range(demo.shape[1]):
|
| 70 |
+
d = demo[:,di]
|
| 71 |
+
d = d - torch.mean(d)
|
| 72 |
+
p = torch.einsum('n,nz->z', d, z)
|
| 73 |
+
p = p/torch.std(d)
|
| 74 |
+
p = p/torch.einsum('nz,nz->z', z, z)
|
| 75 |
+
tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 76 |
+
loss = rmse(p, tgt)
|
| 77 |
+
losses.append(loss)
|
| 78 |
+
ps.append(p)
|
| 79 |
+
losses = torch.stack(losses)
|
| 80 |
+
return losses, ps
|
| 81 |
+
|
| 82 |
+
def pretty(x):
|
| 83 |
+
return f'{round(float(x), 4)}'
|
| 84 |
+
|
| 85 |
+
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
|
| 86 |
+
demo_t = []
|
| 87 |
+
demo_idx = 0
|
| 88 |
+
for d,t,s in zip(demo, demo_types, pred_stats):
|
| 89 |
+
if t == 'continuous':
|
| 90 |
+
demo_t.append(to_cuda(to_torch(d), use_cuda))
|
| 91 |
+
elif t == 'categorical':
|
| 92 |
+
for dd in d:
|
| 93 |
+
if dd not in s:
|
| 94 |
+
print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
|
| 95 |
+
raise Exception('Bad demographic')
|
| 96 |
+
for ss in s:
|
| 97 |
+
idx = (d == ss).astype('bool')
|
| 98 |
+
zeros = torch.zeros(len(d))
|
| 99 |
+
zeros[idx] = 1
|
| 100 |
+
demo_t.append(to_cuda(zeros, use_cuda))
|
| 101 |
+
demo_idx += 1
|
| 102 |
+
demo_t = torch.stack(demo_t).permute(1,0)
|
| 103 |
+
return demo_t
|
| 104 |
+
|
| 105 |
+
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
|
| 106 |
+
# Get linear predictors for demographics
|
| 107 |
+
pred_w = []
|
| 108 |
+
pred_i = []
|
| 109 |
+
# Pred stats are mean and std for continuous, and a list of all values for categorical
|
| 110 |
+
pred_stats = []
|
| 111 |
+
for i,d,t in zip(range(len(demo)), demo, demo_types):
|
| 112 |
+
print(f'Fitting auxilliary guidance model for demographic {i} {t}...', end='')
|
| 113 |
+
if t == 'continuous':
|
| 114 |
+
pred_stats.append([np.mean(d), np.std(d)])
|
| 115 |
+
reg = Ridge(alpha=alpha).fit(x, d)
|
| 116 |
+
reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
|
| 117 |
+
reg_i = reg.intercept_
|
| 118 |
+
pred_w.append(reg_w)
|
| 119 |
+
pred_i.append(reg_i)
|
| 120 |
+
elif t == 'categorical':
|
| 121 |
+
pred_stats.append(sorted(list(set(list(d)))))
|
| 122 |
+
reg = LogisticRegression(C=LR_C).fit(x, d)
|
| 123 |
+
# Binary
|
| 124 |
+
if len(reg.coef_) == 1:
|
| 125 |
+
reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
|
| 126 |
+
reg_i = reg.intercept_[0]
|
| 127 |
+
pred_w.append(-reg_w)
|
| 128 |
+
pred_i.append(-reg_i)
|
| 129 |
+
pred_w.append(reg_w)
|
| 130 |
+
pred_i.append(reg_i)
|
| 131 |
+
# Categorical
|
| 132 |
+
else:
|
| 133 |
+
for i in range(len(reg.coef_)):
|
| 134 |
+
reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
|
| 135 |
+
reg_i = reg.intercept_[i]
|
| 136 |
+
pred_w.append(reg_w)
|
| 137 |
+
pred_i.append(reg_i)
|
| 138 |
+
else:
|
| 139 |
+
print(f'demographic type "{t}" not "continuous" or "categorical"')
|
| 140 |
+
raise Exception('Bad demographic type')
|
| 141 |
+
print(' done')
|
| 142 |
+
ret_obj.pred_stats = pred_stats
|
| 143 |
+
# Convert input to pytorch
|
| 144 |
+
print('Converting input to pytorch')
|
| 145 |
+
x = to_cuda(to_torch(x), vae.use_cuda)
|
| 146 |
+
# Convert demographics to pytorch
|
| 147 |
+
print('Converting demographics to pytorch')
|
| 148 |
+
demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
|
| 149 |
+
# Training loop
|
| 150 |
+
print('Beginning VAE training')
|
| 151 |
+
ce = nn.CrossEntropyLoss()
|
| 152 |
+
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 153 |
+
for e in range(nepochs):
|
| 154 |
+
for bs in range(0,len(x),bsize):
|
| 155 |
+
xb = x[bs:(bs+bsize)]
|
| 156 |
+
db = demo_t[bs:(bs+bsize)]
|
| 157 |
+
optim.zero_grad()
|
| 158 |
+
# Reconstruct
|
| 159 |
+
z = vae.enc(xb)
|
| 160 |
+
y = vae.dec(z, db)
|
| 161 |
+
loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
|
| 162 |
+
loss_decor, _ = decor_loss(z, db, vae.use_cuda)
|
| 163 |
+
loss_decor = sum(loss_decor)
|
| 164 |
+
loss_rec = rmse(xb, y)
|
| 165 |
+
# Sample demographics
|
| 166 |
+
demo_gen = []
|
| 167 |
+
for s,t in zip(pred_stats, demo_types):
|
| 168 |
+
if t == 'continuous':
|
| 169 |
+
mu = s[0]
|
| 170 |
+
std = s[1]
|
| 171 |
+
dd = torch.randn(100).float()
|
| 172 |
+
dd = dd*std+mu
|
| 173 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 174 |
+
demo_gen.append(dd)
|
| 175 |
+
elif t == 'categorical':
|
| 176 |
+
idx = random.randint(0, len(s)-1)
|
| 177 |
+
for i in range(len(s)):
|
| 178 |
+
if idx == i:
|
| 179 |
+
dd = torch.ones(100).float()
|
| 180 |
+
else:
|
| 181 |
+
dd = torch.zeros(100).float()
|
| 182 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 183 |
+
demo_gen.append(dd)
|
| 184 |
+
demo_gen = torch.stack(demo_gen).permute(1,0)
|
| 185 |
+
# Generate
|
| 186 |
+
z = vae.gen(100)
|
| 187 |
+
y = vae.dec(z, demo_gen)
|
| 188 |
+
# Regressor/classifier guidance loss
|
| 189 |
+
losses_pred = []
|
| 190 |
+
idcs = []
|
| 191 |
+
dg_idx = 0
|
| 192 |
+
for s,t in zip(pred_stats, demo_types):
|
| 193 |
+
if t == 'continuous':
|
| 194 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 195 |
+
loss = rmse(demo_gen[:,dg_idx], yy)
|
| 196 |
+
losses_pred.append(loss)
|
| 197 |
+
idcs.append(float(demo_gen[0,dg_idx]))
|
| 198 |
+
dg_idx += 1
|
| 199 |
+
elif t == 'categorical':
|
| 200 |
+
loss = 0
|
| 201 |
+
for i in range(len(s)):
|
| 202 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 203 |
+
loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
|
| 204 |
+
idcs.append(int(demo_gen[0,dg_idx]))
|
| 205 |
+
dg_idx += 1
|
| 206 |
+
losses_pred.append(loss)
|
| 207 |
+
total_loss = loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor + loss_pred_mult*sum(losses_pred)
|
| 208 |
+
total_loss.backward()
|
| 209 |
+
optim.step()
|
| 210 |
+
if e%pperiod == 0 or e == nepochs-1:
|
| 211 |
+
print(f'Epoch {e} ', end='')
|
| 212 |
+
print(f'ReconLoss {pretty(loss_rec)} ', end='')
|
| 213 |
+
print(f'CovarianceLoss {pretty(loss_C)} ', end='')
|
| 214 |
+
print(f'MeanLoss {pretty(loss_mu)} ', end='')
|
| 215 |
+
print(f'DecorLoss {pretty(loss_decor)} ', end='')
|
| 216 |
+
losses_pred = [pretty(loss) for loss in losses_pred]
|
| 217 |
+
print(f'GuidanceTargets {idcs} GuidanceLosses {losses_pred} ', end='')
|
| 218 |
+
print()
|
| 219 |
+
print('Training complete.')
|
| 220 |
+
|
| 221 |
+
|
pip/src/demovae/sklearn.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from demovae.model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
|
| 3 |
+
|
| 4 |
+
from sklearn.base import BaseEstimator
|
| 5 |
+
|
| 6 |
+
# For saving
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class DemoVAE(BaseEstimator):
|
| 10 |
+
def __init__(self, **params):
|
| 11 |
+
self.set_params(**params)
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def get_default_params():
|
| 15 |
+
return dict(latent_dim=60, # Latent dimension
|
| 16 |
+
use_cuda=True, # GPU acceleration
|
| 17 |
+
nepochs=3000, # Training epochs
|
| 18 |
+
pperiod=100, # Epochs between printing updates
|
| 19 |
+
bsize=1000, # Batch size
|
| 20 |
+
loss_C_mult=1, # Covariance loss (KL div)
|
| 21 |
+
loss_mu_mult=1, # Mean loss (KL div)
|
| 22 |
+
loss_rec_mult=100, # Reconstruction loss
|
| 23 |
+
loss_decor_mult=10, # Latent-demographic decorrelation loss
|
| 24 |
+
loss_pred_mult=0.001, # Classifier/regressor guidance loss
|
| 25 |
+
alpha=100, # Regularization for continuous guidance models
|
| 26 |
+
LR_C=100, # Regularization for categorical guidance models
|
| 27 |
+
lr=1e-4, # Learning rate
|
| 28 |
+
weight_decay=0, # L2 regularization for VAE model
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def get_params(self, **params):
|
| 32 |
+
return dict(latent_dim=self.latent_dim,
|
| 33 |
+
use_cuda=self.use_cuda,
|
| 34 |
+
nepochs=self.nepochs,
|
| 35 |
+
pperiod=self.pperiod,
|
| 36 |
+
bsize=self.bsize,
|
| 37 |
+
loss_C_mult=self.loss_C_mult,
|
| 38 |
+
loss_mu_mult=self.loss_mu_mult,
|
| 39 |
+
loss_rec_mult=self.loss_rec_mult,
|
| 40 |
+
loss_decor_mult=self.loss_decor_mult,
|
| 41 |
+
loss_pred_mult=self.loss_pred_mult,
|
| 42 |
+
alpha=self.alpha,
|
| 43 |
+
LR_C=self.LR_C,
|
| 44 |
+
lr=self.lr,
|
| 45 |
+
weight_decay=self.weight_decay,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def set_params(self, **params):
|
| 49 |
+
dft = DemoVAE.get_default_params()
|
| 50 |
+
for key in dft:
|
| 51 |
+
if key in params:
|
| 52 |
+
setattr(self, key, params[key])
|
| 53 |
+
else:
|
| 54 |
+
setattr(self, key, dft[key])
|
| 55 |
+
return self
|
| 56 |
+
|
| 57 |
+
def fit(self, x, demo, demo_types, **kwargs):
|
| 58 |
+
# Get demo_dim
|
| 59 |
+
demo_dim = 0
|
| 60 |
+
for d,t in zip(demo, demo_types):
|
| 61 |
+
if t == 'continuous':
|
| 62 |
+
demo_dim += 1
|
| 63 |
+
elif t == 'categorical':
|
| 64 |
+
ll = len(set(list(d)))
|
| 65 |
+
if ll == 1:
|
| 66 |
+
print('Only one type of category for categorical variable')
|
| 67 |
+
raise Exception('Bad categorical')
|
| 68 |
+
demo_dim += ll
|
| 69 |
+
else:
|
| 70 |
+
print(f'demographic type "{t}" not "continuous" or "categorical"')
|
| 71 |
+
raise Exception('Bad demographic type')
|
| 72 |
+
# Save parameters
|
| 73 |
+
self.input_dim = x.shape[1]
|
| 74 |
+
self.demo_dim = demo_dim
|
| 75 |
+
# Create model
|
| 76 |
+
self.vae = VAE(x.shape[1], self.latent_dim, demo_dim, self.use_cuda)
|
| 77 |
+
# Train model
|
| 78 |
+
train_vae(self.vae, x, demo, demo_types,
|
| 79 |
+
self.nepochs, self.pperiod, self.bsize,
|
| 80 |
+
self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, self.loss_decor_mult, self.loss_pred_mult,
|
| 81 |
+
self.lr, self.weight_decay, self.alpha, self.LR_C,
|
| 82 |
+
self)
|
| 83 |
+
return self
|
| 84 |
+
|
| 85 |
+
def transform(self, x, demo, demo_types, **kwargs):
|
| 86 |
+
if isinstance(x, int):
|
| 87 |
+
# Generate
|
| 88 |
+
z = self.vae.gen(x)
|
| 89 |
+
else:
|
| 90 |
+
# Get latents for real data
|
| 91 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 92 |
+
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
|
| 93 |
+
y = self.vae.dec(z, demo_t)
|
| 94 |
+
return to_numpy(y)
|
| 95 |
+
|
| 96 |
+
def fit_transform(self, x, demo, demo_types, **kwargs):
|
| 97 |
+
self.fit(x, demo, demo_types)
|
| 98 |
+
return self.transform(x, demo, demo_types)
|
| 99 |
+
|
| 100 |
+
def get_latents(self, x):
|
| 101 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 102 |
+
return to_numpy(z)
|
| 103 |
+
|
| 104 |
+
def save(self, path):
|
| 105 |
+
params = self.get_params()
|
| 106 |
+
dct = dict(pred_stats=self.pred_stats,
|
| 107 |
+
params=params,
|
| 108 |
+
input_dim=self.input_dim,
|
| 109 |
+
demo_dim=self.demo_dim,
|
| 110 |
+
model_state_dict=self.vae.state_dict())
|
| 111 |
+
torch.save(dct, path)
|
| 112 |
+
|
| 113 |
+
def load(self, path):
|
| 114 |
+
dct = torch.load(path)
|
| 115 |
+
self.pred_stats = dct['pred_stats']
|
| 116 |
+
self.set_params(**dct['params'])
|
| 117 |
+
self.vae = VAE(dct['input_dim'],
|
| 118 |
+
dct['params']['latent_dim'],
|
| 119 |
+
dct['demo_dim'],
|
| 120 |
+
dct['params']['use_cuda'])
|
| 121 |
+
self.vae.load_state_dict(dct['model_state_dict'])
|
| 122 |
+
|
| 123 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.7.0
|
| 2 |
+
numpy>=1.19.0
|
| 3 |
+
pandas>=1.1.0
|
| 4 |
+
scikit-learn>=0.24.0
|
| 5 |
+
matplotlib>=3.3.0
|
| 6 |
+
nibabel>=3.2.0
|
| 7 |
+
tqdm>=4.50.0
|
| 8 |
+
gradio>=3.50.0
|
| 9 |
+
natsort>=8.0.0
|
| 10 |
+
requests>=2.25.0
|
| 11 |
+
datasets>=2.15.0
|
| 12 |
+
huggingface_hub>=0.10.0
|