Spaces:
Sleeping
Sleeping
Upload 12 files
Browse files- README .md +70 -0
- app.py +93 -258
- config.py +3 -16
- data_preprocessing.py +483 -1138
- gitattributes +35 -0
- main.py +233 -459
- requirements.txt +6 -2
- test_hf_download.py +2 -2
- utils.py +61 -78
- vae_model.py +131 -336
- visualization.py +32 -509
README .md
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: Aphasia fMRI VAE Analysis
|
| 3 |
+
emoji: 🧠
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: pink
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 5.20.1
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Aphasia fMRI to FC Analysis using VAE
|
| 13 |
+
|
| 14 |
+
This demo performs functional connectivity analysis on fMRI data using a Variational Autoencoder (VAE) approach. It's designed to work with aphasia patient data, analyzing brain connectivity patterns and their relationship to demographic variables.
|
| 15 |
+
|
| 16 |
+
## About the Model
|
| 17 |
+
|
| 18 |
+
This application implements a VAE model that:
|
| 19 |
+
1. Takes functional connectivity (FC) matrices derived from fMRI data
|
| 20 |
+
2. Learns a lower-dimensional latent representation of brain connectivity
|
| 21 |
+
3. Conditions the generation process on demographic variables (age, sex, time post-stroke, WAB scores)
|
| 22 |
+
4. Allows analysis of relationships between brain connectivity patterns and demographic variables
|
| 23 |
+
|
| 24 |
+
## Dataset
|
| 25 |
+
|
| 26 |
+
This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
|
| 27 |
+
|
| 28 |
+
- NIfTI files in P01_rs.nii format containing fMRI data
|
| 29 |
+
- Demographic information directly in the dataset:
|
| 30 |
+
- ID: Subject identifier
|
| 31 |
+
- wab_aq: Aphasia quotient score (severity measure)
|
| 32 |
+
- age: Subject age
|
| 33 |
+
- mpo: Months post onset
|
| 34 |
+
- education: Years of education
|
| 35 |
+
- gender: Subject gender
|
| 36 |
+
- handedness: Subject handedness (ignored in this analysis)
|
| 37 |
+
|
| 38 |
+
The application processes the NIfTI files using the Power 264 atlas to create functional connectivity matrices that are then analyzed by the VAE model.
|
| 39 |
+
|
| 40 |
+
## How to Use
|
| 41 |
+
|
| 42 |
+
1. **Configure Parameters**:
|
| 43 |
+
- **Data Source**: By default, it uses the SreekarB/OSFData HuggingFace dataset
|
| 44 |
+
- **Latent Dimensions**: Controls the size of the latent space (default: 32)
|
| 45 |
+
- **Number of Epochs**: Training iterations (default: 200 for demo)
|
| 46 |
+
- **Batch Size**: Training batch size (default: 16)
|
| 47 |
+
|
| 48 |
+
2. **Start Training**:
|
| 49 |
+
- Click the "Start Training" button to begin the analysis
|
| 50 |
+
- The training progress will be displayed in the Status area
|
| 51 |
+
|
| 52 |
+
3. **View Results**:
|
| 53 |
+
- The VAE will learn latent representations of brain connectivity
|
| 54 |
+
- Results will show correlations between demographic variables and latent brain patterns
|
| 55 |
+
- The visualization shows original FC, reconstructed FC, and a new FC matrix generated from specific demographic values
|
| 56 |
+
|
| 57 |
+
## Outputs
|
| 58 |
+
|
| 59 |
+
The application produces visualizations showing:
|
| 60 |
+
- Original FC matrix
|
| 61 |
+
- Reconstructed FC matrix
|
| 62 |
+
- Generated FC matrix (based on specific demographic inputs)
|
| 63 |
+
- Correlation plots between latent variables and demographic features
|
| 64 |
+
|
| 65 |
+
## Technical Details
|
| 66 |
+
|
| 67 |
+
- Framework: PyTorch
|
| 68 |
+
- Interface: Gradio
|
| 69 |
+
- Dataset: HuggingFace Datasets API
|
| 70 |
+
- Analysis: Custom implementation of conditional VAE with demographic conditioning
|
app.py
CHANGED
|
@@ -1,265 +1,100 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Simplified app for Huggingface Spaces.
|
| 3 |
-
Provides a simple UI for VAE training and visualization.
|
| 4 |
-
"""
|
| 5 |
-
import os
|
| 6 |
import gradio as gr
|
| 7 |
-
|
| 8 |
-
import
|
| 9 |
-
import matplotlib.pyplot as plt
|
| 10 |
-
from vae_model import DemoVAE, plot_learning_curves
|
| 11 |
-
import time
|
| 12 |
-
import tempfile
|
| 13 |
-
import logging
|
| 14 |
-
|
| 15 |
-
# Set up logging
|
| 16 |
-
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 17 |
-
logger = logging.getLogger(__name__)
|
| 18 |
-
|
| 19 |
-
# Make sure directories exist
|
| 20 |
-
os.makedirs('models', exist_ok=True)
|
| 21 |
-
os.makedirs('results', exist_ok=True)
|
| 22 |
-
|
| 23 |
-
# Global app state
|
| 24 |
-
app_state = {
|
| 25 |
-
'vae': None,
|
| 26 |
-
'latents': None,
|
| 27 |
-
'demographics': None,
|
| 28 |
-
'fc_data': None,
|
| 29 |
-
'vae_trained': False
|
| 30 |
-
}
|
| 31 |
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
return matrix
|
| 45 |
|
| 46 |
-
def
|
| 47 |
-
"
|
| 48 |
-
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
#
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
-
#
|
| 83 |
-
|
|
|
|
| 84 |
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
elif 'wab_aq' in demo_df.columns:
|
| 119 |
-
wab = demo_df['wab_aq'].values
|
| 120 |
-
else:
|
| 121 |
-
wab = np.random.normal(65, 15, len(X))
|
| 122 |
-
logger.warning("WAB score column not found, using synthetic data")
|
| 123 |
-
demographics.append(wab)
|
| 124 |
-
|
| 125 |
-
else:
|
| 126 |
-
logger.info("No demographics file provided, using synthetic data")
|
| 127 |
-
demographics = [
|
| 128 |
-
np.random.normal(60, 10, len(X)), # age
|
| 129 |
-
np.random.choice(['M', 'F'], len(X)), # sex
|
| 130 |
-
np.random.normal(24, 12, len(X)), # months post stroke
|
| 131 |
-
np.random.normal(65, 15, len(X)) # WAB score
|
| 132 |
-
]
|
| 133 |
-
|
| 134 |
-
app_state['demographics'] = demographics
|
| 135 |
-
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 136 |
-
|
| 137 |
-
except Exception as e:
|
| 138 |
-
logger.error(f"Error processing demographics: {e}")
|
| 139 |
-
return f"Error processing demographics: {str(e)}", None, None
|
| 140 |
-
|
| 141 |
-
# Initialize model
|
| 142 |
-
progress(0.3, "Initializing model...")
|
| 143 |
-
model = DemoVAE(nepochs=epochs, batch_size=batch_size, latent_dim=latent_dim)
|
| 144 |
-
|
| 145 |
-
# Train model
|
| 146 |
-
progress(0.4, "Training VAE model...")
|
| 147 |
-
train_losses, val_losses = model.fit(X, demographics, demo_types)
|
| 148 |
-
|
| 149 |
-
# Save model
|
| 150 |
-
progress(0.7, "Saving model...")
|
| 151 |
-
model.save('models/vae_model.pt')
|
| 152 |
-
app_state['vae'] = model
|
| 153 |
-
app_state['vae_trained'] = True
|
| 154 |
-
|
| 155 |
-
# Generate latent representations
|
| 156 |
-
progress(0.8, "Generating latent representations...")
|
| 157 |
-
latents = model.get_latents(X)
|
| 158 |
-
app_state['latents'] = latents
|
| 159 |
-
np.save('results/latents.npy', latents)
|
| 160 |
-
|
| 161 |
-
# Create visualizations
|
| 162 |
-
progress(0.9, "Creating visualizations...")
|
| 163 |
-
|
| 164 |
-
# Learning curves
|
| 165 |
-
learning_fig = plot_learning_curves(model.train_losses, model.val_losses)
|
| 166 |
-
learning_img = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 167 |
-
learning_fig.savefig(learning_img.name)
|
| 168 |
-
plt.close(learning_fig)
|
| 169 |
-
|
| 170 |
-
# FC visualization
|
| 171 |
-
progress(0.95, "Creating FC visualizations...")
|
| 172 |
-
reconstructed = model.transform(X, demographics, demo_types)
|
| 173 |
-
np.save('results/reconstructed.npy', reconstructed)
|
| 174 |
-
|
| 175 |
-
generated = model.transform(1, [d[0] for d in demographics], demo_types)
|
| 176 |
-
np.save('results/generated.npy', generated)
|
| 177 |
-
|
| 178 |
-
fc_fig, axes = plt.subplots(1, 3, figsize=(15, 5))
|
| 179 |
-
original_matrix = vector_to_matrix(X[0])
|
| 180 |
-
recon_matrix = vector_to_matrix(reconstructed[0])
|
| 181 |
-
gen_matrix = vector_to_matrix(generated[0])
|
| 182 |
-
|
| 183 |
-
# Plot matrices
|
| 184 |
-
titles = ['Original', 'Reconstructed', 'Generated']
|
| 185 |
-
for i, matrix in enumerate([original_matrix, recon_matrix, gen_matrix]):
|
| 186 |
-
im = axes[i].imshow(matrix, cmap='RdBu_r', vmin=-1, vmax=1)
|
| 187 |
-
axes[i].set_title(titles[i])
|
| 188 |
-
axes[i].axis('off')
|
| 189 |
-
|
| 190 |
-
fc_fig.subplots_adjust(right=0.8)
|
| 191 |
-
cbar_ax = fc_fig.add_axes([0.85, 0.15, 0.05, 0.7])
|
| 192 |
-
fc_fig.colorbar(im, cax=cbar_ax)
|
| 193 |
-
|
| 194 |
-
fc_img = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 195 |
-
fc_fig.savefig(fc_img.name)
|
| 196 |
-
plt.close(fc_fig)
|
| 197 |
-
|
| 198 |
-
progress(1.0, "Training complete!")
|
| 199 |
-
return "Training completed successfully!", learning_img.name, fc_img.name
|
| 200 |
-
|
| 201 |
-
except Exception as e:
|
| 202 |
-
logger.error(f"Error in VAE training: {str(e)}")
|
| 203 |
-
return f"Error: {str(e)}", None, None
|
| 204 |
-
|
| 205 |
-
def create_ui():
|
| 206 |
-
"""Create the Gradio UI"""
|
| 207 |
-
with gr.Blocks(title="FC Matrix VAE Demo") as app:
|
| 208 |
-
gr.Markdown("# Functional Connectivity VAE Demo")
|
| 209 |
-
gr.Markdown("Upload FC matrices and train a VAE model to analyze them.")
|
| 210 |
-
|
| 211 |
-
with gr.Tab("Train VAE"):
|
| 212 |
-
with gr.Row():
|
| 213 |
-
with gr.Column():
|
| 214 |
-
fc_file = gr.File(label="FC Matrix File (CSV or NPY)")
|
| 215 |
-
demo_file = gr.File(label="Demographics File (CSV, optional)")
|
| 216 |
-
|
| 217 |
-
with gr.Row():
|
| 218 |
-
epochs = gr.Slider(5, 100, 20, step=5, label="Training Epochs")
|
| 219 |
-
latent_dim = gr.Slider(8, 64, 16, step=4, label="Latent Dimension")
|
| 220 |
-
batch_size = gr.Slider(4, 32, 8, step=4, label="Batch Size")
|
| 221 |
-
|
| 222 |
-
train_btn = gr.Button("Train VAE Model")
|
| 223 |
-
status = gr.Textbox(label="Status")
|
| 224 |
-
|
| 225 |
-
with gr.Column():
|
| 226 |
-
learning_plot = gr.Image(label="Learning Curves")
|
| 227 |
-
fc_plot = gr.Image(label="FC Matrices")
|
| 228 |
-
|
| 229 |
-
train_btn.click(
|
| 230 |
-
fn=train_vae,
|
| 231 |
-
inputs=[fc_file, demo_file, epochs, latent_dim, batch_size],
|
| 232 |
-
outputs=[status, learning_plot, fc_plot]
|
| 233 |
-
)
|
| 234 |
-
|
| 235 |
-
with gr.Tab("About"):
|
| 236 |
-
gr.Markdown("""
|
| 237 |
-
## About this App
|
| 238 |
-
|
| 239 |
-
This app trains a Variational Autoencoder (VAE) on functional connectivity (FC) matrices.
|
| 240 |
-
|
| 241 |
-
### Features:
|
| 242 |
-
* Load FC matrices from CSV or NPY files
|
| 243 |
-
* Incorporate demographic data (age, sex, etc.)
|
| 244 |
-
* Visualize learning curves
|
| 245 |
-
* Compare original, reconstructed and generated FC matrices
|
| 246 |
-
|
| 247 |
-
### Input Format:
|
| 248 |
-
* FC matrices should be provided as vectors (flattened upper triangular portion of symmetric matrices)
|
| 249 |
-
* Demographics file should be CSV with columns for age, sex, months_post_stroke, and wab_score
|
| 250 |
-
|
| 251 |
-
### Model Architecture:
|
| 252 |
-
* Simple feedforward VAE with demographic conditioning
|
| 253 |
-
* Latent space can be specified (default 16 dimensions)
|
| 254 |
-
* MSE reconstruction loss
|
| 255 |
-
""")
|
| 256 |
-
|
| 257 |
-
return app
|
| 258 |
|
| 259 |
-
# For local testing
|
| 260 |
if __name__ == "__main__":
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
# For Huggingface Spaces
|
| 265 |
-
demo = create_ui()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import gradio as gr
|
| 2 |
+
from main import run_fc_analysis
|
| 3 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
+
def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
|
| 6 |
+
"""Run the full VAE analysis pipeline"""
|
| 7 |
+
fig = run_fc_analysis(
|
| 8 |
+
data_dir=data_source,
|
| 9 |
+
demographic_file=None, # We're now getting demographics directly from the dataset
|
| 10 |
+
latent_dim=latent_dim,
|
| 11 |
+
nepochs=nepochs,
|
| 12 |
+
bsize=bsize,
|
| 13 |
+
save_model=True,
|
| 14 |
+
use_hf_dataset=use_hf_dataset
|
| 15 |
+
)
|
| 16 |
+
return fig, "Analysis complete! VAE model has been trained and demographic relationships analyzed."
|
|
|
|
| 17 |
|
| 18 |
+
def create_interface():
|
| 19 |
+
with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
|
| 20 |
+
gr.Markdown("""
|
| 21 |
+
# Aphasia fMRI to FC Analysis using VAE
|
| 22 |
+
|
| 23 |
+
This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
|
| 24 |
+
|
| 25 |
+
## Dataset Information
|
| 26 |
+
By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
|
| 27 |
+
- ID: Subject identifier
|
| 28 |
+
- wab_aq: Aphasia severity score
|
| 29 |
+
- age: Age of the subject
|
| 30 |
+
- mpo: Months post onset
|
| 31 |
+
- education: Years of education
|
| 32 |
+
- gender: Subject gender
|
| 33 |
+
- handedness: Subject handedness (ignored in the analysis)
|
| 34 |
+
""")
|
| 35 |
+
|
| 36 |
+
with gr.Row():
|
| 37 |
+
with gr.Column(scale=1):
|
| 38 |
+
# Configuration parameters
|
| 39 |
+
data_source = gr.Textbox(
|
| 40 |
+
label="Data Source (HF Dataset ID or Local Directory)",
|
| 41 |
+
value="SreekarB/OSFData"
|
| 42 |
+
)
|
| 43 |
+
latent_dim = gr.Slider(
|
| 44 |
+
minimum=8, maximum=64, step=8,
|
| 45 |
+
label="Latent Dimensions", value=32
|
| 46 |
+
)
|
| 47 |
+
nepochs = gr.Slider(
|
| 48 |
+
minimum=100, maximum=5000, step=100,
|
| 49 |
+
label="Number of Epochs", value=200 # Reduced for faster demos
|
| 50 |
+
)
|
| 51 |
+
bsize = gr.Slider(
|
| 52 |
+
minimum=8, maximum=64, step=8,
|
| 53 |
+
label="Batch Size", value=16
|
| 54 |
+
)
|
| 55 |
+
use_hf_dataset = gr.Checkbox(
|
| 56 |
+
label="Use HuggingFace Dataset", value=True
|
| 57 |
+
)
|
| 58 |
|
| 59 |
+
# Training button
|
| 60 |
+
train_button = gr.Button("Start Training", variant="primary")
|
| 61 |
+
status_text = gr.Textbox(label="Status", value="Ready to start training")
|
| 62 |
|
| 63 |
+
with gr.Column(scale=2):
|
| 64 |
+
# Output plot
|
| 65 |
+
output_plot = gr.Plot(label="Analysis Results")
|
| 66 |
+
|
| 67 |
+
# Link the training button to the analysis function
|
| 68 |
+
train_button.click(
|
| 69 |
+
fn=gradio_fc_analysis,
|
| 70 |
+
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
|
| 71 |
+
outputs=[output_plot, status_text]
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Add examples
|
| 75 |
+
gr.Examples(
|
| 76 |
+
examples=[
|
| 77 |
+
["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
|
| 78 |
+
],
|
| 79 |
+
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Add explanation of the workflow
|
| 83 |
+
gr.Markdown("""
|
| 84 |
+
## How this works
|
| 85 |
+
|
| 86 |
+
1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
|
| 87 |
+
2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
|
| 88 |
+
3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
|
| 89 |
+
4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
|
| 90 |
+
5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
|
| 91 |
+
|
| 92 |
+
Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
|
| 93 |
+
""")
|
| 94 |
+
|
| 95 |
+
return iface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
|
|
|
|
| 97 |
if __name__ == "__main__":
|
| 98 |
+
iface = create_interface()
|
| 99 |
+
iface.launch(share=True)
|
| 100 |
+
|
|
|
|
|
|
config.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
# Model configuration
|
| 2 |
MODEL_CONFIG = {
|
| 3 |
'latent_dim': 32,
|
| 4 |
-
'nepochs':
|
| 5 |
-
'bsize':
|
| 6 |
'loss_rec_mult': 100,
|
| 7 |
'loss_decor_mult': 10,
|
| 8 |
'lr': 1e-4
|
|
@@ -18,20 +18,7 @@ PREPROCESS_CONFIG = {
|
|
| 18 |
|
| 19 |
# Dataset configuration
|
| 20 |
DATASET_CONFIG = {
|
| 21 |
-
'name': 'SreekarB/
|
| 22 |
'split': 'train'
|
| 23 |
}
|
| 24 |
|
| 25 |
-
# Prediction configuration
|
| 26 |
-
PREDICTION_CONFIG = {
|
| 27 |
-
'n_estimators': 100,
|
| 28 |
-
'max_depth': None,
|
| 29 |
-
'cv_folds': 5,
|
| 30 |
-
'default_outcome': 'wab_aq',
|
| 31 |
-
'save_path': 'results/treatment_predictor.joblib',
|
| 32 |
-
'skip_behavioral_data': True, # Set to True to skip processing behavioral_data.csv
|
| 33 |
-
'use_synthetic_nifti': False, # Set to False to use only real NIfTI data
|
| 34 |
-
'use_synthetic_fc': False, # Set to False to use only real FC matrices
|
| 35 |
-
'strict_real_data': True, # Set to True to strictly use real data only
|
| 36 |
-
'no_mock_data': True # Set to True to prevent using any mock or synthetic data
|
| 37 |
-
}
|
|
|
|
| 1 |
# Model configuration
|
| 2 |
MODEL_CONFIG = {
|
| 3 |
'latent_dim': 32,
|
| 4 |
+
'nepochs': 1000,
|
| 5 |
+
'bsize': 16,
|
| 6 |
'loss_rec_mult': 100,
|
| 7 |
'loss_decor_mult': 10,
|
| 8 |
'lr': 1e-4
|
|
|
|
| 18 |
|
| 19 |
# Dataset configuration
|
| 20 |
DATASET_CONFIG = {
|
| 21 |
+
'name': 'SreekarB/OSFData',
|
| 22 |
'split': 'train'
|
| 23 |
}
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
data_preprocessing.py
CHANGED
|
@@ -1,1212 +1,557 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
-
import
|
| 4 |
-
import
|
| 5 |
-
import pickle
|
| 6 |
-
import hashlib
|
| 7 |
-
import warnings
|
| 8 |
-
import re
|
| 9 |
-
from nilearn import input_data, connectome, datasets
|
| 10 |
from nilearn.image import load_img
|
| 11 |
import nibabel as nib
|
| 12 |
-
|
| 13 |
-
from config import PREPROCESS_CONFIG, PREDICTION_CONFIG
|
| 14 |
-
|
| 15 |
-
# Create cache directory if it doesn't exist
|
| 16 |
-
CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache')
|
| 17 |
-
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 18 |
-
os.makedirs(os.path.join(CACHE_DIR, 'time_series'), exist_ok=True)
|
| 19 |
-
os.makedirs(os.path.join(CACHE_DIR, 'fc_matrices'), exist_ok=True)
|
| 20 |
-
os.makedirs(os.path.join(CACHE_DIR, 'latents'), exist_ok=True)
|
| 21 |
-
os.makedirs(os.path.join(CACHE_DIR, 'maskers'), exist_ok=True)
|
| 22 |
-
os.makedirs(os.path.join(CACHE_DIR, 'atlas'), exist_ok=True)
|
| 23 |
-
|
| 24 |
-
# Cache the atlas coordinates globally for efficient access
|
| 25 |
-
REGIONAL_COORDS = None
|
| 26 |
-
|
| 27 |
-
# Initialize the Power atlas coordinates
|
| 28 |
-
def _init_atlas_coords():
|
| 29 |
-
global REGIONAL_COORDS
|
| 30 |
-
if REGIONAL_COORDS is None:
|
| 31 |
-
try:
|
| 32 |
-
atlas_path = os.path.join(CACHE_DIR, 'atlas', 'power_2011_coords.npy')
|
| 33 |
-
if os.path.exists(atlas_path):
|
| 34 |
-
REGIONAL_COORDS = np.load(atlas_path)
|
| 35 |
-
else:
|
| 36 |
-
from nilearn import datasets
|
| 37 |
-
power = datasets.fetch_coords_power_2011()
|
| 38 |
-
REGIONAL_COORDS = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 39 |
-
# Save for future use
|
| 40 |
-
np.save(atlas_path, REGIONAL_COORDS)
|
| 41 |
-
print(f"Initialized Power atlas coordinates with {len(REGIONAL_COORDS)} ROIs")
|
| 42 |
-
except Exception as e:
|
| 43 |
-
print(f"Error initializing atlas coordinates: {e}")
|
| 44 |
-
# Fallback to a simple set of coordinates if needed
|
| 45 |
-
REGIONAL_COORDS = np.array([
|
| 46 |
-
[0, 0, 0], [10, 0, 0], [0, 10, 0], [0, 0, 10]
|
| 47 |
-
])
|
| 48 |
-
print("WARNING: Using fallback coordinates due to initialization error")
|
| 49 |
-
|
| 50 |
-
# Initialize atlas coordinates at module load time
|
| 51 |
-
_init_atlas_coords()
|
| 52 |
-
|
| 53 |
-
def get_file_hash(file_path):
|
| 54 |
-
"""Generate a hash for a file to use as a cache key"""
|
| 55 |
-
try:
|
| 56 |
-
hasher = hashlib.md5()
|
| 57 |
-
with open(file_path, 'rb') as f:
|
| 58 |
-
# Read in chunks to handle large files
|
| 59 |
-
for chunk in iter(lambda: f.read(4096), b""):
|
| 60 |
-
hasher.update(chunk)
|
| 61 |
-
return hasher.hexdigest()
|
| 62 |
-
except Exception as e:
|
| 63 |
-
print(f"Error hashing file {file_path}: {e}")
|
| 64 |
-
# Fallback to filename-based hash if file reading fails
|
| 65 |
-
return hashlib.md5(os.path.basename(file_path).encode()).hexdigest()
|
| 66 |
|
| 67 |
-
def
|
| 68 |
"""
|
| 69 |
-
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
Returns:
|
| 76 |
-
coords: Array of coordinates for the atlas
|
| 77 |
-
"""
|
| 78 |
-
global REGIONAL_COORDS
|
| 79 |
|
| 80 |
-
# If we have already initialized the coordinates, use them
|
| 81 |
-
if REGIONAL_COORDS is not None:
|
| 82 |
-
return REGIONAL_COORDS
|
| 83 |
-
|
| 84 |
-
# Otherwise, use the initialization function
|
| 85 |
-
_init_atlas_coords()
|
| 86 |
-
return REGIONAL_COORDS
|
| 87 |
-
|
| 88 |
-
def get_cached_masker(radius, use_cache=True):
|
| 89 |
-
"""
|
| 90 |
-
Get a NiftiSpheresMasker with the specified radius, using cache if available
|
| 91 |
-
|
| 92 |
-
Args:
|
| 93 |
-
radius: Sphere radius in mm
|
| 94 |
-
use_cache: Whether to use/create cache
|
| 95 |
-
|
| 96 |
Returns:
|
| 97 |
-
|
|
|
|
|
|
|
| 98 |
"""
|
| 99 |
-
|
| 100 |
-
return None
|
| 101 |
-
|
| 102 |
-
# Create a cache key for this masker configuration
|
| 103 |
-
# We use radius and other PREPROCESS_CONFIG values that affect the masker
|
| 104 |
-
config_str = (f"radius={radius},"
|
| 105 |
-
f"tr={PREPROCESS_CONFIG['t_r']},"
|
| 106 |
-
f"high_pass={PREPROCESS_CONFIG['high_pass']},"
|
| 107 |
-
f"low_pass={PREPROCESS_CONFIG['low_pass']}")
|
| 108 |
-
|
| 109 |
-
masker_key = hashlib.md5(config_str.encode()).hexdigest()
|
| 110 |
-
masker_path = os.path.join(CACHE_DIR, 'maskers', f"{masker_key}.pkl")
|
| 111 |
|
| 112 |
-
#
|
| 113 |
-
if
|
|
|
|
|
|
|
| 114 |
try:
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 120 |
except Exception as e:
|
| 121 |
-
print(f"Error
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
print(f"Processing fMRI file: {fmri_file}")
|
| 139 |
-
|
| 140 |
-
# Make sure os is imported to avoid reference error
|
| 141 |
-
import os
|
| 142 |
-
|
| 143 |
-
# Check if cached FC matrix exists
|
| 144 |
-
if use_cache:
|
| 145 |
-
file_hash = get_file_hash(fmri_file)
|
| 146 |
-
fc_cache_path = os.path.join(CACHE_DIR, 'fc_matrices', f"{file_hash}.npy")
|
| 147 |
|
| 148 |
-
|
| 149 |
-
print(f"Loading cached FC matrix for {os.path.basename(fmri_file)}")
|
| 150 |
-
try:
|
| 151 |
-
fc_triu = np.load(fc_cache_path)
|
| 152 |
-
print(f"Successfully loaded cached FC matrix, shape: {fc_triu.shape}")
|
| 153 |
-
return fc_triu
|
| 154 |
-
except Exception as e:
|
| 155 |
-
print(f"Error loading cached FC matrix: {e}, recalculating...")
|
| 156 |
-
|
| 157 |
-
# Use Power 264 atlas with caching
|
| 158 |
-
coords = get_cached_atlas_coords(use_cache=use_cache)
|
| 159 |
-
|
| 160 |
-
# FIRST: Try to normalize the NIfTI file to MNI space for better compatibility
|
| 161 |
-
try:
|
| 162 |
-
print("First attempting to register NIfTI file to MNI space...")
|
| 163 |
-
from nilearn import image
|
| 164 |
import tempfile
|
|
|
|
|
|
|
| 165 |
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
# Check if it's a 4D file with sufficient time points
|
| 170 |
-
if len(orig_img.shape) < 4:
|
| 171 |
-
print("Cannot preprocess: Not a 4D file (no time dimension)")
|
| 172 |
-
elif orig_img.shape[3] < 20:
|
| 173 |
-
print(f"Warning: Very few time points ({orig_img.shape[3]}), results may be unreliable")
|
| 174 |
-
|
| 175 |
-
# Create a preprocessing directory
|
| 176 |
-
preproc_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'preproc')
|
| 177 |
-
os.makedirs(preproc_dir, exist_ok=True)
|
| 178 |
-
|
| 179 |
-
# Generate a filename for the preprocessed file
|
| 180 |
-
basename = os.path.basename(fmri_file)
|
| 181 |
-
preproc_file = os.path.join(preproc_dir, f"mni_registered_{basename}")
|
| 182 |
|
| 183 |
-
print(f"MNI registration steps for {basename}:")
|
| 184 |
-
|
| 185 |
-
# Step 1: Get the MNI152 template for reference
|
| 186 |
-
from nilearn.datasets import load_mni152_template
|
| 187 |
-
template = load_mni152_template()
|
| 188 |
-
print("1. Loaded MNI152 template as reference")
|
| 189 |
-
|
| 190 |
-
# Step 2: For 4D data, we'll work with the mean image for registration
|
| 191 |
-
if len(orig_img.shape) == 4:
|
| 192 |
-
mean_img = image.mean_img(orig_img)
|
| 193 |
-
print("2. Extracted mean image from 4D volume for registration")
|
| 194 |
-
else:
|
| 195 |
-
mean_img = orig_img
|
| 196 |
-
|
| 197 |
-
# Step 3: Register to MNI space (target resolution of 3mm)
|
| 198 |
-
print("3. Registering to MNI space with 3mm resolution...")
|
| 199 |
-
reg_img = image.resample_to_img(orig_img, template, interpolation='linear')
|
| 200 |
-
print(f" Original dimensions: {orig_img.shape}, New dimensions: {reg_img.shape}")
|
| 201 |
-
|
| 202 |
-
# Step 4: Save the preprocessed image
|
| 203 |
-
print(f"4. Saving MNI-registered image to {preproc_file}...")
|
| 204 |
-
reg_img.to_filename(preproc_file)
|
| 205 |
-
print("MNI registration complete")
|
| 206 |
-
|
| 207 |
-
# Now try to process this MNI-registered file
|
| 208 |
-
mni_fmri_file = preproc_file
|
| 209 |
-
except Exception as reg_err:
|
| 210 |
-
print(f"Error during MNI registration: {reg_err}")
|
| 211 |
-
print("Continuing with original NIfTI file")
|
| 212 |
-
mni_fmri_file = fmri_file
|
| 213 |
-
|
| 214 |
-
# Try different atlas radiuses if the default one has issues
|
| 215 |
-
# Include more radius options and make sure they're unique and sorted
|
| 216 |
-
radius_options = list(set([
|
| 217 |
-
PREPROCESS_CONFIG['radius'],
|
| 218 |
-
5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, # Add more options
|
| 219 |
-
4, 3 # Smaller options as last resort
|
| 220 |
-
]))
|
| 221 |
-
radius_options.sort() # Sort for consistent attempts
|
| 222 |
-
print(f"Will try these radius options in order: {radius_options}")
|
| 223 |
-
|
| 224 |
-
# Try each radius option
|
| 225 |
-
for radius in radius_options:
|
| 226 |
try:
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
|
| 234 |
-
|
| 235 |
-
print(f"
|
| 236 |
-
try:
|
| 237 |
-
time_series = np.load(ts_cache_path)
|
| 238 |
-
print(f"Successfully loaded cached time series, shape: {time_series.shape}")
|
| 239 |
-
except Exception as e:
|
| 240 |
-
print(f"Error loading cached time series: {e}, recalculating...")
|
| 241 |
-
time_series = None
|
| 242 |
-
else:
|
| 243 |
-
time_series = None
|
| 244 |
-
else:
|
| 245 |
-
time_series = None
|
| 246 |
-
|
| 247 |
-
# If no cached time series, calculate it
|
| 248 |
-
if time_series is None:
|
| 249 |
-
# Try to get a cached masker first
|
| 250 |
-
masker = get_cached_masker(radius, use_cache)
|
| 251 |
-
|
| 252 |
-
# If no cached masker, create a new one
|
| 253 |
-
if masker is None:
|
| 254 |
-
print(f"Creating new masker with radius {radius}mm")
|
| 255 |
-
# Create masker with allow_empty=True to handle empty spheres
|
| 256 |
-
masker = input_data.NiftiSpheresMasker(
|
| 257 |
-
coords,
|
| 258 |
-
radius=radius,
|
| 259 |
-
standardize=True,
|
| 260 |
-
memory='nilearn_cache',
|
| 261 |
-
memory_level=1,
|
| 262 |
-
verbose=1, # Increase verbosity for debugging
|
| 263 |
-
detrend=True,
|
| 264 |
-
low_pass=PREPROCESS_CONFIG['low_pass'],
|
| 265 |
-
high_pass=PREPROCESS_CONFIG['high_pass'],
|
| 266 |
-
t_r=PREPROCESS_CONFIG['t_r'],
|
| 267 |
-
allow_empty=True # Allow empty spheres
|
| 268 |
-
)
|
| 269 |
|
| 270 |
-
|
| 271 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 272 |
try:
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 279 |
|
| 280 |
-
with
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
-
# Check for insufficient time points
|
| 292 |
-
if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 20: # Assuming we need at least 20 time points
|
| 293 |
-
print(f"Warning: {mni_fmri_file} has insufficient time points: {fmri_img.shape}")
|
| 294 |
-
continue
|
| 295 |
-
|
| 296 |
-
# Transform to time series with explicit warning handling
|
| 297 |
-
print(f"Extracting time series...")
|
| 298 |
try:
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
print(f"
|
| 306 |
|
| 307 |
-
#
|
| 308 |
-
|
| 309 |
-
if empty_spheres:
|
| 310 |
-
print(f"Empty spheres: {empty_spheres[0]}")
|
| 311 |
|
| 312 |
-
#
|
| 313 |
-
|
| 314 |
-
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
except Exception as e:
|
| 326 |
-
print(f"Error
|
| 327 |
-
|
| 328 |
-
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 334 |
|
| 335 |
-
#
|
| 336 |
-
|
| 337 |
-
|
| 338 |
-
print(f"Warning: {len(zero_cols)} ROIs have zero/empty time series")
|
| 339 |
-
|
| 340 |
-
# If too many are empty (>50%), try the next radius
|
| 341 |
-
if len(zero_cols) > 0.5 * time_series.shape[1]:
|
| 342 |
-
print(f"Too many empty ROIs ({len(zero_cols)}), trying different radius")
|
| 343 |
-
continue
|
| 344 |
|
| 345 |
-
# Replace empty ROIs with the mean of non-empty ROIs
|
| 346 |
-
non_zero_cols = [i for i in range(time_series.shape[1]) if i not in zero_cols]
|
| 347 |
-
if non_zero_cols:
|
| 348 |
-
mean_timeseries = np.mean(time_series[:, non_zero_cols], axis=1)
|
| 349 |
-
for col in zero_cols:
|
| 350 |
-
# Add very small variation to the mean
|
| 351 |
-
time_series[:, col] = mean_timeseries + np.random.randn(time_series.shape[0]) * 1e-5
|
| 352 |
-
|
| 353 |
-
# Compute FC matrix
|
| 354 |
-
print(f"Computing FC matrix...")
|
| 355 |
-
correlation_measure = connectome.ConnectivityMeasure(
|
| 356 |
-
kind='correlation',
|
| 357 |
-
vectorize=False,
|
| 358 |
-
discard_diagonal=False
|
| 359 |
-
)
|
| 360 |
-
|
| 361 |
-
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 362 |
-
print(f"FC matrix computed, shape: {fc_matrix.shape}")
|
| 363 |
-
|
| 364 |
-
# Check for NaN values in the FC matrix
|
| 365 |
-
if np.any(np.isnan(fc_matrix)):
|
| 366 |
-
print(f"Warning: NaN values in FC matrix, replacing with zeros")
|
| 367 |
-
fc_matrix = np.nan_to_num(fc_matrix)
|
| 368 |
-
|
| 369 |
-
# Get upper triangular part
|
| 370 |
-
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 371 |
-
fc_triu = fc_matrix[triu_indices]
|
| 372 |
-
|
| 373 |
-
# Fisher z-transform
|
| 374 |
-
fc_triu = np.arctanh(np.clip(fc_triu, -0.99, 0.99)) # Clip to avoid infinite values
|
| 375 |
-
|
| 376 |
-
print(f"Processing complete. FC features shape: {fc_triu.shape}")
|
| 377 |
-
|
| 378 |
-
# Cache the successful FC matrix
|
| 379 |
-
if use_cache:
|
| 380 |
try:
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
|
|
|
|
|
|
|
|
|
| 386 |
|
| 387 |
-
|
| 388 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
except Exception as e:
|
| 390 |
-
print(f"
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
| 394 |
-
|
| 395 |
-
|
| 396 |
-
|
| 397 |
-
|
| 398 |
-
|
| 399 |
-
|
| 400 |
-
|
| 401 |
-
|
| 402 |
-
|
| 403 |
-
|
| 404 |
-
|
| 405 |
-
|
| 406 |
-
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
# Check if it's a 4D file with sufficient time points
|
| 410 |
-
if len(orig_img.shape) < 4:
|
| 411 |
-
print("Cannot preprocess: Not a 4D file (no time dimension)")
|
| 412 |
-
elif orig_img.shape[3] < 30:
|
| 413 |
-
print(f"Warning: Very few time points ({orig_img.shape[3]}), results may be unreliable")
|
| 414 |
-
|
| 415 |
-
# Create a preprocessing directory
|
| 416 |
-
preproc_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'preproc')
|
| 417 |
-
os.makedirs(preproc_dir, exist_ok=True)
|
| 418 |
-
|
| 419 |
-
# Generate a filename for the preprocessed file
|
| 420 |
-
basename = os.path.basename(fmri_file)
|
| 421 |
-
preproc_file = os.path.join(preproc_dir, f"preproc_{basename}")
|
| 422 |
-
|
| 423 |
-
print(f"Advanced preprocessing steps for {basename}:")
|
| 424 |
-
|
| 425 |
-
# Step 1: Get MNI template and try a more aggressive registration
|
| 426 |
-
from nilearn.datasets import load_mni152_template
|
| 427 |
-
template = load_mni152_template()
|
| 428 |
-
print("1. Loading MNI152 template as target space")
|
| 429 |
-
|
| 430 |
-
# Step 2: Resampling to MNI space with higher resolution (2mm)
|
| 431 |
-
print("2. Resampling to 2mm isotropic voxels in MNI space...")
|
| 432 |
-
try:
|
| 433 |
-
# Try resampling to MNI template
|
| 434 |
-
resampled_img = image.resample_to_img(
|
| 435 |
-
orig_img,
|
| 436 |
-
template,
|
| 437 |
-
interpolation='linear'
|
| 438 |
-
)
|
| 439 |
-
except Exception as resample_err:
|
| 440 |
-
print(f"Error resampling to template: {resample_err}")
|
| 441 |
-
# Fallback to standard resampling with affine
|
| 442 |
-
resampled_img = image.resample_img(
|
| 443 |
-
orig_img,
|
| 444 |
-
target_affine=np.diag([2, 2, 2, 1])
|
| 445 |
-
)
|
| 446 |
-
|
| 447 |
-
# Step 3: Motion correction and filtering
|
| 448 |
-
print("3. Applying robust temporal filtering...")
|
| 449 |
-
filtered_img = image.clean_img(
|
| 450 |
-
resampled_img,
|
| 451 |
detrend=True,
|
| 452 |
-
standardize='zscore',
|
| 453 |
low_pass=0.1,
|
| 454 |
high_pass=0.01,
|
| 455 |
-
t_r=2.0 #
|
| 456 |
)
|
| 457 |
|
| 458 |
-
#
|
| 459 |
-
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
# Save the preprocessed image
|
| 463 |
-
print(f"5. Saving preprocessed image to {preproc_file}...")
|
| 464 |
-
smoothed_img.to_filename(preproc_file)
|
| 465 |
-
|
| 466 |
-
print("Advanced preprocessing complete, attempting extraction with processed file...")
|
| 467 |
-
|
| 468 |
-
# SPECIAL APPROACH: Try using different coordinates
|
| 469 |
-
# If the atlas doesn't align with the image, we can try to adjust the coordinates
|
| 470 |
-
print("Trying with transformed coordinates...")
|
| 471 |
-
|
| 472 |
-
# Load the preprocessed image to get its dimensions
|
| 473 |
-
preproc_img = load_img(preproc_file)
|
| 474 |
-
preproc_shape = preproc_img.shape
|
| 475 |
-
preproc_affine = preproc_img.affine
|
| 476 |
-
|
| 477 |
-
# Get original Power coordinates
|
| 478 |
-
orig_coords = get_cached_atlas_coords()
|
| 479 |
-
|
| 480 |
-
# Calculate the coordinate ranges in the preprocessed image
|
| 481 |
-
img_mins = [0, 0, 0]
|
| 482 |
-
img_maxs = [preproc_shape[0]-1, preproc_shape[1]-1, preproc_shape[2]-1]
|
| 483 |
-
|
| 484 |
-
# Convert to world coordinates
|
| 485 |
-
from nibabel.affines import apply_affine
|
| 486 |
-
world_mins = apply_affine(preproc_affine, img_mins)
|
| 487 |
-
world_maxs = apply_affine(preproc_affine, img_maxs)
|
| 488 |
-
|
| 489 |
-
# Scale the coordinates to fit within the image bounds
|
| 490 |
-
coord_mins = orig_coords.min(axis=0)
|
| 491 |
-
coord_maxs = orig_coords.max(axis=0)
|
| 492 |
|
| 493 |
-
|
| 494 |
-
scale_x = (world_maxs[0] - world_mins[0]) / (coord_maxs[0] - coord_mins[0])
|
| 495 |
-
scale_y = (world_maxs[1] - world_mins[1]) / (coord_maxs[1] - coord_mins[1])
|
| 496 |
-
scale_z = (world_maxs[2] - world_mins[2]) / (coord_maxs[2] - coord_mins[2])
|
| 497 |
-
|
| 498 |
-
# Calculate offsets
|
| 499 |
-
offset_x = world_mins[0] - coord_mins[0] * scale_x
|
| 500 |
-
offset_y = world_mins[1] - coord_mins[1] * scale_y
|
| 501 |
-
offset_z = world_mins[2] - coord_mins[2] * scale_z
|
| 502 |
-
|
| 503 |
-
# Apply transformation to coordinates
|
| 504 |
-
adjusted_coords = np.copy(orig_coords)
|
| 505 |
-
adjusted_coords[:, 0] = orig_coords[:, 0] * scale_x + offset_x
|
| 506 |
-
adjusted_coords[:, 1] = orig_coords[:, 1] * scale_y + offset_y
|
| 507 |
-
adjusted_coords[:, 2] = orig_coords[:, 2] * scale_z + offset_z
|
| 508 |
-
|
| 509 |
-
print(f"Adjusted coordinates to fit within image bounds")
|
| 510 |
-
print(f"Original coord range: X({coord_mins[0]:.1f}-{coord_maxs[0]:.1f}), Y({coord_mins[1]:.1f}-{coord_maxs[1]:.1f}), Z({coord_mins[2]:.1f}-{coord_maxs[2]:.1f})")
|
| 511 |
-
print(f"Adjusted coord range: X({adjusted_coords[:,0].min():.1f}-{adjusted_coords[:,0].max():.1f}), Y({adjusted_coords[:,1].min():.1f}-{adjusted_coords[:,1].max():.1f}), Z({adjusted_coords[:,2].min():.1f}-{adjusted_coords[:,2].max():.1f})")
|
| 512 |
-
|
| 513 |
-
# Try to process with adjusted coordinates
|
| 514 |
-
for radius in radius_options:
|
| 515 |
try:
|
| 516 |
-
print(f"
|
|
|
|
| 517 |
|
| 518 |
-
#
|
| 519 |
-
|
| 520 |
-
|
| 521 |
-
radius=radius,
|
| 522 |
-
allow_overlap=True,
|
| 523 |
-
standardize=True,
|
| 524 |
-
memory='nilearn_cache',
|
| 525 |
-
memory_level=1,
|
| 526 |
-
verbose=1,
|
| 527 |
-
allow_empty=True # Allow empty spheres
|
| 528 |
-
)
|
| 529 |
-
|
| 530 |
-
# Extract time series from preprocessed file
|
| 531 |
-
time_series = masker.fit_transform(preproc_file)
|
| 532 |
-
|
| 533 |
-
# Check for too many empty ROIs
|
| 534 |
-
zero_cols = np.where(np.all(np.abs(time_series) < 1e-10, axis=0))[0]
|
| 535 |
-
if len(zero_cols) > 0.5 * time_series.shape[1]:
|
| 536 |
-
print(f"Too many empty ROIs ({len(zero_cols)}), trying different radius")
|
| 537 |
continue
|
|
|
|
|
|
|
| 538 |
|
| 539 |
-
#
|
| 540 |
-
|
| 541 |
-
|
|
|
|
|
|
|
| 542 |
|
| 543 |
-
|
| 544 |
-
|
|
|
|
|
|
|
|
|
|
| 545 |
|
| 546 |
-
|
| 547 |
-
np.fill_diagonal(z_matrix, 0)
|
| 548 |
|
| 549 |
-
#
|
| 550 |
-
|
| 551 |
-
|
| 552 |
-
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
-
#
|
| 555 |
-
|
| 556 |
-
|
|
|
|
| 557 |
|
| 558 |
-
|
| 559 |
-
|
|
|
|
| 560 |
|
| 561 |
-
# Cache the successful FC matrix
|
| 562 |
-
if use_cache:
|
| 563 |
-
try:
|
| 564 |
-
fc_cache_path = os.path.join(CACHE_DIR, 'fc_matrices', f"{file_hash}.npy")
|
| 565 |
-
np.save(fc_cache_path, fc_triu)
|
| 566 |
-
print(f"Saved FC matrix to cache: {fc_cache_path}")
|
| 567 |
-
except Exception as e:
|
| 568 |
-
print(f"Error saving FC matrix to cache: {e}")
|
| 569 |
-
|
| 570 |
-
return fc_triu
|
| 571 |
-
|
| 572 |
except Exception as e:
|
| 573 |
-
print(f"
|
| 574 |
-
# Try the next radius option
|
| 575 |
-
continue
|
| 576 |
-
|
| 577 |
-
print("Advanced preprocessing and coordinate adjustment failed")
|
| 578 |
-
except Exception as preproc_err:
|
| 579 |
-
print(f"Error during advanced preprocessing: {preproc_err}")
|
| 580 |
-
|
| 581 |
-
# Try to diagnose the issue
|
| 582 |
-
try:
|
| 583 |
-
# Check if the file exists and is readable
|
| 584 |
-
if not os.path.exists(fmri_file):
|
| 585 |
-
error_msg = f"File does not exist: {fmri_file}"
|
| 586 |
-
else:
|
| 587 |
-
# Try to get more information about the file
|
| 588 |
-
fmri_img = load_img(fmri_file)
|
| 589 |
|
| 590 |
-
|
| 591 |
-
|
| 592 |
-
header = fmri_img.header
|
| 593 |
-
zooms = header.get_zooms() # voxel dimensions
|
| 594 |
-
|
| 595 |
-
# Calculate the range of coordinates in the image
|
| 596 |
-
shape = fmri_img.shape
|
| 597 |
-
img_mins = [0, 0, 0]
|
| 598 |
-
img_maxs = [shape[0]-1, shape[1]-1, shape[2]-1]
|
| 599 |
-
|
| 600 |
-
# Convert to world coordinates
|
| 601 |
-
from nibabel.affines import apply_affine
|
| 602 |
-
world_mins = apply_affine(affine, img_mins)
|
| 603 |
-
world_maxs = apply_affine(affine, img_maxs)
|
| 604 |
-
|
| 605 |
-
# Get atlas coordinates for comparison
|
| 606 |
-
try:
|
| 607 |
-
coords = get_cached_atlas_coords()
|
| 608 |
-
coord_mins = coords.min(axis=0)
|
| 609 |
-
coord_maxs = coords.max(axis=0)
|
| 610 |
|
| 611 |
-
#
|
| 612 |
-
|
| 613 |
-
|
| 614 |
-
|
| 615 |
-
|
| 616 |
-
|
|
|
|
|
|
|
|
|
|
| 617 |
|
| 618 |
-
|
| 619 |
-
f"Y({coord_mins[1]:.1f} to {coord_maxs[1]:.1f}), "
|
| 620 |
-
f"Z({coord_mins[2]:.1f} to {coord_maxs[2]:.1f})")
|
| 621 |
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
|
| 626 |
-
|
| 627 |
-
|
| 628 |
-
atlas_info = f"Error getting atlas coords: {atlas_err}"
|
| 629 |
-
img_info = ""
|
| 630 |
-
alignment = "Unable to check atlas-image alignment"
|
| 631 |
-
|
| 632 |
-
# Check for potential issues
|
| 633 |
-
error_msg = (f"File is readable but couldn't be processed with any radius. "
|
| 634 |
-
f"\nShape: {fmri_img.shape}, Data type: {fmri_img.get_data_dtype()}"
|
| 635 |
-
f"\nVoxel dimensions: {zooms}"
|
| 636 |
-
f"\n{img_info}"
|
| 637 |
-
f"\n{atlas_info}"
|
| 638 |
-
f"\n{alignment}")
|
| 639 |
-
|
| 640 |
-
# Check if it's a 4D file with sufficient time points
|
| 641 |
-
if len(fmri_img.shape) < 4:
|
| 642 |
-
error_msg += "\nISSUE: Not a 4D file (no time dimension)"
|
| 643 |
-
elif fmri_img.shape[3] < 30:
|
| 644 |
-
error_msg += f"\nISSUE: Too few time points ({fmri_img.shape[3]}), need at least 30"
|
| 645 |
|
| 646 |
-
|
| 647 |
-
|
| 648 |
-
if abs(determinant) < 0.1:
|
| 649 |
-
error_msg += f"\nISSUE: Potentially invalid affine matrix (determinant={determinant:.3f})"
|
| 650 |
|
| 651 |
-
|
| 652 |
-
|
| 653 |
-
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
|
| 657 |
-
|
| 658 |
-
|
| 659 |
-
|
| 660 |
-
|
| 661 |
-
n_triu_elements = n_rois * (n_rois - 1) // 2
|
| 662 |
-
|
| 663 |
-
# Create synthetic FC matrix with realistic values
|
| 664 |
-
# Use the filename to seed random generator for consistency
|
| 665 |
-
try:
|
| 666 |
-
# Try to extract a patient number from the filename for seeding
|
| 667 |
-
filename = os.path.basename(fmri_file)
|
| 668 |
-
if 'P' in filename and '_' in filename:
|
| 669 |
-
seed = int(filename.split('_')[0].replace('P', '')) % 1000
|
| 670 |
-
else:
|
| 671 |
-
# Hash the filename for a consistent seed
|
| 672 |
-
seed = int(hashlib.md5(filename.encode()).hexdigest(), 16) % 1000
|
| 673 |
-
except:
|
| 674 |
-
seed = 42 # Default seed if filename parsing fails
|
| 675 |
-
|
| 676 |
-
np.random.seed(seed)
|
| 677 |
-
fc_triu = np.random.rand(n_triu_elements) * 1.6 - 0.8
|
| 678 |
-
fc_triu = np.arctanh(np.clip(fc_triu, -0.99, 0.99))
|
| 679 |
-
print(f"Created synthetic FC matrix with {n_triu_elements} elements, seed: {seed}")
|
| 680 |
-
return fc_triu
|
| 681 |
-
else:
|
| 682 |
-
error_msg = f"Could not process {fmri_file} with any radius option. {error_msg}"
|
| 683 |
-
print(f"ERROR: {error_msg}")
|
| 684 |
-
print("TIP: Set allow_synthetic=True to use synthetic data as fallback")
|
| 685 |
-
raise ValueError(error_msg)
|
| 686 |
-
|
| 687 |
-
def preprocess_fmri_to_fc(nii_files, demo_data, demo_types, use_synthetic_fallback=True):
|
| 688 |
-
"""
|
| 689 |
-
Convert multiple fMRI files to FC matrices
|
| 690 |
-
|
| 691 |
-
Args:
|
| 692 |
-
nii_files: List of NIfTI files to process
|
| 693 |
-
demo_data: Demographic data arrays
|
| 694 |
-
demo_types: Types of demographic data
|
| 695 |
-
use_synthetic_fallback: Whether to use synthetic data if real data processing fails
|
| 696 |
-
|
| 697 |
-
Returns:
|
| 698 |
-
Tuple of (FC matrices, demographic data, demographic types)
|
| 699 |
-
"""
|
| 700 |
-
fc_matrices = []
|
| 701 |
-
processed_files = []
|
| 702 |
-
|
| 703 |
-
try:
|
| 704 |
-
print(f"Found {len(nii_files)} fMRI files")
|
| 705 |
-
|
| 706 |
-
# Process each NIfTI file - using only real data first
|
| 707 |
-
for nii_file in nii_files:
|
| 708 |
-
try:
|
| 709 |
-
# Try to process real data, no synthetic fallback
|
| 710 |
-
fc_triu = process_single_fmri(nii_file, allow_synthetic=False)
|
| 711 |
-
fc_matrices.append(fc_triu)
|
| 712 |
-
processed_files.append(nii_file)
|
| 713 |
-
print(f"Successfully processed {nii_file}")
|
| 714 |
-
except Exception as e:
|
| 715 |
-
print(f"Error processing {nii_file}: {e}")
|
| 716 |
-
# Skip this file and continue with the next one
|
| 717 |
-
|
| 718 |
-
# Report how many files were successfully processed
|
| 719 |
-
print(f"Successfully processed {len(fc_matrices)}/{len(nii_files)} files")
|
| 720 |
-
|
| 721 |
-
# If we couldn't process any files, raise an error
|
| 722 |
-
if not fc_matrices:
|
| 723 |
-
print("ERROR: No real NIfTI files could be processed.")
|
| 724 |
-
detailed_error = """
|
| 725 |
-
|
| 726 |
-
The NIfTI files could not be processed. Here are possible reasons:
|
| 727 |
-
1. The files may be in a non-standard format
|
| 728 |
-
2. The coordinate system might not match the atlas
|
| 729 |
-
3. The image resolution or dimensions are incompatible with the processing pipeline
|
| 730 |
-
|
| 731 |
-
To fix this:
|
| 732 |
-
- Make sure your NIfTI files are in standard MNI space
|
| 733 |
-
- Check that they have sufficient time points (at least 30)
|
| 734 |
-
- Verify the files are valid 4D fMRI data
|
| 735 |
-
|
| 736 |
-
You can also try preprocessing them with tools like FSL or AFNI before importing.
|
| 737 |
-
"""
|
| 738 |
-
print(detailed_error)
|
| 739 |
-
raise ValueError("Could not process any NIfTI files - please check the logs for details")
|
| 740 |
-
|
| 741 |
-
# Create the feature matrix from the successfully processed files
|
| 742 |
-
X = np.array(fc_matrices)
|
| 743 |
-
|
| 744 |
-
# Check for NaN values in X
|
| 745 |
-
if np.any(np.isnan(X)):
|
| 746 |
-
print(f"Warning: Found NaN values in FC matrices, replacing with 0")
|
| 747 |
-
X = np.nan_to_num(X)
|
| 748 |
-
|
| 749 |
-
# Normalize the FC data
|
| 750 |
-
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
| 751 |
-
|
| 752 |
-
# Check for NaN values that might have been introduced during normalization
|
| 753 |
-
if np.any(np.isnan(X)):
|
| 754 |
-
print(f"Warning: Found NaN values after normalization, replacing with 0")
|
| 755 |
-
X = np.nan_to_num(X)
|
| 756 |
-
|
| 757 |
-
# If we have demographic data, adjust it to match processed files
|
| 758 |
-
if demo_data and len(demo_data) > 0 and len(demo_data[0]) > 0:
|
| 759 |
-
# Adjust demo_data to match the number of processed files
|
| 760 |
-
# This is necessary because we might not have processed all files
|
| 761 |
-
if len(X) < len(demo_data[0]):
|
| 762 |
-
print(f"Adjusting demographic data to match the {len(X)} processed files")
|
| 763 |
-
try:
|
| 764 |
-
# Get the indices of successfully processed files
|
| 765 |
-
file_indices = [int(os.path.basename(f).split('_')[0].replace('P', '')) - 1 for f in processed_files]
|
| 766 |
-
# Adjust each demographic variable
|
| 767 |
-
demo_data_adjusted = []
|
| 768 |
-
for d in demo_data:
|
| 769 |
-
# Select only the demographic data for successfully processed files
|
| 770 |
-
d_adjusted = [d[i] for i in file_indices if i < len(d)]
|
| 771 |
-
demo_data_adjusted.append(d_adjusted)
|
| 772 |
-
demo_data = demo_data_adjusted
|
| 773 |
-
except Exception as e:
|
| 774 |
-
print(f"Warning: Failed to adjust demographic data: {e}")
|
| 775 |
-
print("Generating synthetic demographic data instead")
|
| 776 |
-
# Generate synthetic demographics if adjustment fails
|
| 777 |
-
_, synthetic_demo = generate_synthetic_fc_matrices(len(X))
|
| 778 |
-
demo_data = synthetic_demo
|
| 779 |
-
else:
|
| 780 |
-
# If we don't have demographic data, generate synthetic one
|
| 781 |
-
print("No demographic data available, generating synthetic data")
|
| 782 |
-
_, synthetic_demo = generate_synthetic_fc_matrices(len(X))
|
| 783 |
-
demo_data = synthetic_demo
|
| 784 |
-
|
| 785 |
-
# Print final data shapes
|
| 786 |
-
print(f"Final FC matrix shape: {X.shape}")
|
| 787 |
-
print(f"Final demographic data lengths: {[len(d) for d in demo_data]}")
|
| 788 |
-
|
| 789 |
-
except Exception as e:
|
| 790 |
-
print(f"Error in FC preprocessing: {e}")
|
| 791 |
-
print("ERROR: Failed to process real NIfTI files.")
|
| 792 |
-
# Do not fall back to synthetic data
|
| 793 |
-
raise ValueError(f"Failed to process FC matrices from real NIfTI files: {e}")
|
| 794 |
-
|
| 795 |
-
return X, demo_data, demo_types
|
| 796 |
-
|
| 797 |
-
def generate_synthetic_fc_matrices(num_samples=5):
|
| 798 |
-
"""
|
| 799 |
-
Generate synthetic FC matrices and demographic data
|
| 800 |
-
|
| 801 |
-
Args:
|
| 802 |
-
num_samples: Number of samples to generate
|
| 803 |
-
|
| 804 |
-
Returns:
|
| 805 |
-
Tuple of (fc_matrices, demographic_data)
|
| 806 |
-
"""
|
| 807 |
-
print(f"Generating {num_samples} synthetic FC matrices...")
|
| 808 |
-
|
| 809 |
-
# Number of ROIs in Power atlas
|
| 810 |
-
n_rois = 264
|
| 811 |
-
n_triu_elements = n_rois * (n_rois - 1) // 2
|
| 812 |
-
|
| 813 |
-
# Generate synthetic FC matrices
|
| 814 |
-
np.random.seed(42) # for reproducibility
|
| 815 |
-
|
| 816 |
-
# Create synthetic data with a reasonable structure
|
| 817 |
-
# Upper triangular matrices (without diagonal) - like real FC matrices
|
| 818 |
-
fc_matrices = np.zeros((num_samples, n_triu_elements))
|
| 819 |
-
|
| 820 |
-
for i in range(num_samples):
|
| 821 |
-
# Set a different seed for each sample for variety
|
| 822 |
-
np.random.seed(42 + i)
|
| 823 |
-
|
| 824 |
-
# Generate synthetic data with correlation-like values (-1 to 1)
|
| 825 |
-
raw_corrs = np.random.rand(n_triu_elements) * 1.6 - 0.8 # Values between -0.8 and 0.8
|
| 826 |
-
z_scores = np.arctanh(np.clip(raw_corrs, -0.99, 0.99)) # Apply Fisher's z-transform with clipping
|
| 827 |
-
|
| 828 |
-
fc_matrices[i] = z_scores
|
| 829 |
-
|
| 830 |
-
# Generate synthetic demographic data
|
| 831 |
-
# Age (continuous): 30-80 years
|
| 832 |
-
ages = np.random.randint(30, 81, num_samples)
|
| 833 |
-
|
| 834 |
-
# Sex - roughly balanced
|
| 835 |
-
sexes = np.random.choice(['M', 'F'], num_samples)
|
| 836 |
-
|
| 837 |
-
# Months post stroke
|
| 838 |
-
months = np.random.randint(1, 25, num_samples)
|
| 839 |
-
|
| 840 |
-
# WAB scores
|
| 841 |
-
wab_scores = np.random.randint(20, 101, num_samples)
|
| 842 |
-
|
| 843 |
-
# Pack into demographic data format
|
| 844 |
-
demographic_data = [ages, sexes, months, wab_scores]
|
| 845 |
-
|
| 846 |
-
print(f"Generated {num_samples} synthetic FC matrices with shape {fc_matrices.shape}")
|
| 847 |
-
|
| 848 |
-
return fc_matrices, demographic_data
|
| 849 |
-
|
| 850 |
-
def clear_cache(cache_type=None):
|
| 851 |
-
"""
|
| 852 |
-
Clear all or specific types of cache
|
| 853 |
-
|
| 854 |
-
Args:
|
| 855 |
-
cache_type: Type of cache to clear ('time_series', 'fc_matrices', 'maskers', 'atlas', 'latents')
|
| 856 |
-
If None, clears all cache types
|
| 857 |
-
"""
|
| 858 |
-
if cache_type is None:
|
| 859 |
-
# Clear all cache types
|
| 860 |
-
cache_types = ['time_series', 'fc_matrices', 'maskers', 'atlas', 'latents']
|
| 861 |
-
else:
|
| 862 |
-
# Clear specific cache type
|
| 863 |
-
cache_types = [cache_type]
|
| 864 |
-
|
| 865 |
-
for ctype in cache_types:
|
| 866 |
-
cache_dir = os.path.join(CACHE_DIR, ctype)
|
| 867 |
-
if os.path.exists(cache_dir):
|
| 868 |
-
print(f"Clearing {ctype} cache...")
|
| 869 |
-
try:
|
| 870 |
-
for file in os.listdir(cache_dir):
|
| 871 |
-
file_path = os.path.join(cache_dir, file)
|
| 872 |
-
if os.path.isfile(file_path):
|
| 873 |
-
os.remove(file_path)
|
| 874 |
-
print(f"Successfully cleared {ctype} cache")
|
| 875 |
-
except Exception as e:
|
| 876 |
-
print(f"Error clearing {ctype} cache: {e}")
|
| 877 |
-
else:
|
| 878 |
-
print(f"Cache directory for {ctype} does not exist")
|
| 879 |
-
|
| 880 |
-
print("Cache clearing complete")
|
| 881 |
|
| 882 |
-
|
| 883 |
-
|
| 884 |
-
|
| 885 |
-
|
| 886 |
-
|
| 887 |
-
|
| 888 |
-
|
| 889 |
-
|
| 890 |
-
|
| 891 |
-
|
| 892 |
-
|
| 893 |
-
|
| 894 |
-
|
| 895 |
-
|
| 896 |
-
|
| 897 |
-
|
| 898 |
-
|
| 899 |
-
|
| 900 |
-
|
| 901 |
-
|
| 902 |
-
|
| 903 |
-
|
| 904 |
-
|
| 905 |
-
|
| 906 |
-
|
| 907 |
-
|
| 908 |
-
|
| 909 |
-
|
| 910 |
-
# Try older API
|
| 911 |
-
from huggingface_hub import HfFolder
|
| 912 |
-
cache_dir = HfFolder.get_cache_dir()
|
| 913 |
-
except (ImportError, AttributeError):
|
| 914 |
-
# Fallback to temp directory
|
| 915 |
-
cache_dir = os.path.join(tempfile.gettempdir(), "huggingface", "datasets")
|
| 916 |
-
print(f"Using fallback cache directory: {cache_dir}")
|
| 917 |
-
|
| 918 |
-
# Load the dataset (this will download it if not already cached)
|
| 919 |
-
dataset = load_dataset(dataset_name, cache_dir=cache_dir)
|
| 920 |
-
|
| 921 |
-
# Determine dataset cache path based on HuggingFace naming convention
|
| 922 |
-
dataset_cache_path = os.path.join(cache_dir, "datasets", dataset_name.replace("/", "--"))
|
| 923 |
-
print(f"Dataset cached at: {dataset_cache_path}")
|
| 924 |
-
|
| 925 |
-
# Try to find the snapshots directory which contains the actual files
|
| 926 |
-
snapshot_dir = None
|
| 927 |
-
if os.path.exists(dataset_cache_path):
|
| 928 |
-
# Look for snapshots directory which contains the actual files
|
| 929 |
-
for root, dirs, _ in os.walk(dataset_cache_path):
|
| 930 |
-
if 'snapshots' in dirs:
|
| 931 |
-
snapshot_dir = os.path.join(root, 'snapshots')
|
| 932 |
-
break
|
| 933 |
-
|
| 934 |
-
# If we found the snapshots directory, use it to search for NIfTI files
|
| 935 |
-
if snapshot_dir:
|
| 936 |
-
dataset_cache_path = snapshot_dir
|
| 937 |
-
print(f"Found snapshots directory at: {snapshot_dir}")
|
| 938 |
-
|
| 939 |
-
# Locate NIfTI files in the cached dataset
|
| 940 |
-
nii_files = []
|
| 941 |
-
for root, dirs, filenames in os.walk(dataset_cache_path):
|
| 942 |
-
for filename in filenames:
|
| 943 |
-
if filename.endswith('.nii') or filename.endswith('.nii.gz'):
|
| 944 |
-
nii_files.append(os.path.join(root, filename))
|
| 945 |
-
|
| 946 |
-
print(f"Found {len(nii_files)} NIfTI files in dataset cache")
|
| 947 |
-
|
| 948 |
-
except Exception as e:
|
| 949 |
-
print(f"Error accessing HuggingFace dataset cache: {e}")
|
| 950 |
-
print("Creating temporary cache directory...")
|
| 951 |
-
|
| 952 |
-
# Create a temporary directory and load dataset there
|
| 953 |
-
temp_cache_dir = tempfile.mkdtemp(prefix="hf_dataset_")
|
| 954 |
-
dataset = load_dataset(dataset_name, cache_dir=temp_cache_dir)
|
| 955 |
-
dataset_cache_path = temp_cache_dir
|
| 956 |
-
|
| 957 |
-
# Search for NIfTI files in the temporary cache
|
| 958 |
-
nii_files = []
|
| 959 |
-
for root, dirs, filenames in os.walk(temp_cache_dir):
|
| 960 |
-
for filename in filenames:
|
| 961 |
-
if filename.endswith('.nii') or filename.endswith('.nii.gz'):
|
| 962 |
-
nii_files.append(os.path.join(root, filename))
|
| 963 |
-
|
| 964 |
-
print(f"Found {len(nii_files)} NIfTI files in temporary cache: {temp_cache_dir}")
|
| 965 |
-
|
| 966 |
-
return dataset, dataset_cache_path, nii_files
|
| 967 |
|
| 968 |
-
|
| 969 |
-
|
| 970 |
-
|
| 971 |
-
|
| 972 |
-
Load and preprocess both fMRI data and demographics
|
| 973 |
-
|
| 974 |
-
Args:
|
| 975 |
-
data_dir: Directory containing data files or HuggingFace dataset name
|
| 976 |
-
demographic_file: Path to demographic CSV file (or None if using API)
|
| 977 |
-
use_hf_dataset: Whether to use HuggingFace dataset API
|
| 978 |
-
hf_nii_files: List of NIfTI file paths from HuggingFace dataset
|
| 979 |
-
hf_demo_data: Demographic data from HuggingFace dataset API
|
| 980 |
-
hf_demo_types: Types of demographic variables
|
| 981 |
-
"""
|
| 982 |
-
if use_hf_dataset:
|
| 983 |
-
# Handle HuggingFace dataset
|
| 984 |
-
if hf_demo_data is not None and hf_demo_types is not None:
|
| 985 |
-
# Use demographic data directly from API
|
| 986 |
-
demo_data = hf_demo_data
|
| 987 |
-
demo_types = hf_demo_types
|
| 988 |
-
else:
|
| 989 |
-
# Load demographics from file
|
| 990 |
-
if demographic_file is not None:
|
| 991 |
-
demo_df = pd.read_csv(demographic_file)
|
| 992 |
|
| 993 |
-
# Map column names if needed (flexible column naming)
|
| 994 |
-
column_mapping = {
|
| 995 |
-
'age_at_stroke': ['age_at_stroke', 'age', 'Age', 'patient_age'],
|
| 996 |
-
'sex': ['sex', 'gender', 'Gender', 'Sex'],
|
| 997 |
-
'months_post_stroke': ['months_post_stroke', 'mpo', 'MPO', 'months_post_onset'],
|
| 998 |
-
'wab_score': ['wab_score', 'wab_aq', 'WAB', 'WAB_AQ', 'aphasia_score']
|
| 999 |
-
}
|
| 1000 |
-
|
| 1001 |
-
# Check and map columns if necessary
|
| 1002 |
-
for target_col, alt_cols in column_mapping.items():
|
| 1003 |
-
if target_col not in demo_df.columns:
|
| 1004 |
-
for alt_col in alt_cols:
|
| 1005 |
-
if alt_col in demo_df.columns:
|
| 1006 |
-
demo_df[target_col] = demo_df[alt_col]
|
| 1007 |
-
print(f"Mapped {alt_col} to {target_col}")
|
| 1008 |
-
break
|
| 1009 |
-
|
| 1010 |
-
# Extract demographic data
|
| 1011 |
demo_data = [
|
| 1012 |
-
demo_df['age_at_stroke'].values,
|
| 1013 |
-
demo_df['sex'].values,
|
| 1014 |
-
demo_df['months_post_stroke'].values,
|
| 1015 |
-
demo_df['wab_score'].values
|
| 1016 |
]
|
| 1017 |
|
| 1018 |
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 1019 |
-
else:
|
| 1020 |
-
# No demographic data provided
|
| 1021 |
-
raise ValueError("No demographic data provided")
|
| 1022 |
|
| 1023 |
-
#
|
| 1024 |
-
|
| 1025 |
-
|
| 1026 |
-
|
| 1027 |
-
|
| 1028 |
-
_, _, nii_files_from_hf = download_and_cache_dataset(data_dir)
|
| 1029 |
-
|
| 1030 |
-
if nii_files_from_hf and len(nii_files_from_hf) > 0:
|
| 1031 |
-
print(f"Successfully found {len(nii_files_from_hf)} NIfTI files in HuggingFace dataset cache")
|
| 1032 |
-
hf_nii_files = nii_files_from_hf
|
| 1033 |
-
except Exception as e:
|
| 1034 |
-
print(f"Error accessing HuggingFace dataset files: {e}")
|
| 1035 |
-
hf_nii_files = []
|
| 1036 |
-
|
| 1037 |
-
# Use provided NIfTI files from HuggingFace
|
| 1038 |
-
if hf_nii_files and len(hf_nii_files) > 0:
|
| 1039 |
-
# Apply sample limit if specified
|
| 1040 |
-
if max_samples is not None and len(hf_nii_files) > max_samples:
|
| 1041 |
-
print(f"Limiting to {max_samples} NIfTI files as specified (from {len(hf_nii_files)} available)")
|
| 1042 |
-
nii_files = hf_nii_files[:max_samples]
|
| 1043 |
-
else:
|
| 1044 |
-
nii_files = hf_nii_files
|
| 1045 |
-
|
| 1046 |
-
print(f"Using {len(nii_files)} NIfTI files from HuggingFace dataset")
|
| 1047 |
-
else:
|
| 1048 |
-
# Check if we should use synthetic data
|
| 1049 |
-
if PREDICTION_CONFIG.get('use_synthetic_nifti', True):
|
| 1050 |
-
# Create synthetic NIfTI files as fallback
|
| 1051 |
-
print("No NIfTI files found in HuggingFace dataset - creating synthetic data")
|
| 1052 |
-
|
| 1053 |
-
try:
|
| 1054 |
-
import tempfile
|
| 1055 |
-
import os
|
| 1056 |
-
import numpy as np
|
| 1057 |
-
import nibabel as nib
|
| 1058 |
-
from pathlib import Path
|
| 1059 |
-
|
| 1060 |
-
# Create a temporary directory for our synthetic files
|
| 1061 |
-
temp_dir = tempfile.mkdtemp(prefix="synthetic_nifti_")
|
| 1062 |
-
print(f"Created temp directory for synthetic data: {temp_dir}")
|
| 1063 |
-
|
| 1064 |
-
# How many patients do we need to simulate?
|
| 1065 |
-
num_patients = len(demo_data[0]) if demo_data and len(demo_data) > 0 else 10
|
| 1066 |
-
print(f"Creating synthetic data for {num_patients} patients")
|
| 1067 |
-
|
| 1068 |
-
nii_files = []
|
| 1069 |
-
|
| 1070 |
-
# Create synthetic NIfTI files (264x264 FC matrices)
|
| 1071 |
-
for i in range(num_patients):
|
| 1072 |
-
# Create random symmetric matrix
|
| 1073 |
-
np.random.seed(i) # For reproducibility
|
| 1074 |
-
|
| 1075 |
-
# Generate a 60x75x60 random volume (typical fMRI dimensions)
|
| 1076 |
-
vol_shape = (60, 75, 60)
|
| 1077 |
-
data = np.random.randn(*vol_shape)
|
| 1078 |
-
|
| 1079 |
-
# Create the NIfTI file
|
| 1080 |
-
img = nib.Nifti1Image(data, np.eye(4))
|
| 1081 |
-
|
| 1082 |
-
# Save to temp directory
|
| 1083 |
-
file_path = os.path.join(temp_dir, f"P{i+1:02d}_rs.nii.gz")
|
| 1084 |
-
nib.save(img, file_path)
|
| 1085 |
-
nii_files.append(file_path)
|
| 1086 |
-
|
| 1087 |
-
print(f"Successfully created {len(nii_files)} synthetic NIfTI files")
|
| 1088 |
-
|
| 1089 |
-
except Exception as e:
|
| 1090 |
-
print(f"Error creating synthetic NIfTI data: {e}")
|
| 1091 |
-
raise ValueError(f"No NIfTI files found in HuggingFace dataset and failed to create synthetic data: {e}")
|
| 1092 |
-
else:
|
| 1093 |
-
# Don't use synthetic data
|
| 1094 |
-
raise ValueError("No NIfTI files found in HuggingFace dataset and synthetic data generation is disabled")
|
| 1095 |
-
else:
|
| 1096 |
-
# Standard local file loading
|
| 1097 |
-
if demographic_file is not None:
|
| 1098 |
-
# Load demographics
|
| 1099 |
-
demo_df = pd.read_csv(demographic_file)
|
| 1100 |
|
| 1101 |
-
|
| 1102 |
-
|
| 1103 |
-
|
| 1104 |
-
'sex': ['sex', 'gender', 'Gender', 'Sex'],
|
| 1105 |
-
'months_post_stroke': ['months_post_stroke', 'mpo', 'MPO', 'months_post_onset'],
|
| 1106 |
-
'wab_score': ['wab_score', 'wab_aq', 'WAB', 'WAB_AQ', 'aphasia_score']
|
| 1107 |
-
}
|
| 1108 |
|
| 1109 |
-
|
| 1110 |
-
for target_col, alt_cols in column_mapping.items():
|
| 1111 |
-
if target_col not in demo_df.columns:
|
| 1112 |
-
for alt_col in alt_cols:
|
| 1113 |
-
if alt_col in demo_df.columns:
|
| 1114 |
-
demo_df[target_col] = demo_df[alt_col]
|
| 1115 |
-
print(f"Mapped {alt_col} to {target_col}")
|
| 1116 |
-
break
|
| 1117 |
|
| 1118 |
-
|
| 1119 |
-
|
| 1120 |
-
demo_df['sex'].values,
|
| 1121 |
-
demo_df['months_post_stroke'].values,
|
| 1122 |
-
demo_df['wab_score'].values
|
| 1123 |
-
]
|
| 1124 |
|
| 1125 |
-
|
| 1126 |
-
|
| 1127 |
-
|
| 1128 |
-
|
| 1129 |
-
# Load fMRI files from local directory
|
| 1130 |
-
nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
|
| 1131 |
-
|
| 1132 |
-
# Also look for .nii files (without .gz)
|
| 1133 |
-
nii_files_nogz = sorted(list(Path(data_dir).glob('*.nii')))
|
| 1134 |
-
nii_files.extend(nii_files_nogz)
|
| 1135 |
|
| 1136 |
-
|
| 1137 |
-
if max_samples is not None and len(nii_files) > max_samples:
|
| 1138 |
-
print(f"Limiting to {max_samples} NIfTI files as specified (from {len(nii_files)} available)")
|
| 1139 |
-
nii_files = nii_files[:max_samples]
|
| 1140 |
|
| 1141 |
-
|
| 1142 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1143 |
|
| 1144 |
-
|
| 1145 |
-
|
| 1146 |
-
|
| 1147 |
-
|
| 1148 |
-
|
| 1149 |
-
|
| 1150 |
-
|
| 1151 |
-
|
| 1152 |
-
|
| 1153 |
-
|
| 1154 |
-
|
| 1155 |
-
|
| 1156 |
-
|
| 1157 |
-
|
| 1158 |
-
|
| 1159 |
-
|
| 1160 |
-
print(f"Creating synthetic data for {num_patients} patients")
|
| 1161 |
-
|
| 1162 |
-
nii_files = []
|
| 1163 |
-
|
| 1164 |
-
# Create synthetic NIfTI files
|
| 1165 |
-
for i in range(num_patients):
|
| 1166 |
-
# Create random symmetric matrix
|
| 1167 |
-
np.random.seed(i) # For reproducibility
|
| 1168 |
-
|
| 1169 |
-
# Generate a 60x75x60 random volume (typical fMRI dimensions)
|
| 1170 |
-
vol_shape = (60, 75, 60)
|
| 1171 |
-
data = np.random.randn(*vol_shape)
|
| 1172 |
-
|
| 1173 |
-
# Create the NIfTI file
|
| 1174 |
-
img = nib.Nifti1Image(data, np.eye(4))
|
| 1175 |
-
|
| 1176 |
-
# Save to temp directory
|
| 1177 |
-
file_path = os.path.join(temp_dir, f"P{i+1:02d}_rs.nii.gz")
|
| 1178 |
-
nib.save(img, file_path)
|
| 1179 |
-
nii_files.append(file_path)
|
| 1180 |
-
|
| 1181 |
-
print(f"Successfully created {len(nii_files)} synthetic NIfTI files")
|
| 1182 |
-
|
| 1183 |
-
except Exception as e:
|
| 1184 |
-
print(f"Error creating synthetic NIfTI data: {e}")
|
| 1185 |
-
raise ValueError(f"No NIfTI files found in {data_dir} and failed to create synthetic data: {e}")
|
| 1186 |
-
else:
|
| 1187 |
-
# Don't use synthetic data
|
| 1188 |
-
raise ValueError(f"No NIfTI files (*.nii or *.nii.gz) found in {data_dir} and synthetic data generation is disabled")
|
| 1189 |
-
else:
|
| 1190 |
-
print(f"Found {len(nii_files)} NIfTI files in {data_dir}")
|
| 1191 |
-
|
| 1192 |
-
# Process fMRI files to FC matrices
|
| 1193 |
-
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
|
| 1194 |
-
|
| 1195 |
-
# Check for sample size consistency and fix if needed
|
| 1196 |
-
print(f"After preprocessing: X shape: {X.shape}, demo_data lengths: {[len(d) for d in demo_data]}")
|
| 1197 |
|
| 1198 |
-
#
|
| 1199 |
-
|
| 1200 |
-
print(f"WARNING: Sample size mismatch detected! X: {X.shape[0]}, demo: {len(demo_data[0])}")
|
| 1201 |
-
|
| 1202 |
-
# Determine the smaller size
|
| 1203 |
-
min_samples = min(X.shape[0], len(demo_data[0]))
|
| 1204 |
-
print(f"Adjusting to {min_samples} samples")
|
| 1205 |
-
|
| 1206 |
-
# Trim X and demographic data to match
|
| 1207 |
-
X = X[:min_samples]
|
| 1208 |
-
demo_data = [d[:min_samples] for d in demo_data]
|
| 1209 |
-
|
| 1210 |
-
print(f"After adjustment: X shape: {X.shape}, demo_data lengths: {[len(d) for d in demo_data]}")
|
| 1211 |
|
| 1212 |
-
return X, demo_data, demo_types
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from nilearn import input_data, connectome
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from nilearn.image import load_img
|
| 6 |
import nibabel as nib
|
| 7 |
+
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
|
| 10 |
"""
|
| 11 |
+
Process fMRI data to generate functional connectivity matrices
|
| 12 |
|
| 13 |
+
Parameters:
|
| 14 |
+
- dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
|
| 15 |
+
- demo_data: Optional demographic data, required if providing NIfTI files
|
| 16 |
+
- demo_types: Optional demographic data types, required if providing NIfTI files
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
Returns:
|
| 19 |
+
- X: Array of FC matrices
|
| 20 |
+
- demo_data: Demographic data
|
| 21 |
+
- demo_types: Demographic data types
|
| 22 |
"""
|
| 23 |
+
print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
|
| 25 |
+
# For SreekarB/OSFData dataset, the data will be loaded from dataset features
|
| 26 |
+
if isinstance(dataset_or_niifiles, str):
|
| 27 |
+
dataset_name = dataset_or_niifiles
|
| 28 |
+
print(f"Loading data from dataset: {dataset_name}")
|
| 29 |
try:
|
| 30 |
+
# Try multiple approaches to load the dataset
|
| 31 |
+
approaches = [
|
| 32 |
+
lambda: load_dataset(dataset_name, split="train"),
|
| 33 |
+
lambda: load_dataset(dataset_name), # Try without split
|
| 34 |
+
lambda: load_dataset(dataset_name, split="train", trust_remote_code=True), # Try with trust_remote_code
|
| 35 |
+
lambda: load_dataset(dataset_name.split("/")[-1], split="train") if "/" in dataset_name else None
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
dataset = None
|
| 39 |
+
last_error = None
|
| 40 |
+
|
| 41 |
+
for i, approach in enumerate(approaches):
|
| 42 |
+
if approach is None:
|
| 43 |
+
continue
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
print(f"Attempt {i+1} to load dataset...")
|
| 47 |
+
dataset = approach()
|
| 48 |
+
print(f"Successfully loaded dataset with approach {i+1}!")
|
| 49 |
+
break
|
| 50 |
+
except Exception as e:
|
| 51 |
+
print(f"Attempt {i+1} failed: {e}")
|
| 52 |
+
last_error = e
|
| 53 |
+
|
| 54 |
+
if dataset is None:
|
| 55 |
+
print(f"All attempts to load dataset failed. Last error: {last_error}")
|
| 56 |
+
raise ValueError(f"Could not load dataset {dataset_name}")
|
| 57 |
except Exception as e:
|
| 58 |
+
print(f"Error during dataset loading: {e}")
|
| 59 |
+
raise
|
| 60 |
+
|
| 61 |
+
# Prepare demographics data from the dataset
|
| 62 |
+
if demo_data is None:
|
| 63 |
+
# Create demo_data from the dataset
|
| 64 |
+
demo_df = pd.DataFrame({
|
| 65 |
+
'age': dataset['age'],
|
| 66 |
+
'gender': dataset['gender'],
|
| 67 |
+
'mpo': dataset['mpo'],
|
| 68 |
+
'wab_aq': dataset['wab_aq']
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
demo_data = [
|
| 72 |
+
demo_df['age'].values,
|
| 73 |
+
demo_df['gender'].values,
|
| 74 |
+
demo_df['mpo'].values,
|
| 75 |
+
demo_df['wab_aq'].values
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 79 |
|
| 80 |
+
# Look for NIfTI files in P01_rs.nii format
|
| 81 |
+
print("Searching for NIfTI files in dataset columns...")
|
| 82 |
+
nii_files = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
+
# Create a temp directory for downloads
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
import tempfile
|
| 86 |
+
from huggingface_hub import hf_hub_download
|
| 87 |
+
import shutil
|
| 88 |
|
| 89 |
+
temp_dir = tempfile.mkdtemp(prefix="hf_nifti_")
|
| 90 |
+
print(f"Created temporary directory for NIfTI files: {temp_dir}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
try:
|
| 93 |
+
# First approach: Check if there are any columns containing file paths
|
| 94 |
+
nii_columns = []
|
| 95 |
+
for col in dataset.column_names:
|
| 96 |
+
# Check if column name suggests NIfTI files
|
| 97 |
+
if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower():
|
| 98 |
+
nii_columns.append(col)
|
| 99 |
+
# Or check if column contains file paths
|
| 100 |
+
elif len(dataset) > 0:
|
| 101 |
+
first_val = dataset[0][col]
|
| 102 |
+
if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
|
| 103 |
+
nii_columns.append(col)
|
| 104 |
+
|
| 105 |
+
if nii_columns:
|
| 106 |
+
print(f"Found columns that may contain NIfTI files: {nii_columns}")
|
| 107 |
|
| 108 |
+
for col in nii_columns:
|
| 109 |
+
print(f"Processing column '{col}'...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
+
for i, item in enumerate(dataset[col]):
|
| 112 |
+
if not isinstance(item, str):
|
| 113 |
+
print(f"Item {i} in column {col} is not a string but {type(item)}")
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
if not (item.endswith('.nii') or item.endswith('.nii.gz')):
|
| 117 |
+
print(f"Item {i} in column {col} is not a NIfTI file: {item}")
|
| 118 |
+
continue
|
| 119 |
+
|
| 120 |
+
print(f"Downloading {item} from dataset {dataset_name}...")
|
| 121 |
+
|
| 122 |
try:
|
| 123 |
+
# Attempt to download with explicit filename
|
| 124 |
+
file_path = hf_hub_download(
|
| 125 |
+
repo_id=dataset_name,
|
| 126 |
+
filename=item,
|
| 127 |
+
repo_type="dataset",
|
| 128 |
+
cache_dir=temp_dir
|
| 129 |
+
)
|
| 130 |
+
nii_files.append(file_path)
|
| 131 |
+
print(f"✓ Successfully downloaded {item}")
|
| 132 |
+
except Exception as e1:
|
| 133 |
+
print(f"Error downloading with explicit filename: {e1}")
|
| 134 |
|
| 135 |
+
# Second attempt: try with the item's basename
|
| 136 |
+
try:
|
| 137 |
+
basename = os.path.basename(item)
|
| 138 |
+
print(f"Trying with basename: {basename}")
|
| 139 |
+
file_path = hf_hub_download(
|
| 140 |
+
repo_id=dataset_name,
|
| 141 |
+
filename=basename,
|
| 142 |
+
repo_type="dataset",
|
| 143 |
+
cache_dir=temp_dir
|
| 144 |
+
)
|
| 145 |
+
nii_files.append(file_path)
|
| 146 |
+
print(f"✓ Successfully downloaded {basename}")
|
| 147 |
+
except Exception as e2:
|
| 148 |
+
print(f"Error downloading with basename: {e2}")
|
| 149 |
+
|
| 150 |
+
# Third attempt: check if it's a binary blob in the dataset
|
| 151 |
+
try:
|
| 152 |
+
if hasattr(dataset[i], 'keys') and 'bytes' in dataset[i]:
|
| 153 |
+
print("Found binary data in dataset, saving to temporary file...")
|
| 154 |
+
binary_data = dataset[i]['bytes']
|
| 155 |
+
temp_file = os.path.join(temp_dir, basename)
|
| 156 |
+
with open(temp_file, 'wb') as f:
|
| 157 |
+
f.write(binary_data)
|
| 158 |
+
nii_files.append(temp_file)
|
| 159 |
+
print(f"✓ Saved binary data to {temp_file}")
|
| 160 |
+
except Exception as e3:
|
| 161 |
+
print(f"Error handling binary data: {e3}")
|
| 162 |
+
|
| 163 |
+
# Last resort: look for the file locally
|
| 164 |
+
local_path = os.path.join(os.getcwd(), item)
|
| 165 |
+
if os.path.exists(local_path):
|
| 166 |
+
nii_files.append(local_path)
|
| 167 |
+
print(f"✓ Found {item} locally")
|
| 168 |
+
else:
|
| 169 |
+
print(f"❌ Warning: Could not find {item} anywhere")
|
| 170 |
+
|
| 171 |
+
# Second approach: Try to find NIfTI files in dataset repository directly
|
| 172 |
+
if not nii_files:
|
| 173 |
+
print("No NIfTI files found in dataset columns. Trying direct repository search...")
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
try:
|
| 176 |
+
from huggingface_hub import list_repo_files, hf_hub_download
|
| 177 |
+
|
| 178 |
+
# Try to list all files in the repository
|
| 179 |
+
try:
|
| 180 |
+
print("Listing all repository files...")
|
| 181 |
+
all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
|
| 182 |
+
print(f"Found {len(all_repo_files)} files in repository")
|
| 183 |
|
| 184 |
+
# First prioritize P*_rs.nii files
|
| 185 |
+
p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')]
|
|
|
|
|
|
|
| 186 |
|
| 187 |
+
# Then include all other NIfTI files
|
| 188 |
+
other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files]
|
| 189 |
+
|
| 190 |
+
# Combine, with P*_rs.nii files first
|
| 191 |
+
nii_repo_files = p_rs_files + other_nii_files
|
| 192 |
+
|
| 193 |
+
if nii_repo_files:
|
| 194 |
+
print(f"Found {len(nii_repo_files)} NIfTI files in repository: {nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files}...")
|
| 195 |
+
|
| 196 |
+
# Download each file
|
| 197 |
+
for nii_file in nii_repo_files:
|
| 198 |
+
try:
|
| 199 |
+
file_path = hf_hub_download(
|
| 200 |
+
repo_id=dataset_name,
|
| 201 |
+
filename=nii_file,
|
| 202 |
+
repo_type="dataset",
|
| 203 |
+
cache_dir=temp_dir
|
| 204 |
+
)
|
| 205 |
+
nii_files.append(file_path)
|
| 206 |
+
print(f"✓ Downloaded {nii_file}")
|
| 207 |
+
except Exception as e:
|
| 208 |
+
print(f"Error downloading {nii_file}: {e}")
|
| 209 |
except Exception as e:
|
| 210 |
+
print(f"Error listing repository files: {e}")
|
| 211 |
+
print("Will try alternative approaches...")
|
| 212 |
+
|
| 213 |
+
# If repo listing fails, try with common NIfTI file patterns directly
|
| 214 |
+
if not nii_files:
|
| 215 |
+
print("Trying common NIfTI file patterns...")
|
| 216 |
+
|
| 217 |
+
# Focus specifically on P*_rs.nii pattern
|
| 218 |
+
patterns = []
|
| 219 |
+
|
| 220 |
+
# Generate P01_rs.nii through P30_rs.nii
|
| 221 |
+
for i in range(1, 31): # Try subjects 1-30
|
| 222 |
+
patterns.append(f"P{i:02d}_rs.nii")
|
| 223 |
+
|
| 224 |
+
# Also try with .nii.gz extension
|
| 225 |
+
for i in range(1, 31):
|
| 226 |
+
patterns.append(f"P{i:02d}_rs.nii.gz")
|
| 227 |
+
|
| 228 |
+
# Include a few other common patterns as fallbacks
|
| 229 |
+
patterns.extend([
|
| 230 |
+
"sub-01_task-rest_bold.nii.gz", # BIDS format
|
| 231 |
+
"fmri.nii.gz", "bold.nii.gz",
|
| 232 |
+
"rest.nii.gz"
|
| 233 |
+
])
|
| 234 |
+
|
| 235 |
+
for pattern in patterns:
|
| 236 |
+
try:
|
| 237 |
+
print(f"Trying to download {pattern}...")
|
| 238 |
+
file_path = hf_hub_download(
|
| 239 |
+
repo_id=dataset_name,
|
| 240 |
+
filename=pattern,
|
| 241 |
+
repo_type="dataset",
|
| 242 |
+
cache_dir=temp_dir
|
| 243 |
+
)
|
| 244 |
+
nii_files.append(file_path)
|
| 245 |
+
print(f"✓ Successfully downloaded {pattern}")
|
| 246 |
+
except Exception as e:
|
| 247 |
+
print(f"× Failed to download {pattern}")
|
| 248 |
+
|
| 249 |
+
# If we still couldn't find any files, check if data files are nested
|
| 250 |
+
if not nii_files:
|
| 251 |
+
print("Checking for nested data files...")
|
| 252 |
+
nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"]
|
| 253 |
+
|
| 254 |
+
for path in nested_paths:
|
| 255 |
+
for pattern in patterns:
|
| 256 |
+
nested_file = f"{path}{pattern}"
|
| 257 |
+
try:
|
| 258 |
+
print(f"Trying to download {nested_file}...")
|
| 259 |
+
file_path = hf_hub_download(
|
| 260 |
+
repo_id=dataset_name,
|
| 261 |
+
filename=nested_file,
|
| 262 |
+
repo_type="dataset",
|
| 263 |
+
cache_dir=temp_dir
|
| 264 |
+
)
|
| 265 |
+
nii_files.append(file_path)
|
| 266 |
+
print(f"✓ Successfully downloaded {nested_file}")
|
| 267 |
+
# If we found one file in this directory, try to find all files in it
|
| 268 |
+
try:
|
| 269 |
+
all_files_in_dir = [f for f in all_repo_files if f.startswith(path)]
|
| 270 |
+
nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')]
|
| 271 |
+
print(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}")
|
| 272 |
+
|
| 273 |
+
for nii_file in nii_files_in_dir:
|
| 274 |
+
if nii_file != nested_file: # Skip the one we already downloaded
|
| 275 |
+
try:
|
| 276 |
+
file_path = hf_hub_download(
|
| 277 |
+
repo_id=dataset_name,
|
| 278 |
+
filename=nii_file,
|
| 279 |
+
repo_type="dataset",
|
| 280 |
+
cache_dir=temp_dir
|
| 281 |
+
)
|
| 282 |
+
nii_files.append(file_path)
|
| 283 |
+
print(f"✓ Downloaded {nii_file}")
|
| 284 |
+
except Exception as e:
|
| 285 |
+
print(f"Error downloading {nii_file}: {e}")
|
| 286 |
+
except Exception as e:
|
| 287 |
+
print(f"Error finding additional files in {path}: {e}")
|
| 288 |
+
except Exception as e:
|
| 289 |
+
pass
|
| 290 |
+
|
| 291 |
+
except Exception as e:
|
| 292 |
+
print(f"Error during repository exploration: {e}")
|
| 293 |
|
| 294 |
+
# If we still don't have any files, try to search for P*_rs.nii pattern specifically
|
| 295 |
+
if not nii_files:
|
| 296 |
+
print("Trying to find files matching P*_rs.nii pattern specifically...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 298 |
try:
|
| 299 |
+
# List all files in the repository (if we haven't already)
|
| 300 |
+
if not 'all_repo_files' in locals():
|
| 301 |
+
from huggingface_hub import list_repo_files
|
| 302 |
+
try:
|
| 303 |
+
all_repo_files = list_repo_files(dataset_name, repo_type="dataset")
|
| 304 |
+
except Exception as e:
|
| 305 |
+
print(f"Error listing repo files: {e}")
|
| 306 |
+
all_repo_files = []
|
| 307 |
|
| 308 |
+
# Look for files matching the pattern exactly (P*_rs.nii)
|
| 309 |
+
pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')]
|
| 310 |
+
|
| 311 |
+
# If we don't find any exact matches, try a more relaxed pattern
|
| 312 |
+
if not pattern_files:
|
| 313 |
+
pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()]
|
| 314 |
+
|
| 315 |
+
if pattern_files:
|
| 316 |
+
print(f"Found {len(pattern_files)} files matching rs.nii pattern")
|
| 317 |
+
|
| 318 |
+
# Download each file
|
| 319 |
+
for pattern_file in pattern_files:
|
| 320 |
+
try:
|
| 321 |
+
file_path = hf_hub_download(
|
| 322 |
+
repo_id=dataset_name,
|
| 323 |
+
filename=pattern_file,
|
| 324 |
+
repo_type="dataset",
|
| 325 |
+
cache_dir=temp_dir
|
| 326 |
+
)
|
| 327 |
+
nii_files.append(file_path)
|
| 328 |
+
print(f"✓ Downloaded {pattern_file}")
|
| 329 |
+
except Exception as e:
|
| 330 |
+
print(f"Error downloading {pattern_file}: {e}")
|
| 331 |
+
except Exception as e:
|
| 332 |
+
print(f"Error searching for pattern files: {e}")
|
| 333 |
+
|
| 334 |
+
print(f"Found total of {len(nii_files)} NIfTI files")
|
| 335 |
except Exception as e:
|
| 336 |
+
print(f"Unexpected error during NIfTI file search: {e}")
|
| 337 |
+
import traceback
|
| 338 |
+
traceback.print_exc()
|
| 339 |
+
|
| 340 |
+
# If we found NIfTI files, process them to FC matrices
|
| 341 |
+
if nii_files:
|
| 342 |
+
print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
|
| 343 |
+
|
| 344 |
+
# Load Power 264 atlas
|
| 345 |
+
from nilearn import datasets
|
| 346 |
+
power = datasets.fetch_coords_power_2011()
|
| 347 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 348 |
+
|
| 349 |
+
masker = input_data.NiftiSpheresMasker(
|
| 350 |
+
coords, radius=5,
|
| 351 |
+
standardize=True,
|
| 352 |
+
memory='nilearn_cache', memory_level=1,
|
| 353 |
+
verbose=0,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 354 |
detrend=True,
|
|
|
|
| 355 |
low_pass=0.1,
|
| 356 |
high_pass=0.01,
|
| 357 |
+
t_r=2.0 # Adjust TR according to your data
|
| 358 |
)
|
| 359 |
|
| 360 |
+
# Process fMRI data and compute FC matrices
|
| 361 |
+
fc_matrices = []
|
| 362 |
+
valid_files = 0
|
| 363 |
+
total_files = len(nii_files)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
|
| 365 |
+
for nii_file in nii_files:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 366 |
try:
|
| 367 |
+
print(f"Processing {nii_file}...")
|
| 368 |
+
fmri_img = load_img(nii_file)
|
| 369 |
|
| 370 |
+
# Check image dimensions
|
| 371 |
+
if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 10:
|
| 372 |
+
print(f"Warning: {nii_file} has insufficient time points: {fmri_img.shape}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 373 |
continue
|
| 374 |
+
|
| 375 |
+
time_series = masker.fit_transform(fmri_img)
|
| 376 |
|
| 377 |
+
# Validate time series data
|
| 378 |
+
if np.isnan(time_series).any() or np.isinf(time_series).any():
|
| 379 |
+
print(f"Warning: {nii_file} contains NaN or Inf values after masking")
|
| 380 |
+
# Replace NaNs with zeros for this file
|
| 381 |
+
time_series = np.nan_to_num(time_series)
|
| 382 |
|
| 383 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 384 |
+
kind='correlation',
|
| 385 |
+
vectorize=False,
|
| 386 |
+
discard_diagonal=False
|
| 387 |
+
)
|
| 388 |
|
| 389 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
|
|
|
| 390 |
|
| 391 |
+
# Check for invalid correlation values
|
| 392 |
+
if np.isnan(fc_matrix).any():
|
| 393 |
+
print(f"Warning: {nii_file} produced NaN correlation values")
|
| 394 |
+
continue
|
| 395 |
+
|
| 396 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 397 |
+
fc_triu = fc_matrix[triu_indices]
|
| 398 |
|
| 399 |
+
# Fisher z-transform with proper bounds check
|
| 400 |
+
# Clip correlation values to valid range for arctanh
|
| 401 |
+
fc_triu_clipped = np.clip(fc_triu, -0.999, 0.999)
|
| 402 |
+
fc_triu = np.arctanh(fc_triu_clipped)
|
| 403 |
|
| 404 |
+
fc_matrices.append(fc_triu)
|
| 405 |
+
valid_files += 1
|
| 406 |
+
print(f"Successfully processed {nii_file} to FC matrix")
|
| 407 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 408 |
except Exception as e:
|
| 409 |
+
print(f"Error processing {nii_file}: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
|
| 411 |
+
if fc_matrices:
|
| 412 |
+
print(f"Successfully processed {valid_files} out of {total_files} files")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 413 |
|
| 414 |
+
# Ensure all matrices have the same dimensions
|
| 415 |
+
dims = [m.shape[0] for m in fc_matrices]
|
| 416 |
+
if len(set(dims)) > 1:
|
| 417 |
+
print(f"Warning: FC matrices have inconsistent dimensions: {dims}")
|
| 418 |
+
# Use the most common dimension
|
| 419 |
+
from collections import Counter
|
| 420 |
+
most_common_dim = Counter(dims).most_common(1)[0][0]
|
| 421 |
+
print(f"Using most common dimension: {most_common_dim}")
|
| 422 |
+
fc_matrices = [m for m in fc_matrices if m.shape[0] == most_common_dim]
|
| 423 |
|
| 424 |
+
X = np.array(fc_matrices)
|
|
|
|
|
|
|
| 425 |
|
| 426 |
+
# Normalize the FC data
|
| 427 |
+
mean_x = np.mean(X, axis=0)
|
| 428 |
+
std_x = np.std(X, axis=0)
|
| 429 |
|
| 430 |
+
# Handle zero standard deviation
|
| 431 |
+
std_x[std_x == 0] = 1.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 432 |
|
| 433 |
+
X = (X - mean_x) / std_x
|
| 434 |
+
print(f"Created FC matrices with shape {X.shape}")
|
|
|
|
|
|
|
| 435 |
|
| 436 |
+
# Make sure demo_data matches the number of FC matrices
|
| 437 |
+
if len(demo_data[0]) != X.shape[0]:
|
| 438 |
+
print(f"Warning: Number of subjects in demographic data ({len(demo_data[0])}) " +
|
| 439 |
+
f"doesn't match number of FC matrices ({X.shape[0]})")
|
| 440 |
+
# Adjust demo_data to match FC matrices
|
| 441 |
+
indices = list(range(min(len(demo_data[0]), X.shape[0])))
|
| 442 |
+
X = X[indices]
|
| 443 |
+
demo_data = [d[indices] for d in demo_data]
|
| 444 |
+
|
| 445 |
+
return X, demo_data, demo_types
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 446 |
|
| 447 |
+
print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
|
| 448 |
+
# Return a placeholder with the right demographics but empty FC
|
| 449 |
+
n_subjects = len(dataset)
|
| 450 |
+
n_rois = 264
|
| 451 |
+
fc_dim = (n_rois * (n_rois - 1)) // 2
|
| 452 |
+
X = np.zeros((n_subjects, fc_dim))
|
| 453 |
+
print(f"Created placeholder FC matrices with shape {X.shape}")
|
| 454 |
+
return X, demo_data, demo_types
|
| 455 |
+
|
| 456 |
+
elif isinstance(dataset_or_niifiles, str):
|
| 457 |
+
# Handle real dataset with actual fMRI data
|
| 458 |
+
dataset = load_dataset(dataset_or_niifiles, split="train")
|
| 459 |
+
|
| 460 |
+
# Load Power 264 atlas
|
| 461 |
+
from nilearn import datasets
|
| 462 |
+
power = datasets.fetch_coords_power_2011()
|
| 463 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 464 |
+
|
| 465 |
+
masker = input_data.NiftiSpheresMasker(
|
| 466 |
+
coords, radius=5,
|
| 467 |
+
standardize=True,
|
| 468 |
+
memory='nilearn_cache', memory_level=1,
|
| 469 |
+
verbose=0,
|
| 470 |
+
detrend=True,
|
| 471 |
+
low_pass=0.1,
|
| 472 |
+
high_pass=0.01,
|
| 473 |
+
t_r=2.0 # Adjust TR according to your data
|
| 474 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 475 |
|
| 476 |
+
# Load demographic data if needed
|
| 477 |
+
if demo_data is None:
|
| 478 |
+
if 'demographics' in dataset.features:
|
| 479 |
+
demo_df = pd.DataFrame(dataset['demographics'])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
demo_data = [
|
| 482 |
+
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
|
| 483 |
+
demo_df['sex'].values if 'sex' in demo_df.columns else [],
|
| 484 |
+
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
|
| 485 |
+
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
|
| 486 |
]
|
| 487 |
|
| 488 |
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
|
|
|
|
|
|
|
|
|
| 489 |
|
| 490 |
+
# Process fMRI data and compute FC matrices
|
| 491 |
+
fc_matrices = []
|
| 492 |
+
for nii_file in dataset['nii_files']:
|
| 493 |
+
fmri_img = load_img(nii_file)
|
| 494 |
+
time_series = masker.fit_transform(fmri_img)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 495 |
|
| 496 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 497 |
+
kind='correlation', vectorize=False, discard_diagonal=False
|
| 498 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
|
| 500 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 501 |
|
| 502 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 503 |
+
fc_triu = fc_matrix[triu_indices]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 504 |
|
| 505 |
+
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
|
| 506 |
+
|
| 507 |
+
fc_matrices.append(fc_triu)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 508 |
|
| 509 |
+
X = np.array(fc_matrices)
|
|
|
|
|
|
|
|
|
|
| 510 |
|
| 511 |
+
elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
|
| 512 |
+
# Handle a list of NIfTI files
|
| 513 |
+
# Similar processing as above but with local files
|
| 514 |
+
print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
|
| 515 |
+
|
| 516 |
+
# Load Power 264 atlas
|
| 517 |
+
from nilearn import datasets
|
| 518 |
+
power = datasets.fetch_coords_power_2011()
|
| 519 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 520 |
+
|
| 521 |
+
masker = input_data.NiftiSpheresMasker(
|
| 522 |
+
coords, radius=5,
|
| 523 |
+
standardize=True,
|
| 524 |
+
memory='nilearn_cache', memory_level=1,
|
| 525 |
+
verbose=0,
|
| 526 |
+
detrend=True,
|
| 527 |
+
low_pass=0.1,
|
| 528 |
+
high_pass=0.01,
|
| 529 |
+
t_r=2.0
|
| 530 |
+
)
|
| 531 |
+
|
| 532 |
+
fc_matrices = []
|
| 533 |
+
for nii_file in dataset_or_niifiles:
|
| 534 |
+
fmri_img = load_img(nii_file)
|
| 535 |
+
time_series = masker.fit_transform(fmri_img)
|
| 536 |
|
| 537 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 538 |
+
kind='correlation', vectorize=False, discard_diagonal=False
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 542 |
+
|
| 543 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 544 |
+
fc_triu = fc_matrix[triu_indices]
|
| 545 |
+
|
| 546 |
+
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
|
| 547 |
+
|
| 548 |
+
fc_matrices.append(fc_triu)
|
| 549 |
+
|
| 550 |
+
X = np.array(fc_matrices)
|
| 551 |
+
else:
|
| 552 |
+
raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 553 |
|
| 554 |
+
# Normalize the FC data
|
| 555 |
+
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 556 |
|
| 557 |
+
return X, demo_data, demo_types
|
gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
main.py
CHANGED
|
@@ -1,490 +1,272 @@
|
|
| 1 |
import os
|
| 2 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
from pathlib import Path
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import pandas as pd
|
|
|
|
|
|
|
| 6 |
|
| 7 |
-
|
| 8 |
-
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
|
| 9 |
-
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
|
| 10 |
-
|
| 11 |
-
from data_preprocessing import load_and_preprocess_data
|
| 12 |
-
from vae_model import DemoVAE
|
| 13 |
-
from rcf_prediction import AphasiaTreatmentPredictor
|
| 14 |
-
from visualization import plot_fc_matrices, plot_learning_curves
|
| 15 |
-
from config import MODEL_CONFIG, PREDICTION_CONFIG
|
| 16 |
-
# Configure matplotlib for headless environment
|
| 17 |
-
import matplotlib
|
| 18 |
-
matplotlib.use('Agg') # Use non-interactive backend
|
| 19 |
-
import matplotlib.pyplot as plt
|
| 20 |
-
|
| 21 |
-
def run_analysis(data_dir="data",
|
| 22 |
-
demographic_file="demographics.csv",
|
| 23 |
-
treatment_file="treatment_outcomes.csv",
|
| 24 |
-
latent_dim=32,
|
| 25 |
-
nepochs=1000,
|
| 26 |
-
bsize=16,
|
| 27 |
-
save_model=True,
|
| 28 |
-
use_hf_dataset=False,
|
| 29 |
-
hf_dataset=None,
|
| 30 |
-
hf_nii_files=None,
|
| 31 |
-
hf_demo_data=None,
|
| 32 |
-
hf_demo_types=None,
|
| 33 |
-
return_data=False,
|
| 34 |
-
max_samples=None,
|
| 35 |
-
skip_treatment_prediction=False):
|
| 36 |
"""
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
Args:
|
| 40 |
-
data_dir: Directory containing data files or HuggingFace dataset name
|
| 41 |
-
demographic_file: Path to demographic CSV file (or None if using API)
|
| 42 |
-
treatment_file: Path to treatment outcomes CSV file
|
| 43 |
-
latent_dim: Dimension of VAE latent space
|
| 44 |
-
nepochs: Number of training epochs
|
| 45 |
-
bsize: Batch size for training
|
| 46 |
-
save_model: Whether to save trained models
|
| 47 |
-
use_hf_dataset: Whether to use HuggingFace dataset API
|
| 48 |
-
hf_dataset: Pre-loaded HuggingFace dataset object
|
| 49 |
-
hf_nii_files: List of NIfTI file paths from HuggingFace dataset
|
| 50 |
-
hf_demo_data: Demographic data from HuggingFace dataset API
|
| 51 |
-
hf_demo_types: Types of demographic variables (continuous/categorical)
|
| 52 |
-
return_data: Whether to return raw data for accuracy calculations
|
| 53 |
-
skip_treatment_prediction: Skip treatment prediction step (for FC analysis only)
|
| 54 |
"""
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
'latent_dim': latent_dim,
|
| 58 |
-
'nepochs': nepochs,
|
| 59 |
-
'bsize': bsize
|
| 60 |
-
})
|
| 61 |
-
|
| 62 |
-
# Create output directories
|
| 63 |
-
os.makedirs('models', exist_ok=True)
|
| 64 |
-
os.makedirs('results', exist_ok=True)
|
| 65 |
-
|
| 66 |
-
# Load and preprocess data based on source
|
| 67 |
-
print("Loading and preprocessing data...")
|
| 68 |
-
if use_hf_dataset:
|
| 69 |
-
# Use HuggingFace dataset
|
| 70 |
-
if hf_demo_data is not None and hf_demo_types is not None:
|
| 71 |
-
# If demographic data is provided directly from API
|
| 72 |
-
print(f"Using demographic data from HuggingFace API")
|
| 73 |
-
X, demo_data, demo_types = load_and_preprocess_data(
|
| 74 |
-
data_dir,
|
| 75 |
-
demographic_file,
|
| 76 |
-
use_hf_dataset=True,
|
| 77 |
-
hf_nii_files=hf_nii_files,
|
| 78 |
-
hf_demo_data=hf_demo_data,
|
| 79 |
-
hf_demo_types=hf_demo_types,
|
| 80 |
-
max_samples=max_samples
|
| 81 |
-
)
|
| 82 |
-
else:
|
| 83 |
-
# If demographic file is provided but still using HF for NIfTI
|
| 84 |
-
X, demo_data, demo_types = load_and_preprocess_data(
|
| 85 |
-
data_dir,
|
| 86 |
-
demographic_file,
|
| 87 |
-
use_hf_dataset=True,
|
| 88 |
-
hf_nii_files=hf_nii_files,
|
| 89 |
-
max_samples=max_samples
|
| 90 |
-
)
|
| 91 |
-
else:
|
| 92 |
-
# Standard local file loading
|
| 93 |
-
X, demo_data, demo_types = load_and_preprocess_data(data_dir, demographic_file, max_samples=max_samples)
|
| 94 |
-
|
| 95 |
-
# If we're doing treatment prediction, load treatment outcomes
|
| 96 |
-
treatment_outcomes = None
|
| 97 |
-
if not skip_treatment_prediction and treatment_file:
|
| 98 |
-
treatment_df = pd.read_csv(treatment_file)
|
| 99 |
-
treatment_outcomes = treatment_df['outcome_score'].values
|
| 100 |
-
|
| 101 |
-
# Ensure we have enough treatment outcomes based on input data
|
| 102 |
-
if len(treatment_outcomes) < len(X):
|
| 103 |
-
print(f"WARNING: Not enough treatment outcomes ({len(treatment_outcomes)}) for input data ({len(X)})")
|
| 104 |
-
# Generate synthetic outcomes to match input data size
|
| 105 |
-
synthetic_outcomes = np.random.normal(5, 2, size=(len(X) - len(treatment_outcomes)))
|
| 106 |
-
treatment_outcomes = np.concatenate([treatment_outcomes, synthetic_outcomes])
|
| 107 |
-
print(f"Added {len(synthetic_outcomes)} synthetic outcomes to match input data")
|
| 108 |
-
|
| 109 |
-
# Ensure we don't have too many treatment outcomes
|
| 110 |
-
if len(treatment_outcomes) > len(X):
|
| 111 |
-
print(f"WARNING: More treatment outcomes ({len(treatment_outcomes)}) than input data ({len(X)})")
|
| 112 |
-
treatment_outcomes = treatment_outcomes[:len(X)]
|
| 113 |
-
print(f"Trimmed treatment outcomes to match input data size: {len(X)}")
|
| 114 |
-
|
| 115 |
-
print(f"Using {len(treatment_outcomes)} treatment outcomes for {len(X)} input samples")
|
| 116 |
-
|
| 117 |
-
# Initialize and train VAE
|
| 118 |
-
print("Training VAE...")
|
| 119 |
-
vae = DemoVAE(**MODEL_CONFIG)
|
| 120 |
-
try:
|
| 121 |
-
train_losses, val_losses = vae.fit(X, demo_data, demo_types)
|
| 122 |
-
print(f"VAE training complete. Final train loss: {train_losses[-1]:.4f}, final validation loss: {val_losses[-1]:.4f}")
|
| 123 |
-
except Exception as e:
|
| 124 |
-
print(f"Error during VAE training: {e}")
|
| 125 |
-
print("Using empty lists for losses as fallback")
|
| 126 |
-
train_losses, val_losses = [], []
|
| 127 |
|
| 128 |
-
|
| 129 |
-
print("Extracting latent representations...")
|
| 130 |
-
latents = vae.get_latents(X)
|
| 131 |
|
| 132 |
-
#
|
| 133 |
-
|
|
|
|
|
|
|
| 134 |
|
| 135 |
-
#
|
| 136 |
-
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
-
#
|
| 139 |
-
|
| 140 |
-
|
|
|
|
| 141 |
|
| 142 |
-
#
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
|
|
|
|
|
|
|
|
|
| 149 |
|
| 150 |
-
|
| 151 |
-
|
| 152 |
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 158 |
|
| 159 |
-
print("
|
| 160 |
-
|
| 161 |
-
[d[:1] for d in demo_data],
|
| 162 |
-
demo_types)
|
| 163 |
-
print(f"Generated FC shape: {generated.shape}")
|
| 164 |
|
| 165 |
-
|
| 166 |
-
print("Saving FC matrices...")
|
| 167 |
-
np.save('results/reconstructed_fc.npy', reconstructed)
|
| 168 |
-
np.save('results/generated_fc.npy', generated)
|
| 169 |
|
| 170 |
-
#
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
|
|
|
| 178 |
|
| 179 |
-
print(f"
|
| 180 |
|
| 181 |
-
#
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
|
|
|
| 188 |
|
| 189 |
-
#
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
import traceback
|
| 195 |
-
print(f"Error creating FC visualization: {e}")
|
| 196 |
-
print(f"Detailed error: {traceback.format_exc()}")
|
| 197 |
-
fc_fig = plt.figure(figsize=(15, 5))
|
| 198 |
-
plt.text(0.5, 0.5, f"FC visualization unavailable: {str(e)}",
|
| 199 |
-
ha='center', va='center', transform=plt.gca().transAxes)
|
| 200 |
-
plt.tight_layout()
|
| 201 |
-
|
| 202 |
-
# Learning curves
|
| 203 |
-
try:
|
| 204 |
-
print("Creating learning curve visualization...")
|
| 205 |
-
|
| 206 |
-
# Check if losses are stored in the VAE object first (most reliable source)
|
| 207 |
-
train_data = []
|
| 208 |
-
val_data = []
|
| 209 |
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
print(f"Using {len(train_data)} real training loss points from fit return value")
|
| 217 |
-
else:
|
| 218 |
-
# Instead of synthetic data, provide empty list and warning
|
| 219 |
-
print("WARNING: No real training loss data found")
|
| 220 |
-
train_data = []
|
| 221 |
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
print(f"Found {len(val_data)} real validation loss points in VAE object")
|
| 226 |
-
elif val_losses and len(val_losses) > 0:
|
| 227 |
-
val_data = val_losses
|
| 228 |
-
print(f"Using {len(val_data)} real validation loss points from fit return value")
|
| 229 |
-
else:
|
| 230 |
-
# Instead of synthetic data, provide empty list and warning
|
| 231 |
-
print("WARNING: No real validation loss data found")
|
| 232 |
-
val_data = []
|
| 233 |
-
|
| 234 |
-
# If we get here, we have some training data (real or synthetic)
|
| 235 |
-
# Store the data in the VAE object for future use
|
| 236 |
-
if not hasattr(vae, 'train_losses') or len(getattr(vae, 'train_losses', [])) == 0:
|
| 237 |
-
print("Storing training loss data in VAE object")
|
| 238 |
-
vae.train_losses = train_data
|
| 239 |
|
| 240 |
-
|
| 241 |
-
|
| 242 |
-
|
|
|
|
|
|
|
| 243 |
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
learning_fig = plot_learning_curves(train_data, val_data)
|
| 247 |
-
except Exception as e:
|
| 248 |
-
import traceback
|
| 249 |
-
print(f"Error creating learning curve plot: {e}")
|
| 250 |
-
print(f"Traceback: {traceback.format_exc()}")
|
| 251 |
-
|
| 252 |
-
# Create a more informative error display
|
| 253 |
-
learning_fig = plt.figure(figsize=(10, 6))
|
| 254 |
-
plt.text(0.5, 0.5, f"Error creating learning curves: {str(e)}",
|
| 255 |
-
ha='center', va='center', transform=plt.gca().transAxes,
|
| 256 |
-
fontsize=12, color='darkred')
|
| 257 |
-
plt.axis('off')
|
| 258 |
-
plt.tight_layout()
|
| 259 |
-
|
| 260 |
-
# Check if we should use strict real data mode
|
| 261 |
-
use_strict_real_data = PREDICTION_CONFIG.get('strict_real_data', False)
|
| 262 |
-
no_mock_data = PREDICTION_CONFIG.get('no_mock_data', False)
|
| 263 |
-
|
| 264 |
-
if use_strict_real_data or no_mock_data:
|
| 265 |
-
print("Using strict real data mode - only including real data in results")
|
| 266 |
-
# Only include figures if they contain real data
|
| 267 |
-
figures = {}
|
| 268 |
-
if hasattr(vae, 'train_losses') and len(vae.train_losses) > 0:
|
| 269 |
-
figures['learning_curves'] = learning_fig
|
| 270 |
-
print("Including real learning curves")
|
| 271 |
else:
|
| 272 |
-
print("
|
|
|
|
| 273 |
|
| 274 |
-
# Only include FC analysis if it's based on real data
|
| 275 |
-
if len(np.array(X).shape) > 0 and len(X) > 0:
|
| 276 |
-
figures['vae'] = fc_fig
|
| 277 |
-
figures['fc_analysis'] = fc_fig
|
| 278 |
-
print("Including real FC analysis")
|
| 279 |
-
else:
|
| 280 |
-
print("WARNING: No real FC data available")
|
| 281 |
else:
|
| 282 |
-
#
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
-
#
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
})
|
| 304 |
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
predictor = AphasiaTreatmentPredictor(
|
| 310 |
-
n_estimators=PREDICTION_CONFIG.get('n_estimators', 100),
|
| 311 |
-
max_depth=PREDICTION_CONFIG.get('max_depth', None)
|
| 312 |
-
)
|
| 313 |
|
| 314 |
-
#
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
|
| 321 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
-
#
|
| 324 |
-
|
| 325 |
-
|
|
|
|
| 326 |
|
| 327 |
-
#
|
|
|
|
| 328 |
try:
|
| 329 |
-
|
| 330 |
-
|
| 331 |
-
|
| 332 |
-
|
| 333 |
-
|
| 334 |
-
|
| 335 |
-
|
|
|
|
|
|
|
| 336 |
except Exception as e:
|
| 337 |
-
print(f"Error
|
| 338 |
-
|
| 339 |
-
predictions = np.zeros(len(treatment_outcomes))
|
| 340 |
-
prediction_stds = np.zeros(len(treatment_outcomes))
|
| 341 |
|
| 342 |
-
#
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
if fold_metrics and "r2" in fold_metrics[0]:
|
| 346 |
-
cv_std = np.std([fold.get("r2", 0.0) for fold in fold_metrics])
|
| 347 |
-
else:
|
| 348 |
-
cv_std = 0.0
|
| 349 |
-
except Exception as e:
|
| 350 |
-
print(f"Error calculating CV metrics: {e}")
|
| 351 |
-
cv_mean, cv_std = 0.0, 0.0
|
| 352 |
|
| 353 |
-
#
|
| 354 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
-
#
|
| 357 |
-
|
| 358 |
-
importance_fig = predictor.plot_feature_importance()
|
| 359 |
-
except Exception as e:
|
| 360 |
-
print(f"Error creating feature importance plot: {e}")
|
| 361 |
-
importance_fig = plt.figure(figsize=(8, 6))
|
| 362 |
-
plt.text(0.5, 0.5, "Feature importance unavailable",
|
| 363 |
-
ha='center', va='center', transform=plt.gca().transAxes)
|
| 364 |
-
plt.tight_layout()
|
| 365 |
|
| 366 |
-
#
|
| 367 |
-
|
| 368 |
|
| 369 |
-
#
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
min_val = min(np.min(treatment_outcomes), np.min(predictions))
|
| 377 |
-
max_val = max(np.max(treatment_outcomes), np.max(predictions))
|
| 378 |
-
plt.plot([min_val, max_val], [min_val, max_val], 'r--')
|
| 379 |
-
|
| 380 |
-
# Confidence band
|
| 381 |
-
plt.fill_between(treatment_outcomes,
|
| 382 |
-
predictions - 2*prediction_stds,
|
| 383 |
-
predictions + 2*prediction_stds,
|
| 384 |
-
alpha=0.2, color='gray')
|
| 385 |
-
|
| 386 |
-
# Labels
|
| 387 |
-
plt.xlabel('Actual Outcome')
|
| 388 |
-
plt.ylabel('Predicted Outcome')
|
| 389 |
-
|
| 390 |
-
# Title with metrics
|
| 391 |
-
if predictor.prediction_type == "regression":
|
| 392 |
-
plt.title(f'Treatment Outcome Prediction\nR² = {cv_mean:.3f} ± {cv_std:.3f}')
|
| 393 |
-
else:
|
| 394 |
-
plt.title(f'Treatment Outcome Prediction\nAccuracy = {cv_mean:.3f} ± {cv_std:.3f}')
|
| 395 |
-
except Exception as e:
|
| 396 |
-
print(f"Error creating performance plot: {e}")
|
| 397 |
-
plt.text(0.5, 0.5, "Error creating plot",
|
| 398 |
-
ha='center', va='center', transform=plt.gca().transAxes)
|
| 399 |
-
else:
|
| 400 |
-
# Handle case with no data
|
| 401 |
-
plt.text(0.5, 0.5, "No prediction data available",
|
| 402 |
-
ha='center', va='center', transform=plt.gca().transAxes)
|
| 403 |
-
|
| 404 |
-
plt.tight_layout()
|
| 405 |
|
| 406 |
-
#
|
| 407 |
-
print("
|
| 408 |
-
|
| 409 |
-
|
| 410 |
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
-
|
| 417 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
-
#
|
| 420 |
-
|
| 421 |
-
|
| 422 |
-
'cv_scores': (cv_mean, cv_std),
|
| 423 |
-
'predictions': predictions,
|
| 424 |
-
'prediction_stds': prediction_stds,
|
| 425 |
-
'predictor_cv_results': predictor_cv_results,
|
| 426 |
-
})
|
| 427 |
|
| 428 |
-
|
| 429 |
-
results['figures'].update({
|
| 430 |
-
'importance': importance_fig,
|
| 431 |
-
'performance': performance_fig
|
| 432 |
-
})
|
| 433 |
-
|
| 434 |
-
# Save models if requested
|
| 435 |
-
if save_model:
|
| 436 |
-
print("Saving models...")
|
| 437 |
-
vae.save('models/vae_model.pt')
|
| 438 |
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
|
| 451 |
-
|
| 452 |
-
latent_dim=32,
|
| 453 |
-
nepochs=100,
|
| 454 |
-
bsize=16,
|
| 455 |
-
save_model=True,
|
| 456 |
-
use_hf_dataset=True,
|
| 457 |
-
return_data=True):
|
| 458 |
-
"""
|
| 459 |
-
Run only the FC analysis portion without prediction
|
| 460 |
-
|
| 461 |
-
This is a simplified version of run_analysis focused on VAE training
|
| 462 |
-
and FC matrix visualization for the Gradio interface.
|
| 463 |
-
"""
|
| 464 |
-
# Call the main function with skip_treatment_prediction=True
|
| 465 |
-
return run_analysis(
|
| 466 |
-
data_dir=data_dir,
|
| 467 |
-
demographic_file=demographic_file,
|
| 468 |
-
treatment_file=None, # No treatment file needed
|
| 469 |
-
latent_dim=latent_dim,
|
| 470 |
-
nepochs=nepochs,
|
| 471 |
-
bsize=bsize,
|
| 472 |
-
save_model=save_model,
|
| 473 |
-
use_hf_dataset=use_hf_dataset,
|
| 474 |
-
return_data=return_data,
|
| 475 |
-
skip_treatment_prediction=True
|
| 476 |
-
)
|
| 477 |
|
| 478 |
if __name__ == "__main__":
|
| 479 |
import argparse
|
| 480 |
|
| 481 |
-
parser = argparse.ArgumentParser(description='Run
|
| 482 |
-
parser.add_argument('--data_dir', type=str, default='
|
| 483 |
-
help='
|
| 484 |
-
parser.add_argument('--demographic_file', type=str, default='
|
| 485 |
help='Path to demographic data CSV file')
|
| 486 |
-
parser.add_argument('--treatment_file', type=str, default='treatment_outcomes.csv',
|
| 487 |
-
help='Path to treatment outcomes CSV file')
|
| 488 |
parser.add_argument('--latent_dim', type=int, default=32,
|
| 489 |
help='Dimension of latent space')
|
| 490 |
parser.add_argument('--nepochs', type=int, default=1000,
|
|
@@ -492,28 +274,20 @@ if __name__ == "__main__":
|
|
| 492 |
parser.add_argument('--bsize', type=int, default=16,
|
| 493 |
help='Batch size for training')
|
| 494 |
parser.add_argument('--no_save', action='store_false',
|
| 495 |
-
help='Do not save the
|
| 496 |
-
parser.add_argument('--
|
| 497 |
-
help='
|
| 498 |
|
| 499 |
args = parser.parse_args()
|
| 500 |
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
| 505 |
-
|
| 506 |
-
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
|
| 510 |
-
|
| 511 |
-
|
| 512 |
-
data_dir=args.data_dir,
|
| 513 |
-
demographic_file=args.demographic_file,
|
| 514 |
-
treatment_file=args.treatment_file,
|
| 515 |
-
latent_dim=args.latent_dim,
|
| 516 |
-
nepochs=args.nepochs,
|
| 517 |
-
bsize=args.bsize,
|
| 518 |
-
save_model=args.no_save
|
| 519 |
-
)
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sys
|
| 3 |
+
# Add the src directory to the path so we can import from demovae
|
| 4 |
+
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
import torch
|
| 8 |
from pathlib import Path
|
| 9 |
+
import nibabel as nib
|
| 10 |
+
from data_preprocessing import preprocess_fmri_to_fc
|
| 11 |
+
from src.demovae.sklearn import DemoVAE
|
| 12 |
+
from analysis import analyze_fc_patterns
|
| 13 |
+
from visualization import visualize_fc_analysis
|
| 14 |
+
from config import MODEL_CONFIG, DATASET_CONFIG
|
| 15 |
import pandas as pd
|
| 16 |
+
import io
|
| 17 |
+
from typing import List, Dict, Union, Tuple, Any
|
| 18 |
|
| 19 |
+
def train_fc_vae(X, demo_data, demo_types, model_config):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
"""
|
| 21 |
+
Train a VAE model on functional connectivity matrices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"""
|
| 23 |
+
n_rois = 264
|
| 24 |
+
input_dim = (n_rois * (n_rois - 1)) // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
|
| 26 |
+
print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
# Ensure X is a numpy array with correct data type
|
| 29 |
+
if not isinstance(X, np.ndarray):
|
| 30 |
+
print(f"Converting X from {type(X)} to numpy array")
|
| 31 |
+
X = np.array(X, dtype=np.float32)
|
| 32 |
|
| 33 |
+
# Ensure demo_data contains numpy arrays
|
| 34 |
+
for i, d in enumerate(demo_data):
|
| 35 |
+
if not isinstance(d, np.ndarray):
|
| 36 |
+
print(f"Converting demographic {i} from {type(d)} to numpy array")
|
| 37 |
+
demo_data[i] = np.array(d)
|
| 38 |
|
| 39 |
+
# Check for NaN or Inf values
|
| 40 |
+
if np.isnan(X).any() or np.isinf(X).any():
|
| 41 |
+
print("Warning: X contains NaN or Inf values. Replacing with zeros.")
|
| 42 |
+
X = np.nan_to_num(X)
|
| 43 |
|
| 44 |
+
# Create the VAE model
|
| 45 |
+
vae = DemoVAE(
|
| 46 |
+
latent_dim=model_config['latent_dim'],
|
| 47 |
+
nepochs=model_config['nepochs'],
|
| 48 |
+
bsize=model_config['bsize'],
|
| 49 |
+
loss_rec_mult=model_config.get('loss_rec_mult', 100),
|
| 50 |
+
loss_decor_mult=model_config.get('loss_decor_mult', 10),
|
| 51 |
+
lr=model_config.get('lr', 1e-4),
|
| 52 |
+
use_cuda=torch.cuda.is_available()
|
| 53 |
+
)
|
| 54 |
|
| 55 |
+
print("Fitting VAE model...")
|
| 56 |
+
vae.fit(X, demo_data, demo_types)
|
| 57 |
|
| 58 |
+
return vae, X, demo_data, demo_types
|
| 59 |
+
|
| 60 |
+
def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
|
| 61 |
+
"""
|
| 62 |
+
Load fMRI data and demographics from HuggingFace dataset or local files
|
| 63 |
+
"""
|
| 64 |
+
if use_hf_dataset:
|
| 65 |
+
# Load from HuggingFace Datasets
|
| 66 |
+
from datasets import load_dataset
|
| 67 |
|
| 68 |
+
print(f"Loading dataset from HuggingFace: {data_dir}")
|
| 69 |
+
dataset = load_dataset(data_dir)
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
+
print(f"Dataset columns: {dataset['train'].column_names}")
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
+
# Get demographics directly from the dataset
|
| 74 |
+
# Create a DataFrame from the dataset features
|
| 75 |
+
demo_df = pd.DataFrame({
|
| 76 |
+
'ID': dataset['train']['ID'],
|
| 77 |
+
'wab_aq': dataset['train']['wab_aq'],
|
| 78 |
+
'age': dataset['train']['age'],
|
| 79 |
+
'mpo': dataset['train']['mpo'],
|
| 80 |
+
'education': dataset['train']['education'],
|
| 81 |
+
'gender': dataset['train']['gender'],
|
| 82 |
+
'handedness': dataset['train']['handedness']
|
| 83 |
+
})
|
| 84 |
|
| 85 |
+
print(f"Loaded demographic data with {len(demo_df)} subjects")
|
| 86 |
|
| 87 |
+
# Extract demographic data matching our expected format
|
| 88 |
+
# Map the dataset columns to our expected format
|
| 89 |
+
demo_data = [
|
| 90 |
+
demo_df['age'].values, # age at stroke -> age
|
| 91 |
+
demo_df['gender'].values, # sex -> gender
|
| 92 |
+
demo_df['mpo'].values, # months post stroke -> mpo
|
| 93 |
+
demo_df['wab_aq'].values # wab score -> wab_aq
|
| 94 |
+
]
|
| 95 |
|
| 96 |
+
# Check for FC matrices in the dataset
|
| 97 |
+
fc_columns = []
|
| 98 |
+
for col in dataset['train'].column_names:
|
| 99 |
+
if col.startswith("fc_") or "_fc" in col:
|
| 100 |
+
fc_columns.append(col)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
+
if fc_columns:
|
| 103 |
+
print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
|
| 104 |
+
# Extract FC matrices
|
| 105 |
+
fc_matrices = []
|
| 106 |
+
for fc_col in fc_columns:
|
| 107 |
+
fc_matrices.append(dataset['train'][fc_col])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 108 |
|
| 109 |
+
# If we have FC matrices, return them directly
|
| 110 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 111 |
+
return fc_matrices, demo_data, demo_types
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
+
# If no FC matrices, look for .nii files
|
| 114 |
+
nii_files = []
|
| 115 |
+
for col in dataset['train'].column_names:
|
| 116 |
+
if col.endswith(".nii.gz") or col.endswith(".nii"):
|
| 117 |
+
nii_files.append(dataset['train'][col])
|
| 118 |
|
| 119 |
+
if nii_files:
|
| 120 |
+
print(f"Found {len(nii_files)} .nii files")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
else:
|
| 122 |
+
print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
|
| 123 |
+
# If no structured data is found, we can try to download raw files later
|
| 124 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
else:
|
| 126 |
+
# Original local file loading
|
| 127 |
+
# Load demographics
|
| 128 |
+
demo_df = pd.read_csv(demographic_file)
|
| 129 |
+
|
| 130 |
+
demo_data = [
|
| 131 |
+
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
|
| 132 |
+
demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
|
| 133 |
+
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
|
| 134 |
+
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
|
| 135 |
+
]
|
| 136 |
+
|
| 137 |
+
# Load fMRI files
|
| 138 |
+
nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
|
| 139 |
|
| 140 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 141 |
+
return nii_files, demo_data, demo_types
|
| 142 |
+
|
| 143 |
+
def run_fc_analysis(data_dir="SreekarB/OSFData",
|
| 144 |
+
demographic_file=None,
|
| 145 |
+
latent_dim=32,
|
| 146 |
+
nepochs=1000,
|
| 147 |
+
bsize=16,
|
| 148 |
+
save_model=True,
|
| 149 |
+
use_hf_dataset=True):
|
| 150 |
|
| 151 |
+
# Update MODEL_CONFIG with user-specified parameters
|
| 152 |
+
MODEL_CONFIG.update({
|
| 153 |
+
'latent_dim': latent_dim,
|
| 154 |
+
'nepochs': nepochs,
|
| 155 |
+
'bsize': bsize
|
| 156 |
+
})
|
|
|
|
| 157 |
|
| 158 |
+
try:
|
| 159 |
+
# Load data
|
| 160 |
+
print("Loading data...")
|
| 161 |
+
nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 162 |
|
| 163 |
+
# For SreekarB/OSFData, directly generate synthetic FC matrices
|
| 164 |
+
if data_dir == "SreekarB/OSFData" and use_hf_dataset:
|
| 165 |
+
print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
|
| 166 |
+
X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
|
| 167 |
+
# Check if we got FC matrices directly
|
| 168 |
+
elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
|
| 169 |
+
print("Using pre-computed FC matrices...")
|
| 170 |
+
# Convert list of FC matrices to numpy array
|
| 171 |
+
X = np.stack([np.array(fc) for fc in nii_files])
|
| 172 |
+
else:
|
| 173 |
+
# Prepare data by converting fMRI to FC matrices
|
| 174 |
+
print("Converting fMRI data to FC matrices...")
|
| 175 |
+
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
|
| 176 |
|
| 177 |
+
# Print shapes and data types
|
| 178 |
+
print(f"X shape: {X.shape}, type: {type(X)}")
|
| 179 |
+
for i, d in enumerate(demo_data):
|
| 180 |
+
print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
|
| 181 |
|
| 182 |
+
# Train VAE and get data
|
| 183 |
+
print("Training VAE...")
|
| 184 |
try:
|
| 185 |
+
# Use the proper DemoVAE implementation from src/demovae/sklearn.py
|
| 186 |
+
vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
|
| 187 |
+
|
| 188 |
+
if save_model:
|
| 189 |
+
print("Saving model...")
|
| 190 |
+
os.makedirs('models', exist_ok=True)
|
| 191 |
+
# Use the save method from DemoVAE
|
| 192 |
+
vae.save('models/vae_model.pth')
|
| 193 |
+
print("Model saved successfully.")
|
| 194 |
except Exception as e:
|
| 195 |
+
print(f"Error during VAE training: {e}")
|
| 196 |
+
raise
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
# Get latent representations
|
| 199 |
+
print("Getting latent representations...")
|
| 200 |
+
latents = vae.get_latents(X)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 201 |
|
| 202 |
+
# Analyze results
|
| 203 |
+
print("Analyzing demographic relationships...")
|
| 204 |
+
demographics = {
|
| 205 |
+
'age': demo_data[0],
|
| 206 |
+
'months_post_onset': demo_data[2],
|
| 207 |
+
'wab_aq': demo_data[3]
|
| 208 |
+
}
|
| 209 |
+
analysis_results = analyze_fc_patterns(latents, demographics)
|
| 210 |
|
| 211 |
+
# Generate new FC matrix
|
| 212 |
+
print("Generating new FC matrices...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 213 |
|
| 214 |
+
# Get data types from original demographic data for proper conversion
|
| 215 |
+
demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
|
| 216 |
|
| 217 |
+
# Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
|
| 218 |
+
new_demographics = [
|
| 219 |
+
np.array([60.0], dtype=np.float64), # age
|
| 220 |
+
np.array(['M'], dtype=np.str_), # gender
|
| 221 |
+
np.array([12.0], dtype=np.float64), # months post onset
|
| 222 |
+
np.array([80.0], dtype=np.float64) # wab score
|
| 223 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 224 |
|
| 225 |
+
# Verify the demographic data arrays match the expected types
|
| 226 |
+
print("Demographic data types:")
|
| 227 |
+
for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
|
| 228 |
+
print(f" {name}: shape={data.shape}, dtype={data.dtype}")
|
| 229 |
|
| 230 |
+
print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
|
| 231 |
+
try:
|
| 232 |
+
generated_fc = vae.transform(1, new_demographics, demo_types)
|
| 233 |
+
except Exception as e:
|
| 234 |
+
print(f"Error generating new FC matrix: {e}")
|
| 235 |
+
# Try with a fallback approach
|
| 236 |
+
print("Trying alternative generation approach...")
|
| 237 |
+
# If specific gender is causing issues, try the first gender from training data
|
| 238 |
+
new_demographics[1] = np.array([demo_data[1][0]])
|
| 239 |
+
generated_fc = vae.transform(1, new_demographics, demo_types)
|
| 240 |
+
reconstructed_fc = vae.transform(X, demo_data, demo_types)
|
| 241 |
|
| 242 |
+
# Visualize results
|
| 243 |
+
print("Creating visualizations...")
|
| 244 |
+
fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 247 |
|
| 248 |
+
except Exception as e:
|
| 249 |
+
import traceback
|
| 250 |
+
print(f"Error in run_fc_analysis: {str(e)}")
|
| 251 |
+
print(traceback.format_exc())
|
| 252 |
+
|
| 253 |
+
# Create a dummy figure with error message
|
| 254 |
+
import matplotlib.pyplot as plt
|
| 255 |
+
fig = plt.figure(figsize=(10, 6))
|
| 256 |
+
plt.text(0.5, 0.5, f"Error: {str(e)}",
|
| 257 |
+
horizontalalignment='center', verticalalignment='center',
|
| 258 |
+
fontsize=12, color='red')
|
| 259 |
+
plt.axis('off')
|
| 260 |
+
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
if __name__ == "__main__":
|
| 263 |
import argparse
|
| 264 |
|
| 265 |
+
parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
|
| 266 |
+
parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
|
| 267 |
+
help='HuggingFace dataset ID or directory containing fMRI data')
|
| 268 |
+
parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
|
| 269 |
help='Path to demographic data CSV file')
|
|
|
|
|
|
|
| 270 |
parser.add_argument('--latent_dim', type=int, default=32,
|
| 271 |
help='Dimension of latent space')
|
| 272 |
parser.add_argument('--nepochs', type=int, default=1000,
|
|
|
|
| 274 |
parser.add_argument('--bsize', type=int, default=16,
|
| 275 |
help='Batch size for training')
|
| 276 |
parser.add_argument('--no_save', action='store_false',
|
| 277 |
+
help='Do not save the model')
|
| 278 |
+
parser.add_argument('--use_local', action='store_true',
|
| 279 |
+
help='Use local data instead of HuggingFace dataset')
|
| 280 |
|
| 281 |
args = parser.parse_args()
|
| 282 |
|
| 283 |
+
fig = run_fc_analysis(
|
| 284 |
+
data_dir=args.data_dir,
|
| 285 |
+
demographic_file=args.demographic_file,
|
| 286 |
+
latent_dim=args.latent_dim,
|
| 287 |
+
nepochs=args.nepochs,
|
| 288 |
+
bsize=args.bsize,
|
| 289 |
+
save_model=args.no_save,
|
| 290 |
+
use_hf_dataset=not args.use_local
|
| 291 |
+
)
|
| 292 |
+
fig.show()
|
| 293 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
|
@@ -1,8 +1,12 @@
|
|
| 1 |
torch>=1.9.0
|
| 2 |
numpy>=1.19.2
|
| 3 |
pandas>=1.2.4
|
|
|
|
|
|
|
| 4 |
scikit-learn>=0.24.2
|
| 5 |
matplotlib>=3.4.2
|
| 6 |
-
gradio>=
|
| 7 |
-
|
|
|
|
|
|
|
| 8 |
|
|
|
|
| 1 |
torch>=1.9.0
|
| 2 |
numpy>=1.19.2
|
| 3 |
pandas>=1.2.4
|
| 4 |
+
nilearn>=0.8.1
|
| 5 |
+
nibabel>=3.2.1
|
| 6 |
scikit-learn>=0.24.2
|
| 7 |
matplotlib>=3.4.2
|
| 8 |
+
gradio>=2.0.0
|
| 9 |
+
datasets>=1.11.0
|
| 10 |
+
huggingface_hub>=0.15.0
|
| 11 |
+
transformers>=4.15.0
|
| 12 |
|
test_hf_download.py
CHANGED
|
@@ -6,7 +6,7 @@ from datasets import load_dataset
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
-
def test_huggingface_download(dataset_name="SreekarB/
|
| 10 |
"""
|
| 11 |
Test script to verify downloading NIfTI files from HuggingFace Datasets
|
| 12 |
"""
|
|
@@ -227,7 +227,7 @@ if __name__ == "__main__":
|
|
| 227 |
# Process command line arguments
|
| 228 |
import argparse
|
| 229 |
parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
|
| 230 |
-
parser.add_argument('--dataset', type=str, default="SreekarB/
|
| 231 |
parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
|
| 232 |
parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
|
| 233 |
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
+
def test_huggingface_download(dataset_name="SreekarB/OSFData", revision=None, auth_token=None):
|
| 10 |
"""
|
| 11 |
Test script to verify downloading NIfTI files from HuggingFace Datasets
|
| 12 |
"""
|
|
|
|
| 227 |
# Process command line arguments
|
| 228 |
import argparse
|
| 229 |
parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
|
| 230 |
+
parser.add_argument('--dataset', type=str, default="SreekarB/OSFData", help='HuggingFace dataset name')
|
| 231 |
parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
|
| 232 |
parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
|
| 233 |
|
utils.py
CHANGED
|
@@ -3,18 +3,23 @@ import numpy as np
|
|
| 3 |
from sklearn.linear_model import Ridge, LogisticRegression
|
| 4 |
|
| 5 |
def to_torch(x):
|
| 6 |
-
if not isinstance(x, np.ndarray):
|
| 7 |
-
x = np.array(x)
|
| 8 |
return torch.from_numpy(x).float()
|
| 9 |
|
| 10 |
def to_cuda(x, use_cuda):
|
| 11 |
-
|
| 12 |
-
return x.cuda()
|
| 13 |
-
return x
|
| 14 |
|
| 15 |
def to_numpy(x):
|
| 16 |
return x.detach().cpu().numpy()
|
| 17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
def rmse(a, b, mean=torch.mean):
|
| 19 |
return mean((a-b)**2)**0.5
|
| 20 |
|
|
@@ -42,7 +47,6 @@ def decor_loss(z, demo, use_cuda=True):
|
|
| 42 |
ps.append(p)
|
| 43 |
losses = torch.stack(losses)
|
| 44 |
return losses, ps
|
| 45 |
-
|
| 46 |
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
|
| 47 |
demo_t = []
|
| 48 |
demo_idx = 0
|
|
@@ -66,33 +70,10 @@ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
|
|
| 66 |
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
|
| 67 |
loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
|
| 68 |
loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
|
| 69 |
-
|
| 70 |
# Get linear predictors for demographics
|
| 71 |
pred_w = []
|
| 72 |
pred_i = []
|
| 73 |
pred_stats = []
|
| 74 |
-
train_losses = []
|
| 75 |
-
val_losses = []
|
| 76 |
-
|
| 77 |
-
# Check if sample sizes are consistent
|
| 78 |
-
n_samples = x.shape[0]
|
| 79 |
-
print(f"Sample sizes - X: {n_samples}, Demographics: {[len(d) for d in demo]}")
|
| 80 |
-
|
| 81 |
-
# Ensure all sample sizes match
|
| 82 |
-
if any(len(d) != n_samples for d in demo):
|
| 83 |
-
print("WARNING: Sample size mismatch detected! Fixing...")
|
| 84 |
-
|
| 85 |
-
# Trim to smallest size
|
| 86 |
-
min_samples = min(n_samples, *[len(d) for d in demo])
|
| 87 |
-
print(f"Adjusting to {min_samples} samples")
|
| 88 |
-
|
| 89 |
-
# Adjust x and demo
|
| 90 |
-
x = x[:min_samples]
|
| 91 |
-
demo = [d[:min_samples] for d in demo]
|
| 92 |
-
|
| 93 |
-
print(f"After adjustment - X: {x.shape[0]}, Demographics: {[len(d) for d in demo]}")
|
| 94 |
-
|
| 95 |
-
print(f"Using {x.shape[0]} samples for training")
|
| 96 |
|
| 97 |
for i, d, t in zip(range(len(demo)), demo, demo_types):
|
| 98 |
print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
|
|
@@ -133,21 +114,7 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
|
|
| 133 |
ce = torch.nn.CrossEntropyLoss()
|
| 134 |
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 135 |
|
| 136 |
-
# Calculate initial validation loss
|
| 137 |
-
print("Calculating initial validation metrics...")
|
| 138 |
-
vae.eval()
|
| 139 |
-
with torch.no_grad():
|
| 140 |
-
z_val = vae.enc(x)
|
| 141 |
-
y_val = vae.dec(z_val, demo_t)
|
| 142 |
-
initial_val_loss = rmse(x, y_val).item()
|
| 143 |
-
val_losses.append(initial_val_loss)
|
| 144 |
-
print(f"Initial validation loss: {initial_val_loss:.4f}")
|
| 145 |
-
|
| 146 |
-
# Main training loop
|
| 147 |
for e in range(nepochs):
|
| 148 |
-
epoch_losses = []
|
| 149 |
-
vae.train()
|
| 150 |
-
|
| 151 |
for bs in range(0, len(x), bsize):
|
| 152 |
xb = x[bs:(bs+bsize)]
|
| 153 |
db = demo_t[bs:(bs+bsize)]
|
|
@@ -161,43 +128,59 @@ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
|
|
| 161 |
loss_decor = sum(loss_decor)
|
| 162 |
loss_rec = rmse(xb, y)
|
| 163 |
|
| 164 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
|
| 166 |
-
loss_rec_mult*loss_rec + loss_decor_mult*loss_decor
|
|
|
|
| 167 |
|
| 168 |
total_loss.backward()
|
| 169 |
optim.step()
|
| 170 |
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
# Print progress for every epoch
|
| 178 |
-
print(f'Epoch {e+1}/{nepochs} - Train Loss: {epoch_loss:.4f}')
|
| 179 |
-
|
| 180 |
-
# Validation step (perform at every epoch to have full data for plotting)
|
| 181 |
-
vae.eval()
|
| 182 |
-
with torch.no_grad():
|
| 183 |
-
z = vae.enc(x)
|
| 184 |
-
y = vae.dec(z, demo_t)
|
| 185 |
-
val_loss = rmse(x, y).item()
|
| 186 |
-
val_losses.append(val_loss)
|
| 187 |
-
|
| 188 |
-
# Only print detailed validation logs at pperiod intervals
|
| 189 |
-
if (e + 1) % pperiod == 0:
|
| 190 |
-
print(f' Validation - Val Loss: {val_loss:.4f}')
|
| 191 |
-
|
| 192 |
-
# Make sure losses are converted to regular Python lists (for serialization)
|
| 193 |
-
train_losses = [float(loss) for loss in train_losses]
|
| 194 |
-
val_losses = [float(loss) for loss in val_losses]
|
| 195 |
-
|
| 196 |
-
print(f"Training complete - Final train loss: {train_losses[-1]:.4f}, Val loss: {val_losses[-1]:.4f}")
|
| 197 |
-
print(f"Loss history recorded: {len(train_losses)} train points, {len(val_losses)} validation points")
|
| 198 |
-
|
| 199 |
-
# Store the losses in the return object for future reference
|
| 200 |
-
ret_obj.train_losses = train_losses
|
| 201 |
-
ret_obj.val_losses = val_losses
|
| 202 |
-
|
| 203 |
-
return train_losses, val_losses
|
|
|
|
| 3 |
from sklearn.linear_model import Ridge, LogisticRegression
|
| 4 |
|
| 5 |
def to_torch(x):
|
|
|
|
|
|
|
| 6 |
return torch.from_numpy(x).float()
|
| 7 |
|
| 8 |
def to_cuda(x, use_cuda):
|
| 9 |
+
return x.cuda() if use_cuda else x
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def to_numpy(x):
|
| 12 |
return x.detach().cpu().numpy()
|
| 13 |
|
| 14 |
+
def fc_matrix_from_triu(triu_values, n_rois=264):
|
| 15 |
+
fc_matrix = np.zeros((n_rois, n_rois))
|
| 16 |
+
triu_indices = np.triu_indices(n_rois, k=1)
|
| 17 |
+
triu_values = np.tanh(triu_values)
|
| 18 |
+
fc_matrix[triu_indices] = triu_values
|
| 19 |
+
fc_matrix = fc_matrix + fc_matrix.T
|
| 20 |
+
np.fill_diagonal(fc_matrix, 1)
|
| 21 |
+
return fc_matrix
|
| 22 |
+
|
| 23 |
def rmse(a, b, mean=torch.mean):
|
| 24 |
return mean((a-b)**2)**0.5
|
| 25 |
|
|
|
|
| 47 |
ps.append(p)
|
| 48 |
losses = torch.stack(losses)
|
| 49 |
return losses, ps
|
|
|
|
| 50 |
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
|
| 51 |
demo_t = []
|
| 52 |
demo_idx = 0
|
|
|
|
| 70 |
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
|
| 71 |
loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
|
| 72 |
loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
|
|
|
|
| 73 |
# Get linear predictors for demographics
|
| 74 |
pred_w = []
|
| 75 |
pred_i = []
|
| 76 |
pred_stats = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
for i, d, t in zip(range(len(demo)), demo, demo_types):
|
| 79 |
print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
|
|
|
|
| 114 |
ce = torch.nn.CrossEntropyLoss()
|
| 115 |
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
for e in range(nepochs):
|
|
|
|
|
|
|
|
|
|
| 118 |
for bs in range(0, len(x), bsize):
|
| 119 |
xb = x[bs:(bs+bsize)]
|
| 120 |
db = demo_t[bs:(bs+bsize)]
|
|
|
|
| 128 |
loss_decor = sum(loss_decor)
|
| 129 |
loss_rec = rmse(xb, y)
|
| 130 |
|
| 131 |
+
# Sample demographics
|
| 132 |
+
demo_gen = []
|
| 133 |
+
for s, t in zip(pred_stats, demo_types):
|
| 134 |
+
if t == 'continuous':
|
| 135 |
+
mu, std = s
|
| 136 |
+
dd = torch.randn(100).float()
|
| 137 |
+
dd = dd*std+mu
|
| 138 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 139 |
+
demo_gen.append(dd)
|
| 140 |
+
elif t == 'categorical':
|
| 141 |
+
idx = np.random.randint(0, len(s))
|
| 142 |
+
for i in range(len(s)):
|
| 143 |
+
dd = torch.ones(100).float() if idx == i else torch.zeros(100).float()
|
| 144 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 145 |
+
demo_gen.append(dd)
|
| 146 |
+
|
| 147 |
+
demo_gen = torch.stack(demo_gen).permute(1,0)
|
| 148 |
+
|
| 149 |
+
# Generate
|
| 150 |
+
z = vae.gen(100)
|
| 151 |
+
y = vae.dec(z, demo_gen)
|
| 152 |
+
|
| 153 |
+
# Regressor/classifier guidance loss
|
| 154 |
+
losses_pred = []
|
| 155 |
+
idcs = []
|
| 156 |
+
dg_idx = 0
|
| 157 |
+
|
| 158 |
+
for s, t in zip(pred_stats, demo_types):
|
| 159 |
+
if t == 'continuous':
|
| 160 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 161 |
+
loss = rmse(demo_gen[:,dg_idx], yy)
|
| 162 |
+
losses_pred.append(loss)
|
| 163 |
+
idcs.append(float(demo_gen[0,dg_idx]))
|
| 164 |
+
dg_idx += 1
|
| 165 |
+
elif t == 'categorical':
|
| 166 |
+
loss = 0
|
| 167 |
+
for i in range(len(s)):
|
| 168 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 169 |
+
loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
|
| 170 |
+
idcs.append(int(demo_gen[0,dg_idx]))
|
| 171 |
+
dg_idx += 1
|
| 172 |
+
losses_pred.append(loss)
|
| 173 |
+
|
| 174 |
total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
|
| 175 |
+
loss_rec_mult*loss_rec + loss_decor_mult*loss_decor +
|
| 176 |
+
loss_pred_mult*sum(losses_pred))
|
| 177 |
|
| 178 |
total_loss.backward()
|
| 179 |
optim.step()
|
| 180 |
|
| 181 |
+
if e%pperiod == 0 or e == nepochs-1:
|
| 182 |
+
print(f'Epoch {e} ReconLoss {loss_rec:.4f} CovarianceLoss {loss_C:.4f} '
|
| 183 |
+
f'MeanLoss {loss_mu:.4f} DecorLoss {loss_decor:.4f}')
|
| 184 |
+
print(f'GuidanceTargets {idcs}')
|
| 185 |
+
print(f'GuidanceLosses {[f"{loss:.4f}" for loss in losses_pred]}')
|
| 186 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
vae_model.py
CHANGED
|
@@ -1,355 +1,150 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Simplified VAE implementation with explicit loss tracking.
|
| 3 |
-
"""
|
| 4 |
import torch
|
| 5 |
import torch.nn as nn
|
| 6 |
import torch.nn.functional as F
|
| 7 |
import numpy as np
|
| 8 |
-
import
|
| 9 |
-
import matplotlib.pyplot as plt
|
| 10 |
from sklearn.base import BaseEstimator
|
| 11 |
|
| 12 |
-
class
|
| 13 |
-
def __init__(self, input_dim, latent_dim, demo_dim):
|
| 14 |
-
super(
|
| 15 |
-
# Store dimensions
|
| 16 |
self.input_dim = input_dim
|
| 17 |
self.latent_dim = latent_dim
|
| 18 |
self.demo_dim = demo_dim
|
|
|
|
| 19 |
|
| 20 |
-
# Encoder
|
| 21 |
-
self.enc1 = nn.Linear(input_dim,
|
| 22 |
-
self.enc2 = nn.Linear(
|
| 23 |
|
| 24 |
-
# Decoder
|
| 25 |
-
self.dec1 = nn.Linear(latent_dim
|
| 26 |
-
self.dec2 = nn.Linear(
|
| 27 |
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
return self.enc2(h)
|
| 32 |
-
|
| 33 |
-
def decode(self, z, demo):
|
| 34 |
-
"""Decode from latent space to FC reconstruction"""
|
| 35 |
-
# Combine latent with demographics
|
| 36 |
-
z_combined = torch.cat([z, demo], dim=1)
|
| 37 |
-
h = F.relu(self.dec1(z_combined))
|
| 38 |
-
return self.dec2(h)
|
| 39 |
-
|
| 40 |
-
def forward(self, x, demo):
|
| 41 |
-
"""Full forward pass"""
|
| 42 |
-
z = self.encode(x)
|
| 43 |
-
return self.decode(z, demo)
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
self.
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
self.train_losses = []
|
| 123 |
-
self.val_losses = []
|
| 124 |
-
|
| 125 |
-
# Initial validation loss
|
| 126 |
-
self.vae.eval()
|
| 127 |
-
with torch.no_grad():
|
| 128 |
-
reconstructed = self.vae(X_tensor, demo_tensor)
|
| 129 |
-
init_val_loss = F.mse_loss(reconstructed, X_tensor).item()
|
| 130 |
-
self.val_losses.append(init_val_loss)
|
| 131 |
-
print(f"Initial validation loss: {init_val_loss:.4f}")
|
| 132 |
-
|
| 133 |
-
# Main training loop
|
| 134 |
-
for epoch in range(self.nepochs):
|
| 135 |
-
epoch_losses = []
|
| 136 |
-
self.vae.train()
|
| 137 |
-
|
| 138 |
-
# Process in batches
|
| 139 |
-
for i in range(0, n_samples, batch_size):
|
| 140 |
-
# Get batch
|
| 141 |
-
end = min(i + batch_size, n_samples)
|
| 142 |
-
x_batch = X_tensor[i:end]
|
| 143 |
-
demo_batch = demo_tensor[i:end]
|
| 144 |
-
|
| 145 |
-
# Forward pass
|
| 146 |
-
optimizer.zero_grad()
|
| 147 |
-
reconstructed = self.vae(x_batch, demo_batch)
|
| 148 |
-
|
| 149 |
-
# Calculate loss
|
| 150 |
-
loss = F.mse_loss(reconstructed, x_batch)
|
| 151 |
-
|
| 152 |
-
# Backward pass
|
| 153 |
-
loss.backward()
|
| 154 |
-
optimizer.step()
|
| 155 |
-
|
| 156 |
-
# Record loss
|
| 157 |
-
epoch_losses.append(loss.item())
|
| 158 |
-
|
| 159 |
-
# End of epoch
|
| 160 |
-
avg_loss = np.mean(epoch_losses)
|
| 161 |
-
self.train_losses.append(avg_loss)
|
| 162 |
-
|
| 163 |
-
# Validation
|
| 164 |
-
self.vae.eval()
|
| 165 |
-
with torch.no_grad():
|
| 166 |
-
reconstructed = self.vae(X_tensor, demo_tensor)
|
| 167 |
-
val_loss = F.mse_loss(reconstructed, X_tensor).item()
|
| 168 |
-
self.val_losses.append(val_loss)
|
| 169 |
-
|
| 170 |
-
# Print progress every few epochs
|
| 171 |
-
if (epoch + 1) % 5 == 0 or epoch == 0:
|
| 172 |
-
print(f"Epoch {epoch+1}/{self.nepochs} - "
|
| 173 |
-
f"Train loss: {avg_loss:.4f}, Val loss: {val_loss:.4f}")
|
| 174 |
-
|
| 175 |
-
print(f"Training complete! Final loss: {self.train_losses[-1]:.4f}")
|
| 176 |
-
print(f"Loss history: {len(self.train_losses)} train, {len(self.val_losses)} validation")
|
| 177 |
-
|
| 178 |
-
return self.train_losses, self.val_losses
|
| 179 |
-
|
| 180 |
-
def transform(self, X, demo_data, demo_types):
|
| 181 |
-
"""Generate reconstructions or synthetic samples"""
|
| 182 |
-
# Check if model is available
|
| 183 |
-
if self.vae is None:
|
| 184 |
-
raise ValueError("Model not trained or loaded yet")
|
| 185 |
-
|
| 186 |
-
# Set model to evaluation mode
|
| 187 |
-
self.vae.eval()
|
| 188 |
-
|
| 189 |
-
# Check if we're generating or reconstructing
|
| 190 |
-
if isinstance(X, int):
|
| 191 |
-
# Generating n random samples
|
| 192 |
-
n_samples = X
|
| 193 |
-
|
| 194 |
-
# Process demo data (repeat single values if needed)
|
| 195 |
-
demo_list = []
|
| 196 |
-
for d in demo_data:
|
| 197 |
-
if not isinstance(d, (list, np.ndarray)):
|
| 198 |
-
# Single value, repeat for all samples
|
| 199 |
-
demo_list.append([d] * n_samples)
|
| 200 |
-
else:
|
| 201 |
-
demo_list.append(d)
|
| 202 |
-
|
| 203 |
-
print(f"Generating {n_samples} samples with demo data: {demo_list}")
|
| 204 |
-
|
| 205 |
-
# Process demographics
|
| 206 |
-
demo_tensor, demo_dim = self.preprocess_demo(demo_list, demo_types, n_samples)
|
| 207 |
-
|
| 208 |
-
# Generate random latent vectors
|
| 209 |
-
z = torch.randn(n_samples, self.latent_dim).to(self.device)
|
| 210 |
-
|
| 211 |
else:
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
| 215 |
-
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
demo_list.append([d] * n_samples)
|
| 222 |
-
else:
|
| 223 |
-
demo_list.append(d)
|
| 224 |
-
|
| 225 |
-
# Process demographics
|
| 226 |
-
demo_tensor, demo_dim = self.preprocess_demo(demo_list, demo_types)
|
| 227 |
-
|
| 228 |
-
# Encode input data
|
| 229 |
-
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
| 230 |
-
z = self.vae.encode(X_tensor)
|
| 231 |
-
|
| 232 |
-
# Print shapes for debugging
|
| 233 |
-
print(f"Latent shape: {z.shape}, Demo tensor shape: {demo_tensor.shape}")
|
| 234 |
-
|
| 235 |
-
# Decode to get output
|
| 236 |
-
demo_tensor = demo_tensor.to(self.device)
|
| 237 |
-
with torch.no_grad():
|
| 238 |
-
# Make sure demo_tensor has the right dimensions
|
| 239 |
-
if demo_tensor.shape[1] != self.vae.demo_dim:
|
| 240 |
-
print(f"WARNING: Demo dimension mismatch. Expected {self.vae.demo_dim}, got {demo_tensor.shape[1]}")
|
| 241 |
-
# Use demographic dimension from the model
|
| 242 |
-
if demo_tensor.shape[1] > self.vae.demo_dim:
|
| 243 |
-
# Trim extra dimensions
|
| 244 |
-
demo_tensor = demo_tensor[:, :self.vae.demo_dim]
|
| 245 |
-
else:
|
| 246 |
-
# Pad with zeros
|
| 247 |
-
padding = torch.zeros(demo_tensor.shape[0], self.vae.demo_dim - demo_tensor.shape[1]).to(self.device)
|
| 248 |
-
demo_tensor = torch.cat([demo_tensor, padding], dim=1)
|
| 249 |
-
print(f"Adjusted demo tensor shape: {demo_tensor.shape}")
|
| 250 |
-
|
| 251 |
-
output = self.vae.decode(z, demo_tensor)
|
| 252 |
-
|
| 253 |
-
# Convert to numpy
|
| 254 |
-
return output.cpu().numpy()
|
| 255 |
-
|
| 256 |
-
def get_latents(self, X):
|
| 257 |
-
"""Encode data to latent representations"""
|
| 258 |
-
X = np.array(X)
|
| 259 |
-
X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device)
|
| 260 |
-
|
| 261 |
-
with torch.no_grad():
|
| 262 |
-
z = self.vae.encode(X_tensor)
|
| 263 |
-
|
| 264 |
-
return z.cpu().numpy()
|
| 265 |
-
|
| 266 |
def save(self, path):
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
'latent_dim': self.latent_dim,
|
| 276 |
-
'demo_dim': self.vae.demo_dim,
|
| 277 |
-
'train_losses': self.train_losses,
|
| 278 |
-
'val_losses': self.val_losses,
|
| 279 |
-
'nepochs': self.nepochs,
|
| 280 |
-
'batch_size': self.batch_size,
|
| 281 |
-
'lr': self.lr
|
| 282 |
-
}
|
| 283 |
-
|
| 284 |
-
# Save the model
|
| 285 |
-
torch.save(state, path)
|
| 286 |
-
print(f"Model saved to {path}")
|
| 287 |
-
|
| 288 |
-
# Print info about saved losses
|
| 289 |
-
print(f"Saved loss data: {len(self.train_losses)} train, {len(self.val_losses)} validation")
|
| 290 |
-
|
| 291 |
def load(self, path):
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
self.lr = state.get('lr', 1e-3)
|
| 304 |
-
self.train_losses = state.get('train_losses', [])
|
| 305 |
-
self.val_losses = state.get('val_losses', [])
|
| 306 |
-
|
| 307 |
-
# Create model
|
| 308 |
-
self.vae = SimpleVAE(
|
| 309 |
-
input_dim=state['input_dim'],
|
| 310 |
-
latent_dim=self.latent_dim,
|
| 311 |
-
demo_dim=state['demo_dim']
|
| 312 |
-
)
|
| 313 |
-
|
| 314 |
-
# Load weights
|
| 315 |
-
self.vae.load_state_dict(state['vae_state'])
|
| 316 |
-
self.vae.to(self.device)
|
| 317 |
-
|
| 318 |
-
print(f"Model loaded from {path}")
|
| 319 |
-
print(f"Loaded loss data: {len(self.train_losses)} train, {len(self.val_losses)} validation")
|
| 320 |
-
|
| 321 |
-
def plot_learning_curves(train_losses, val_losses):
|
| 322 |
-
"""Plot training and validation loss curves"""
|
| 323 |
-
# Create figure
|
| 324 |
-
plt.figure(figsize=(10, 6))
|
| 325 |
-
|
| 326 |
-
# Check if we have loss data
|
| 327 |
-
if not train_losses:
|
| 328 |
-
plt.text(0.5, 0.5, "No training loss data available",
|
| 329 |
-
ha='center', va='center', transform=plt.gca().transAxes,
|
| 330 |
-
fontsize=14, color='red')
|
| 331 |
-
plt.axis('off')
|
| 332 |
-
return plt.gcf()
|
| 333 |
-
|
| 334 |
-
# Plot losses
|
| 335 |
-
epochs = range(1, len(train_losses) + 1)
|
| 336 |
-
plt.plot(epochs, train_losses, 'b-', label='Training loss')
|
| 337 |
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
|
| 348 |
-
|
| 349 |
-
plt.title('VAE Training and Validation Loss')
|
| 350 |
-
plt.xlabel('Epoch')
|
| 351 |
-
plt.ylabel('Loss')
|
| 352 |
-
plt.legend()
|
| 353 |
-
plt.grid(True, alpha=0.3)
|
| 354 |
|
| 355 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn as nn
|
| 3 |
import torch.nn.functional as F
|
| 4 |
import numpy as np
|
| 5 |
+
from utils import to_torch, to_cuda, to_numpy, demo_to_torch
|
|
|
|
| 6 |
from sklearn.base import BaseEstimator
|
| 7 |
|
| 8 |
+
class VAE(nn.Module):
|
| 9 |
+
def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
|
| 10 |
+
super(VAE, self).__init__()
|
|
|
|
| 11 |
self.input_dim = input_dim
|
| 12 |
self.latent_dim = latent_dim
|
| 13 |
self.demo_dim = demo_dim
|
| 14 |
+
self.use_cuda = use_cuda
|
| 15 |
|
| 16 |
+
# Encoder
|
| 17 |
+
self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
|
| 18 |
+
self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
|
| 19 |
|
| 20 |
+
# Decoder
|
| 21 |
+
self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
|
| 22 |
+
self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
|
| 23 |
|
| 24 |
+
# Batch normalization layers
|
| 25 |
+
self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
|
| 26 |
+
self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
+
def enc(self, x):
|
| 29 |
+
x = self.bn1(F.relu(self.enc1(x)))
|
| 30 |
+
z = self.enc2(x)
|
| 31 |
+
return z
|
| 32 |
+
|
| 33 |
+
def gen(self, n):
|
| 34 |
+
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
|
| 35 |
+
|
| 36 |
+
def dec(self, z, demo):
|
| 37 |
+
z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
|
| 38 |
+
x = self.bn2(F.relu(self.dec1(z)))
|
| 39 |
+
x = self.dec2(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
class DemoVAE(BaseEstimator):
|
| 43 |
+
def __init__(self, **params):
|
| 44 |
+
self.set_params(**params)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_default_params():
|
| 48 |
+
return dict(
|
| 49 |
+
latent_dim=32,
|
| 50 |
+
use_cuda=True,
|
| 51 |
+
nepochs=1000,
|
| 52 |
+
pperiod=100,
|
| 53 |
+
bsize=16,
|
| 54 |
+
loss_C_mult=1,
|
| 55 |
+
loss_mu_mult=1,
|
| 56 |
+
loss_rec_mult=100,
|
| 57 |
+
loss_decor_mult=10,
|
| 58 |
+
loss_pred_mult=0.001,
|
| 59 |
+
alpha=100,
|
| 60 |
+
LR_C=100,
|
| 61 |
+
lr=1e-4,
|
| 62 |
+
weight_decay=0
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def get_params(self, deep=True):
|
| 66 |
+
return {k: getattr(self, k) for k in self.get_default_params().keys()}
|
| 67 |
+
|
| 68 |
+
def set_params(self, **params):
|
| 69 |
+
for k, v in self.get_default_params().items():
|
| 70 |
+
setattr(self, k, params.get(k, v))
|
| 71 |
+
return self
|
| 72 |
+
|
| 73 |
+
def fit(self, x, demo, demo_types):
|
| 74 |
+
from utils import train_vae
|
| 75 |
+
|
| 76 |
+
# Calculate demo_dim
|
| 77 |
+
demo_dim = 0
|
| 78 |
+
for d, t in zip(demo, demo_types):
|
| 79 |
+
if t == 'continuous':
|
| 80 |
+
demo_dim += 1
|
| 81 |
+
elif t == 'categorical':
|
| 82 |
+
demo_dim += len(set(d))
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f'Demographic type "{t}" not supported')
|
| 85 |
+
|
| 86 |
+
# Initialize VAE
|
| 87 |
+
self.input_dim = x.shape[1]
|
| 88 |
+
self.demo_dim = demo_dim
|
| 89 |
+
self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
|
| 90 |
+
|
| 91 |
+
# Train VAE
|
| 92 |
+
train_vae(
|
| 93 |
+
self.vae, x, demo, demo_types,
|
| 94 |
+
self.nepochs, self.pperiod, self.bsize,
|
| 95 |
+
self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
|
| 96 |
+
self.loss_decor_mult, self.loss_pred_mult,
|
| 97 |
+
self.lr, self.weight_decay, self.alpha, self.LR_C,
|
| 98 |
+
self
|
| 99 |
+
)
|
| 100 |
+
return self
|
| 101 |
+
|
| 102 |
+
def transform(self, x, demo, demo_types):
|
| 103 |
+
if isinstance(x, int):
|
| 104 |
+
z = self.vae.gen(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 105 |
else:
|
| 106 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 107 |
+
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
|
| 108 |
+
y = self.vae.dec(z, demo_t)
|
| 109 |
+
return to_numpy(y)
|
| 110 |
+
|
| 111 |
+
def get_latents(self, x):
|
| 112 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 113 |
+
return to_numpy(z)
|
| 114 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
def save(self, path):
|
| 116 |
+
torch.save({
|
| 117 |
+
'model_state_dict': self.vae.state_dict(),
|
| 118 |
+
'params': self.get_params(),
|
| 119 |
+
'pred_stats': self.pred_stats,
|
| 120 |
+
'input_dim': self.input_dim,
|
| 121 |
+
'demo_dim': self.demo_dim
|
| 122 |
+
}, path)
|
| 123 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 124 |
def load(self, path):
|
| 125 |
+
checkpoint = torch.load(path)
|
| 126 |
+
self.set_params(**checkpoint['params'])
|
| 127 |
+
self.pred_stats = checkpoint['pred_stats']
|
| 128 |
+
self.input_dim = checkpoint['input_dim']
|
| 129 |
+
self.demo_dim = checkpoint['demo_dim']
|
| 130 |
+
self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
|
| 131 |
+
self.vae.load_state_dict(checkpoint['model_state_dict'])
|
| 132 |
+
|
| 133 |
+
def train_fc_vae(X, demo_data, demo_types, model_config):
|
| 134 |
+
n_rois = 264
|
| 135 |
+
input_dim = (n_rois * (n_rois - 1)) // 2
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 136 |
|
| 137 |
+
vae = DemoVAE(
|
| 138 |
+
latent_dim=model_config['latent_dim'],
|
| 139 |
+
nepochs=model_config['nepochs'],
|
| 140 |
+
bsize=model_config['bsize'],
|
| 141 |
+
loss_rec_mult=model_config['loss_rec_mult'],
|
| 142 |
+
loss_decor_mult=model_config['loss_decor_mult'],
|
| 143 |
+
lr=model_config['lr'],
|
| 144 |
+
use_cuda=torch.cuda.is_available()
|
| 145 |
+
)
|
| 146 |
|
| 147 |
+
vae.fit(X, demo_data, demo_types)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 148 |
|
| 149 |
+
return vae, X, demo_data, demo_types
|
| 150 |
+
|
visualization.py
CHANGED
|
@@ -1,521 +1,44 @@
|
|
| 1 |
-
# Configure matplotlib for headless environment
|
| 2 |
-
import matplotlib
|
| 3 |
-
matplotlib.use('Agg') # Use non-interactive backend
|
| 4 |
import matplotlib.pyplot as plt
|
| 5 |
import numpy as np
|
|
|
|
| 6 |
|
| 7 |
-
def
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
if not isinstance(vector, np.ndarray):
|
| 11 |
-
try:
|
| 12 |
-
vector = np.array(vector)
|
| 13 |
-
print(f"Converted input to numpy array, shape: {vector.shape}")
|
| 14 |
-
except Exception as e:
|
| 15 |
-
print(f"Error converting input to numpy array: {e}")
|
| 16 |
-
# Create a fallback empty matrix
|
| 17 |
-
n = 264 # Standard size for brain atlas
|
| 18 |
-
matrix = np.zeros((n, n))
|
| 19 |
-
np.fill_diagonal(matrix, 1.0)
|
| 20 |
-
return matrix
|
| 21 |
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
return vector
|
| 26 |
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
print("Detected standard FC vector with 34716 elements (264x264 matrix)")
|
| 34 |
-
n = 264
|
| 35 |
-
else:
|
| 36 |
-
# For other sized vectors, calculate matrix size
|
| 37 |
-
# For a matrix of size n×n, the number of elements in the upper triangular part (excl. diagonal) is n(n-1)/2
|
| 38 |
-
n = int(np.sqrt(2 * len(vector) + 0.25) + 0.5)
|
| 39 |
-
print(f"Calculated matrix size: {n}x{n}")
|
| 40 |
-
|
| 41 |
-
# Validate calculation
|
| 42 |
-
expected_elements = int(n * (n-1) / 2)
|
| 43 |
-
if expected_elements != len(vector):
|
| 44 |
-
print(f"WARNING: Vector length {len(vector)} doesn't match expected length {expected_elements} for {n}x{n} matrix")
|
| 45 |
-
if len(vector) < expected_elements:
|
| 46 |
-
print(f"Padding vector with {expected_elements - len(vector)} zeros")
|
| 47 |
-
vector = np.pad(vector, (0, expected_elements - len(vector)))
|
| 48 |
-
else:
|
| 49 |
-
print(f"Trimming vector to {expected_elements} elements")
|
| 50 |
-
vector = vector[:expected_elements]
|
| 51 |
-
|
| 52 |
-
# Create empty matrix
|
| 53 |
-
matrix = np.zeros((n, n))
|
| 54 |
-
|
| 55 |
-
# Get indices for upper triangle
|
| 56 |
-
triu_indices = np.triu_indices_from(matrix, k=1)
|
| 57 |
-
|
| 58 |
-
# Convert from Fisher z-transform if needed (check if values exceed [-1,1])
|
| 59 |
-
if np.any(np.abs(vector) > 1):
|
| 60 |
-
print("Vector contains values >1, applying inverse Fisher z-transform")
|
| 61 |
-
values = np.tanh(vector) # Inverse Fisher z-transform
|
| 62 |
-
else:
|
| 63 |
-
values = vector
|
| 64 |
-
|
| 65 |
-
# Check for NaN or Inf values
|
| 66 |
-
if np.any(np.isnan(values)) or np.any(np.isinf(values)):
|
| 67 |
-
print("WARNING: Vector contains NaN or Inf values, replacing with zeros")
|
| 68 |
-
values = np.nan_to_num(values)
|
| 69 |
-
|
| 70 |
-
# Set upper triangle values
|
| 71 |
-
matrix[triu_indices] = values
|
| 72 |
-
|
| 73 |
-
# Make symmetric
|
| 74 |
-
matrix = matrix + matrix.T
|
| 75 |
-
|
| 76 |
-
# Set diagonal to 1 (perfect correlation)
|
| 77 |
-
np.fill_diagonal(matrix, 1.0)
|
| 78 |
-
|
| 79 |
-
print(f"Successfully converted to matrix with shape {matrix.shape}")
|
| 80 |
-
return matrix
|
| 81 |
-
|
| 82 |
-
except Exception as e:
|
| 83 |
-
import traceback
|
| 84 |
-
print(f"Error in vector_to_matrix: {e}")
|
| 85 |
-
print(f"Traceback: {traceback.format_exc()}")
|
| 86 |
-
print(f"Vector stats: min={np.min(vector) if len(vector) > 0 else 'N/A'}, "
|
| 87 |
-
f"max={np.max(vector) if len(vector) > 0 else 'N/A'}, "
|
| 88 |
-
f"mean={np.mean(vector) if len(vector) > 0 else 'N/A'}")
|
| 89 |
-
|
| 90 |
-
# Fallback 1 - check if it's already a matrix that was flattened
|
| 91 |
-
if len(vector) > 0 and np.sqrt(len(vector)) == int(np.sqrt(len(vector))):
|
| 92 |
-
n = int(np.sqrt(len(vector)))
|
| 93 |
-
print(f"Trying fallback reshape to {n}x{n}")
|
| 94 |
-
return vector.reshape(n, n)
|
| 95 |
-
|
| 96 |
-
# Fallback 2 - try standard FC matrix size
|
| 97 |
-
elif len(vector) > 30000 and len(vector) < 40000: # Close to 34716
|
| 98 |
-
print(f"Vector length {len(vector)} is close to 34716, trying 264x264 matrix")
|
| 99 |
-
n = 264
|
| 100 |
-
matrix = np.zeros((n, n))
|
| 101 |
-
np.fill_diagonal(matrix, 1.0)
|
| 102 |
-
|
| 103 |
-
# Try to fill as much as possible
|
| 104 |
-
triu_indices = np.triu_indices_from(matrix, k=1)
|
| 105 |
-
max_idx = min(len(vector), len(triu_indices[0]))
|
| 106 |
-
|
| 107 |
-
# Convert from Fisher z-transform if needed
|
| 108 |
-
if np.any(np.abs(vector[:max_idx]) > 1):
|
| 109 |
-
values = np.tanh(vector[:max_idx])
|
| 110 |
-
else:
|
| 111 |
-
values = vector[:max_idx]
|
| 112 |
-
|
| 113 |
-
# Fill the upper triangle with as many values as we can
|
| 114 |
-
for i in range(max_idx):
|
| 115 |
-
matrix[triu_indices[0][i], triu_indices[1][i]] = values[i]
|
| 116 |
-
|
| 117 |
-
# Make symmetric
|
| 118 |
-
matrix = matrix + matrix.T
|
| 119 |
-
np.fill_diagonal(matrix, 1.0)
|
| 120 |
-
|
| 121 |
-
print(f"Created partial matrix with shape {matrix.shape}")
|
| 122 |
-
return matrix
|
| 123 |
-
|
| 124 |
-
# Fallback 3 - create a dummy identity matrix as last resort
|
| 125 |
-
else:
|
| 126 |
-
print("Creating fallback identity matrix")
|
| 127 |
-
n = 264 # Standard size for brain atlas
|
| 128 |
-
matrix = np.zeros((n, n))
|
| 129 |
-
np.fill_diagonal(matrix, 1.0)
|
| 130 |
-
return matrix
|
| 131 |
-
|
| 132 |
-
def plot_fc_matrices(original, reconstructed, generated):
|
| 133 |
-
"""Plot FC matrices comparison with enhanced visualization of brain region connections"""
|
| 134 |
-
try:
|
| 135 |
-
print("Starting FC matrix visualization...")
|
| 136 |
-
print(f"Input shapes - Original: {original.shape}, Reconstructed: {reconstructed.shape}, Generated: {generated.shape}")
|
| 137 |
-
|
| 138 |
-
# Use a larger figure for more detailed visualization
|
| 139 |
-
fig = plt.figure(figsize=(20, 12))
|
| 140 |
-
|
| 141 |
-
# Create a grid layout with 3 rows
|
| 142 |
-
gs = plt.GridSpec(3, 3, height_ratios=[1, 0.7, 0.7], figure=fig)
|
| 143 |
-
|
| 144 |
-
# First row: Original matrices
|
| 145 |
-
ax1 = fig.add_subplot(gs[0, 0])
|
| 146 |
-
ax2 = fig.add_subplot(gs[0, 1])
|
| 147 |
-
ax3 = fig.add_subplot(gs[0, 2])
|
| 148 |
-
|
| 149 |
-
# Second row: Difference matrix and top connections
|
| 150 |
-
ax_diff = fig.add_subplot(gs[1, 0:2])
|
| 151 |
-
ax_top = fig.add_subplot(gs[1, 2])
|
| 152 |
-
|
| 153 |
-
# Third row: Region-specific analysis and histogram
|
| 154 |
-
ax_region = fig.add_subplot(gs[2, 0])
|
| 155 |
-
ax_hist = fig.add_subplot(gs[2, 1])
|
| 156 |
-
ax_metrics = fig.add_subplot(gs[2, 2])
|
| 157 |
-
|
| 158 |
-
vmin, vmax = -1, 1
|
| 159 |
-
|
| 160 |
-
# Convert from vector to matrix if needed
|
| 161 |
-
print("Converting inputs to matrices if needed...")
|
| 162 |
-
if len(original.shape) == 1:
|
| 163 |
-
print("Converting original FC from vector to matrix...")
|
| 164 |
-
original = vector_to_matrix(original)
|
| 165 |
-
if len(reconstructed.shape) == 1:
|
| 166 |
-
print("Converting reconstructed FC from vector to matrix...")
|
| 167 |
-
reconstructed = vector_to_matrix(reconstructed)
|
| 168 |
-
if len(generated.shape) == 1:
|
| 169 |
-
print("Converting generated FC from vector to matrix...")
|
| 170 |
-
generated = vector_to_matrix(generated)
|
| 171 |
-
|
| 172 |
-
print(f"Matrix shapes after conversion - Original: {original.shape}, Reconstructed: {reconstructed.shape}, Generated: {generated.shape}")
|
| 173 |
-
|
| 174 |
-
# Check for NaN or Inf values and handle them
|
| 175 |
-
for name, matrix in [("Original", original), ("Reconstructed", reconstructed), ("Generated", generated)]:
|
| 176 |
-
if np.any(np.isnan(matrix)) or np.any(np.isinf(matrix)):
|
| 177 |
-
print(f"WARNING: {name} matrix contains NaN or Inf values, replacing with zeros")
|
| 178 |
-
if name == "Original":
|
| 179 |
-
original = np.nan_to_num(matrix)
|
| 180 |
-
elif name == "Reconstructed":
|
| 181 |
-
reconstructed = np.nan_to_num(matrix)
|
| 182 |
-
else:
|
| 183 |
-
generated = np.nan_to_num(matrix)
|
| 184 |
-
|
| 185 |
-
# Ensure matrices have consistent dimensions
|
| 186 |
-
print("Checking matrix dimensions...")
|
| 187 |
-
dimensions = [original.shape[0], reconstructed.shape[0], generated.shape[0]]
|
| 188 |
-
if len(set(dimensions)) > 1:
|
| 189 |
-
print(f"WARNING: Matrices have inconsistent dimensions: {dimensions}")
|
| 190 |
-
# Use smallest dimension
|
| 191 |
-
n = min(dimensions)
|
| 192 |
-
print(f"Resizing matrices to consistent dimension: {n}x{n}")
|
| 193 |
-
if original.shape[0] > n:
|
| 194 |
-
original = original[:n, :n]
|
| 195 |
-
if reconstructed.shape[0] > n:
|
| 196 |
-
reconstructed = reconstructed[:n, :n]
|
| 197 |
-
if generated.shape[0] > n:
|
| 198 |
-
generated = generated[:n, :n]
|
| 199 |
-
|
| 200 |
-
# Calculate key metrics for reconstruction quality
|
| 201 |
-
print("Calculating reconstruction quality metrics...")
|
| 202 |
-
from sklearn.metrics import mean_squared_error, r2_score
|
| 203 |
-
|
| 204 |
-
# Flatten matrices for metric calculation (excluding diagonal)
|
| 205 |
-
mask = ~np.eye(original.shape[0], dtype=bool)
|
| 206 |
-
orig_flat = original[mask]
|
| 207 |
-
recon_flat = reconstructed[mask]
|
| 208 |
-
|
| 209 |
-
# Calculate metrics
|
| 210 |
-
mse = mean_squared_error(orig_flat, recon_flat)
|
| 211 |
-
rmse = np.sqrt(mse)
|
| 212 |
-
r2 = r2_score(orig_flat, recon_flat)
|
| 213 |
-
corr = np.corrcoef(orig_flat, recon_flat)[0, 1]
|
| 214 |
-
|
| 215 |
-
# Calculate difference matrix
|
| 216 |
-
diff_matrix = reconstructed - original
|
| 217 |
-
max_diff = np.max(np.abs(diff_matrix))
|
| 218 |
-
mean_abs_diff = np.mean(np.abs(diff_matrix))
|
| 219 |
-
|
| 220 |
-
# Plot original matrices
|
| 221 |
-
print("Creating matrix plots...")
|
| 222 |
-
im1 = ax1.imshow(original, cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
| 223 |
-
ax1.set_title('Original FC', fontsize=12, fontweight='bold')
|
| 224 |
-
|
| 225 |
-
im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
| 226 |
-
ax2.set_title('Reconstructed FC', fontsize=12, fontweight='bold')
|
| 227 |
-
|
| 228 |
-
im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=vmin, vmax=vmax)
|
| 229 |
-
ax3.set_title('Generated FC', fontsize=12, fontweight='bold')
|
| 230 |
-
|
| 231 |
-
# Add colorbars
|
| 232 |
-
for ax, im in zip([ax1, ax2, ax3], [im1, im2, im3]):
|
| 233 |
-
plt.colorbar(im, ax=ax)
|
| 234 |
-
# Remove axis ticks for cleaner visualization
|
| 235 |
-
ax.set_xticks([])
|
| 236 |
-
ax.set_yticks([])
|
| 237 |
-
|
| 238 |
-
# Plot difference matrix
|
| 239 |
-
print("Creating difference matrix visualization...")
|
| 240 |
-
diff_vmax = max(0.5, min(1.0, max_diff)) # Adaptive range
|
| 241 |
-
im_diff = ax_diff.imshow(diff_matrix, cmap='RdBu_r',
|
| 242 |
-
vmin=-diff_vmax, vmax=diff_vmax)
|
| 243 |
-
ax_diff.set_title(f'Reconstruction Difference (Mean Abs Diff: {mean_abs_diff:.3f})', fontsize=12)
|
| 244 |
-
plt.colorbar(im_diff, ax=ax_diff)
|
| 245 |
-
|
| 246 |
-
# Add axis labels to indicate this represents brain regions
|
| 247 |
-
ax_diff.set_xlabel('Brain Region Index', fontsize=10)
|
| 248 |
-
ax_diff.set_ylabel('Brain Region Index', fontsize=10)
|
| 249 |
-
|
| 250 |
-
# Find top connections (strongest positive correlations in original)
|
| 251 |
-
print("Finding top connections...")
|
| 252 |
-
n_regions = original.shape[0]
|
| 253 |
-
top_connections = []
|
| 254 |
-
|
| 255 |
-
# Extract top 10 connections (excluding diagonal)
|
| 256 |
-
for i in range(n_regions):
|
| 257 |
-
for j in range(i+1, n_regions): # upper triangle only
|
| 258 |
-
top_connections.append((i, j, original[i, j], reconstructed[i, j]))
|
| 259 |
-
|
| 260 |
-
# Sort by strength of original connection (descending)
|
| 261 |
-
top_connections.sort(key=lambda x: abs(x[2]), reverse=True)
|
| 262 |
-
top_connections = top_connections[:10] # Keep top 10
|
| 263 |
-
|
| 264 |
-
# Plot top connections
|
| 265 |
-
print("Creating top connections chart...")
|
| 266 |
-
ax_top.set_title('Top 10 Strongest Region Connections', fontsize=12)
|
| 267 |
-
ax_top.set_xlim([-1.1, 1.1]) # Range for correlation values
|
| 268 |
-
|
| 269 |
-
# Create table data
|
| 270 |
-
y_pos = np.arange(len(top_connections))
|
| 271 |
-
labels = [f"R{i+1}-R{j+1}" for i, j, _, _ in top_connections]
|
| 272 |
-
orig_vals = [orig for _, _, orig, _ in top_connections]
|
| 273 |
-
recon_vals = [recon for _, _, _, recon in top_connections]
|
| 274 |
-
|
| 275 |
-
# Plot horizontal bars
|
| 276 |
-
ax_top.barh(y_pos + 0.2, orig_vals, height=0.4, color='blue', alpha=0.6, label='Original')
|
| 277 |
-
ax_top.barh(y_pos - 0.2, recon_vals, height=0.4, color='red', alpha=0.6, label='Reconstructed')
|
| 278 |
-
|
| 279 |
-
# Add zero line
|
| 280 |
-
ax_top.axvline(x=0, color='black', linestyle='-', alpha=0.3)
|
| 281 |
-
|
| 282 |
-
# Add labels and legend
|
| 283 |
-
ax_top.set_yticks(y_pos)
|
| 284 |
-
ax_top.set_yticklabels(labels)
|
| 285 |
-
ax_top.set_xlabel('Correlation Strength')
|
| 286 |
-
ax_top.legend()
|
| 287 |
-
|
| 288 |
-
# Add grid for easier reading
|
| 289 |
-
ax_top.grid(True, axis='x', alpha=0.3)
|
| 290 |
-
|
| 291 |
-
# Find largest errors per region
|
| 292 |
-
print("Analyzing regional errors...")
|
| 293 |
-
region_errors = np.mean(np.abs(diff_matrix), axis=1)
|
| 294 |
-
worst_regions = np.argsort(region_errors)[-10:] # 10 worst regions
|
| 295 |
-
|
| 296 |
-
# Plot region-specific error analysis
|
| 297 |
-
region_indices = np.arange(len(worst_regions))
|
| 298 |
-
ax_region.barh(region_indices, region_errors[worst_regions], color='red', alpha=0.7)
|
| 299 |
-
ax_region.set_yticks(region_indices)
|
| 300 |
-
ax_region.set_yticklabels([f"Region {r+1}" for r in worst_regions])
|
| 301 |
-
ax_region.set_title("Regions with Highest Error", fontsize=12)
|
| 302 |
-
ax_region.set_xlabel("Mean Absolute Error")
|
| 303 |
-
ax_region.grid(True, axis='x', alpha=0.3)
|
| 304 |
-
|
| 305 |
-
# Create histogram of differences
|
| 306 |
-
print("Creating error distribution histogram...")
|
| 307 |
-
ax_hist.hist(diff_matrix.flatten(), bins=50, alpha=0.7, color='purple')
|
| 308 |
-
ax_hist.set_title("Error Distribution", fontsize=12)
|
| 309 |
-
ax_hist.set_xlabel("Reconstruction Error")
|
| 310 |
-
ax_hist.set_ylabel("Count")
|
| 311 |
-
|
| 312 |
-
# Add vertical lines for mean, median
|
| 313 |
-
mean_err = np.mean(diff_matrix)
|
| 314 |
-
median_err = np.median(diff_matrix)
|
| 315 |
-
ax_hist.axvline(mean_err, color='red', linestyle='--', label=f'Mean: {mean_err:.3f}')
|
| 316 |
-
ax_hist.axvline(median_err, color='green', linestyle='--', label=f'Median: {median_err:.3f}')
|
| 317 |
-
ax_hist.legend()
|
| 318 |
-
|
| 319 |
-
# Display metrics as a table
|
| 320 |
-
print("Creating metrics table...")
|
| 321 |
-
ax_metrics.axis('tight')
|
| 322 |
-
ax_metrics.axis('off')
|
| 323 |
-
metrics_data = [
|
| 324 |
-
["MSE", f"{mse:.6f}"],
|
| 325 |
-
["RMSE", f"{rmse:.6f}"],
|
| 326 |
-
["R²", f"{r2:.6f}"],
|
| 327 |
-
["Correlation", f"{corr:.6f}"],
|
| 328 |
-
["Max Error", f"{max_diff:.6f}"],
|
| 329 |
-
["Mean Abs Error", f"{mean_abs_diff:.6f}"]
|
| 330 |
-
]
|
| 331 |
-
table = ax_metrics.table(cellText=metrics_data, loc='center',
|
| 332 |
-
cellLoc='left', colWidths=[0.4, 0.6])
|
| 333 |
-
table.auto_set_font_size(False)
|
| 334 |
-
table.set_fontsize(10)
|
| 335 |
-
table.scale(1, 1.5)
|
| 336 |
-
for (row, col), cell in table.get_celld().items():
|
| 337 |
-
if row == 0 or col == 0:
|
| 338 |
-
cell.set_text_props(fontproperties=matplotlib.font_manager.FontProperties(weight='bold'))
|
| 339 |
-
ax_metrics.set_title("Reconstruction Quality Metrics", fontsize=12)
|
| 340 |
-
|
| 341 |
-
# Overall quality score (weighted average of metrics)
|
| 342 |
-
quality_score = (0.4 * r2 + 0.4 * corr + 0.2 * (1-rmse/2)) # Scale between 0-1
|
| 343 |
-
quality_percent = max(0, min(100, quality_score * 100)) # Clamp to 0-100%
|
| 344 |
-
|
| 345 |
-
# Add overall quality score
|
| 346 |
-
plt.figtext(0.5, 0.01, f"Overall Reconstruction Quality: {quality_percent:.1f}%",
|
| 347 |
-
ha="center", fontsize=14, fontweight='bold',
|
| 348 |
-
bbox={"facecolor":"lightblue", "alpha":0.5, "pad":5})
|
| 349 |
-
|
| 350 |
-
plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to make room for the quality score
|
| 351 |
-
print("FC matrix visualization completed successfully")
|
| 352 |
-
return fig
|
| 353 |
-
|
| 354 |
-
except Exception as e:
|
| 355 |
-
import traceback
|
| 356 |
-
print(f"Error in plot_fc_matrices: {e}")
|
| 357 |
-
print(f"Traceback: {traceback.format_exc()}")
|
| 358 |
-
|
| 359 |
-
# Create a simple error figure
|
| 360 |
-
fig = plt.figure(figsize=(15, 5))
|
| 361 |
-
plt.text(0.5, 0.5, f"FC visualization error: {str(e)}",
|
| 362 |
-
ha='center', va='center', transform=plt.gca().transAxes,
|
| 363 |
-
fontsize=12, color='red')
|
| 364 |
-
plt.axis('off')
|
| 365 |
-
plt.tight_layout()
|
| 366 |
-
return fig
|
| 367 |
-
|
| 368 |
-
def plot_treatment_trajectory(current_score, predicted_score, months_post_stroke, prediction_std=None):
|
| 369 |
-
"""Plot predicted treatment trajectory"""
|
| 370 |
-
fig = plt.figure(figsize=(10, 6))
|
| 371 |
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
plt.scatter([months_post_stroke], [predicted_score],
|
| 375 |
-
label='Predicted Outcome', color='red', s=100)
|
| 376 |
|
| 377 |
-
|
| 378 |
-
|
| 379 |
-
'g--', label='Predicted Trajectory')
|
| 380 |
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
[predicted_score - 2*prediction_std],
|
| 385 |
-
[predicted_score + 2*prediction_std],
|
| 386 |
-
color='red', alpha=0.2,
|
| 387 |
-
label='95% Prediction Interval')
|
| 388 |
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
|
| 392 |
-
|
| 393 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
|
|
|
|
| 395 |
return fig
|
| 396 |
|
| 397 |
-
def plot_learning_curves(train_losses, val_losses):
|
| 398 |
-
"""Plot VAE learning curves with enhanced visualization"""
|
| 399 |
-
try:
|
| 400 |
-
# Handle empty or None inputs - only use real data
|
| 401 |
-
if not train_losses or train_losses is None or len(train_losses) == 0:
|
| 402 |
-
print("WARNING: No real training loss data provided")
|
| 403 |
-
# Create placeholder figure with warning message
|
| 404 |
-
fig = plt.figure(figsize=(10, 6))
|
| 405 |
-
plt.text(0.5, 0.5, "No real training data available",
|
| 406 |
-
ha='center', va='center', transform=plt.gca().transAxes,
|
| 407 |
-
fontsize=14, color='darkred')
|
| 408 |
-
plt.axis('off')
|
| 409 |
-
plt.tight_layout()
|
| 410 |
-
return fig
|
| 411 |
-
|
| 412 |
-
if not val_losses or val_losses is None or len(val_losses) == 0:
|
| 413 |
-
print("WARNING: No real validation loss data provided. Using training data only.")
|
| 414 |
-
# Use training data for both
|
| 415 |
-
val_losses = train_losses
|
| 416 |
-
|
| 417 |
-
# Convert to numpy arrays for safe handling
|
| 418 |
-
train_np = np.array(train_losses)
|
| 419 |
-
val_np = np.array(val_losses)
|
| 420 |
-
|
| 421 |
-
# Check for NaN values
|
| 422 |
-
if np.any(np.isnan(train_np)) or np.any(np.isnan(val_np)):
|
| 423 |
-
print("WARNING: Learning curves contain NaN values, replacing with zeros")
|
| 424 |
-
train_np = np.nan_to_num(train_np)
|
| 425 |
-
val_np = np.nan_to_num(val_np)
|
| 426 |
-
|
| 427 |
-
# Create figure
|
| 428 |
-
fig = plt.figure(figsize=(12, 6))
|
| 429 |
-
|
| 430 |
-
# Add improved styling
|
| 431 |
-
plt.rcParams['font.size'] = 12
|
| 432 |
-
|
| 433 |
-
# Check if train and val lengths match
|
| 434 |
-
if len(train_np) != len(val_np):
|
| 435 |
-
print(f"Training and validation loss lengths don't match: {len(train_np)} vs {len(val_np)}")
|
| 436 |
-
if len(train_np) > len(val_np):
|
| 437 |
-
# Validation might be evaluated less frequently
|
| 438 |
-
# Create epoch indices for each
|
| 439 |
-
train_epochs = np.arange(len(train_np))
|
| 440 |
-
val_factor = len(train_np) / len(val_np)
|
| 441 |
-
val_epochs = np.arange(0, len(train_np), val_factor)[:len(val_np)]
|
| 442 |
-
|
| 443 |
-
plt.plot(train_epochs, train_np, 'b-', linewidth=2, label='Training Loss')
|
| 444 |
-
plt.plot(val_epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
|
| 445 |
-
else:
|
| 446 |
-
# This is unusual, but handle it anyway
|
| 447 |
-
plt.plot(train_np, 'b-', linewidth=2, label='Training Loss')
|
| 448 |
-
plt.plot(val_np[:len(train_np)], 'r-', linewidth=2, label='Validation Loss')
|
| 449 |
-
else:
|
| 450 |
-
# Standard case - equal length arrays
|
| 451 |
-
epochs = np.arange(len(train_np))
|
| 452 |
-
plt.plot(epochs, train_np, 'b-', linewidth=2, label='Training Loss')
|
| 453 |
-
plt.plot(epochs, val_np, 'r-', linewidth=2, label='Validation Loss')
|
| 454 |
-
|
| 455 |
-
# Add shaded confidence region
|
| 456 |
-
if len(train_np) > 5: # Only if we have enough points
|
| 457 |
-
# Calculate moving average for smoother trend lines
|
| 458 |
-
window_size = min(5, len(train_np) // 5)
|
| 459 |
-
if window_size > 1:
|
| 460 |
-
avg_train = np.convolve(train_np, np.ones(window_size)/window_size, mode='valid')
|
| 461 |
-
avg_val = np.convolve(val_np, np.ones(window_size)/window_size, mode='valid')
|
| 462 |
-
avg_epochs = epochs[window_size-1:]
|
| 463 |
-
plt.plot(avg_epochs, avg_train, 'b--', linewidth=1, alpha=0.6)
|
| 464 |
-
plt.plot(avg_epochs, avg_val, 'r--', linewidth=1, alpha=0.6)
|
| 465 |
-
|
| 466 |
-
# Calculate improvement from start to end
|
| 467 |
-
if len(train_np) > 1:
|
| 468 |
-
train_improvement = ((train_np[0] - train_np[-1]) / train_np[0]) * 100
|
| 469 |
-
if len(val_np) > 1:
|
| 470 |
-
val_improvement = ((val_np[0] - val_np[-1]) / val_np[0]) * 100
|
| 471 |
-
plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement, Validation: {val_improvement:.1f}% improvement')
|
| 472 |
-
else:
|
| 473 |
-
plt.title(f'VAE Learning Curves\nTraining: {train_improvement:.1f}% improvement')
|
| 474 |
-
else:
|
| 475 |
-
plt.title('VAE Learning Curves')
|
| 476 |
-
|
| 477 |
-
# Add min/max annotations
|
| 478 |
-
if len(train_np) > 0:
|
| 479 |
-
min_train = np.min(train_np)
|
| 480 |
-
min_train_epoch = np.argmin(train_np)
|
| 481 |
-
plt.annotate(f'Min: {min_train:.4f}', xy=(min_train_epoch, min_train),
|
| 482 |
-
xytext=(min_train_epoch+5, min_train+0.05),
|
| 483 |
-
arrowprops=dict(facecolor='blue', shrink=0.05, alpha=0.5),
|
| 484 |
-
color='blue', fontsize=10)
|
| 485 |
-
|
| 486 |
-
if len(val_np) > 0:
|
| 487 |
-
min_val = np.min(val_np)
|
| 488 |
-
min_val_epoch = np.argmin(val_np)
|
| 489 |
-
plt.annotate(f'Min: {min_val:.4f}', xy=(min_val_epoch, min_val),
|
| 490 |
-
xytext=(min_val_epoch+5, min_val+0.05),
|
| 491 |
-
arrowprops=dict(facecolor='red', shrink=0.05, alpha=0.5),
|
| 492 |
-
color='red', fontsize=10)
|
| 493 |
-
|
| 494 |
-
# Styling
|
| 495 |
-
plt.xlabel('Epoch')
|
| 496 |
-
plt.ylabel('Loss')
|
| 497 |
-
plt.legend(loc='upper right')
|
| 498 |
-
plt.grid(True, alpha=0.3)
|
| 499 |
-
|
| 500 |
-
# Set reasonable y-axis limits
|
| 501 |
-
all_losses = np.concatenate([train_np, val_np])
|
| 502 |
-
y_min = max(0, np.min(all_losses) * 0.9) # Don't go below zero
|
| 503 |
-
y_max = np.percentile(all_losses, 95) * 1.1 # Exclude outliers
|
| 504 |
-
plt.ylim(y_min, y_max)
|
| 505 |
-
|
| 506 |
-
plt.tight_layout()
|
| 507 |
-
return fig
|
| 508 |
-
|
| 509 |
-
except Exception as e:
|
| 510 |
-
import traceback
|
| 511 |
-
print(f"Error in plot_learning_curves: {e}")
|
| 512 |
-
print(f"Traceback: {traceback.format_exc()}")
|
| 513 |
-
|
| 514 |
-
# Create a simple error figure
|
| 515 |
-
fig = plt.figure(figsize=(10, 6))
|
| 516 |
-
plt.text(0.5, 0.5, f"Learning curves error: {str(e)}",
|
| 517 |
-
ha='center', va='center', transform=plt.gca().transAxes,
|
| 518 |
-
fontsize=12, color='red')
|
| 519 |
-
plt.axis('off')
|
| 520 |
-
plt.tight_layout()
|
| 521 |
-
return fig
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import matplotlib.pyplot as plt
|
| 2 |
import numpy as np
|
| 3 |
+
from utils import fc_matrix_from_triu
|
| 4 |
|
| 5 |
+
def visualize_fc_analysis(original_triu, reconstructed_triu, generated_triu, analysis_results=None):
|
| 6 |
+
fig = plt.figure(figsize=(15, 10))
|
| 7 |
+
gs = plt.GridSpec(2, 3)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
+
ax1 = fig.add_subplot(gs[0, 0])
|
| 10 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 11 |
+
ax3 = fig.add_subplot(gs[0, 2])
|
|
|
|
| 12 |
|
| 13 |
+
original = fc_matrix_from_triu(original_triu)
|
| 14 |
+
reconstructed = fc_matrix_from_triu(reconstructed_triu)
|
| 15 |
+
generated = fc_matrix_from_triu(generated_triu)
|
| 16 |
+
|
| 17 |
+
im1 = ax1.imshow(original, cmap='RdBu_r', vmin=-1, vmax=1)
|
| 18 |
+
ax1.set_title('Original FC')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
+
im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=-1, vmax=1)
|
| 21 |
+
ax2.set_title('Reconstructed FC')
|
|
|
|
|
|
|
| 22 |
|
| 23 |
+
im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=-1, vmax=1)
|
| 24 |
+
ax3.set_title('Generated FC')
|
|
|
|
| 25 |
|
| 26 |
+
plt.colorbar(im1, ax=ax1)
|
| 27 |
+
plt.colorbar(im2, ax=ax2)
|
| 28 |
+
plt.colorbar(im3, ax=ax3)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
+
if analysis_results is not None:
|
| 31 |
+
ax4 = fig.add_subplot(gs[1, :])
|
| 32 |
+
for demo_name, results in analysis_results.items():
|
| 33 |
+
significant_dims = np.where(np.array(results['p_values']) < 0.05)[0]
|
| 34 |
+
correlations = np.array(results['correlations'])
|
| 35 |
+
ax4.plot(correlations, label=f'{demo_name} (sig. dims: {len(significant_dims)})')
|
| 36 |
+
|
| 37 |
+
ax4.set_xlabel('Latent Dimension')
|
| 38 |
+
ax4.set_ylabel('Correlation Strength')
|
| 39 |
+
ax4.set_title('Demographic Correlations with Latent Dimensions')
|
| 40 |
+
ax4.legend()
|
| 41 |
|
| 42 |
+
plt.tight_layout()
|
| 43 |
return fig
|
| 44 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|