Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- README.md +18 -13
- app.py +71 -28
- data_preprocessing.py +215 -43
- main.py +76 -57
- requirements.txt +1 -1
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 🧠
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
@@ -25,7 +25,7 @@ This application implements a VAE model that:
|
|
| 25 |
|
| 26 |
This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
|
| 27 |
|
| 28 |
-
-
|
| 29 |
- Demographic information directly in the dataset:
|
| 30 |
- ID: Subject identifier
|
| 31 |
- wab_aq: Aphasia quotient score (severity measure)
|
|
@@ -35,19 +35,24 @@ This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/O
|
|
| 35 |
- gender: Subject gender
|
| 36 |
- handedness: Subject handedness (ignored in this analysis)
|
| 37 |
|
|
|
|
|
|
|
| 38 |
## How to Use
|
| 39 |
|
| 40 |
-
1. **
|
| 41 |
-
|
| 42 |
-
- Latent Dimensions: Controls the size of the latent space (default: 32)
|
| 43 |
-
- Number of Epochs: Training iterations (default:
|
| 44 |
-
- Batch Size: Training batch size (default: 16)
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
-
|
| 48 |
-
-
|
| 49 |
-
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
## Outputs
|
| 53 |
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 25 |
|
| 26 |
This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
|
| 27 |
|
| 28 |
+
- NIfTI files in P01_rs.nii format containing fMRI data
|
| 29 |
- Demographic information directly in the dataset:
|
| 30 |
- ID: Subject identifier
|
| 31 |
- wab_aq: Aphasia quotient score (severity measure)
|
|
|
|
| 35 |
- gender: Subject gender
|
| 36 |
- handedness: Subject handedness (ignored in this analysis)
|
| 37 |
|
| 38 |
+
The application processes the NIfTI files using the Power 264 atlas to create functional connectivity matrices that are then analyzed by the VAE model.
|
| 39 |
+
|
| 40 |
## How to Use
|
| 41 |
|
| 42 |
+
1. **Configure Parameters**:
|
| 43 |
+
- **Data Source**: By default, it uses the SreekarB/OSFData HuggingFace dataset
|
| 44 |
+
- **Latent Dimensions**: Controls the size of the latent space (default: 32)
|
| 45 |
+
- **Number of Epochs**: Training iterations (default: 200 for demo)
|
| 46 |
+
- **Batch Size**: Training batch size (default: 16)
|
| 47 |
+
|
| 48 |
+
2. **Start Training**:
|
| 49 |
+
- Click the "Start Training" button to begin the analysis
|
| 50 |
+
- The training progress will be displayed in the Status area
|
| 51 |
+
|
| 52 |
+
3. **View Results**:
|
| 53 |
+
- The VAE will learn latent representations of brain connectivity
|
| 54 |
+
- Results will show correlations between demographic variables and latent brain patterns
|
| 55 |
+
- The visualization shows original FC, reconstructed FC, and a new FC matrix generated from specific demographic values
|
| 56 |
|
| 57 |
## Outputs
|
| 58 |
|
app.py
CHANGED
|
@@ -3,6 +3,7 @@ from main import run_fc_analysis
|
|
| 3 |
import os
|
| 4 |
|
| 5 |
def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
|
|
|
|
| 6 |
fig = run_fc_analysis(
|
| 7 |
data_dir=data_source,
|
| 8 |
demographic_file=None, # We're now getting demographics directly from the dataset
|
|
@@ -12,30 +13,17 @@ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
|
|
| 12 |
save_model=True,
|
| 13 |
use_hf_dataset=use_hf_dataset
|
| 14 |
)
|
| 15 |
-
return fig
|
| 16 |
|
| 17 |
def create_interface():
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
label="Number of Epochs", value=500), # Reduced for faster demos
|
| 27 |
-
gr.Slider(minimum=8, maximum=64, step=8,
|
| 28 |
-
label="Batch Size", value=16),
|
| 29 |
-
gr.Checkbox(label="Use HuggingFace Dataset",
|
| 30 |
-
value=True),
|
| 31 |
-
],
|
| 32 |
-
outputs="plot",
|
| 33 |
-
title="Aphasia fMRI to FC Analysis using VAE",
|
| 34 |
-
description="""
|
| 35 |
-
Analysis pipeline: fMRI → FC matrices → VAE → Analysis
|
| 36 |
-
|
| 37 |
-
This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
|
| 38 |
-
The dataset contains the following columns:
|
| 39 |
- ID: Subject identifier
|
| 40 |
- wab_aq: Aphasia severity score
|
| 41 |
- age: Age of the subject
|
|
@@ -43,12 +31,67 @@ def create_interface():
|
|
| 43 |
- education: Years of education
|
| 44 |
- gender: Subject gender
|
| 45 |
- handedness: Subject handedness (ignored in the analysis)
|
| 46 |
-
"""
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
return iface
|
| 53 |
|
| 54 |
if __name__ == "__main__":
|
|
|
|
| 3 |
import os
|
| 4 |
|
| 5 |
def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
|
| 6 |
+
"""Run the full VAE analysis pipeline"""
|
| 7 |
fig = run_fc_analysis(
|
| 8 |
data_dir=data_source,
|
| 9 |
demographic_file=None, # We're now getting demographics directly from the dataset
|
|
|
|
| 13 |
save_model=True,
|
| 14 |
use_hf_dataset=use_hf_dataset
|
| 15 |
)
|
| 16 |
+
return fig, "Analysis complete! VAE model has been trained and demographic relationships analyzed."
|
| 17 |
|
| 18 |
def create_interface():
|
| 19 |
+
with gr.Blocks(title="Aphasia fMRI to FC Analysis using VAE") as iface:
|
| 20 |
+
gr.Markdown("""
|
| 21 |
+
# Aphasia fMRI to FC Analysis using VAE
|
| 22 |
+
|
| 23 |
+
This demo uses a Variational Autoencoder (VAE) to analyze functional connectivity patterns in the brain and their relationship to demographic variables.
|
| 24 |
+
|
| 25 |
+
## Dataset Information
|
| 26 |
+
By default, this uses the SreekarB/OSFData dataset from HuggingFace with the following variables:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
- ID: Subject identifier
|
| 28 |
- wab_aq: Aphasia severity score
|
| 29 |
- age: Age of the subject
|
|
|
|
| 31 |
- education: Years of education
|
| 32 |
- gender: Subject gender
|
| 33 |
- handedness: Subject handedness (ignored in the analysis)
|
| 34 |
+
""")
|
| 35 |
+
|
| 36 |
+
with gr.Row():
|
| 37 |
+
with gr.Column(scale=1):
|
| 38 |
+
# Configuration parameters
|
| 39 |
+
data_source = gr.Textbox(
|
| 40 |
+
label="Data Source (HF Dataset ID or Local Directory)",
|
| 41 |
+
value="SreekarB/OSFData"
|
| 42 |
+
)
|
| 43 |
+
latent_dim = gr.Slider(
|
| 44 |
+
minimum=8, maximum=64, step=8,
|
| 45 |
+
label="Latent Dimensions", value=32
|
| 46 |
+
)
|
| 47 |
+
nepochs = gr.Slider(
|
| 48 |
+
minimum=100, maximum=5000, step=100,
|
| 49 |
+
label="Number of Epochs", value=200 # Reduced for faster demos
|
| 50 |
+
)
|
| 51 |
+
bsize = gr.Slider(
|
| 52 |
+
minimum=8, maximum=64, step=8,
|
| 53 |
+
label="Batch Size", value=16
|
| 54 |
+
)
|
| 55 |
+
use_hf_dataset = gr.Checkbox(
|
| 56 |
+
label="Use HuggingFace Dataset", value=True
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Training button
|
| 60 |
+
train_button = gr.Button("Start Training", variant="primary")
|
| 61 |
+
status_text = gr.Textbox(label="Status", value="Ready to start training")
|
| 62 |
+
|
| 63 |
+
with gr.Column(scale=2):
|
| 64 |
+
# Output plot
|
| 65 |
+
output_plot = gr.Plot(label="Analysis Results")
|
| 66 |
+
|
| 67 |
+
# Link the training button to the analysis function
|
| 68 |
+
train_button.click(
|
| 69 |
+
fn=gradio_fc_analysis,
|
| 70 |
+
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
|
| 71 |
+
outputs=[output_plot, status_text]
|
| 72 |
+
)
|
| 73 |
+
|
| 74 |
+
# Add examples
|
| 75 |
+
gr.Examples(
|
| 76 |
+
examples=[
|
| 77 |
+
["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
|
| 78 |
+
],
|
| 79 |
+
inputs=[data_source, latent_dim, nepochs, bsize, use_hf_dataset],
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
# Add explanation of the workflow
|
| 83 |
+
gr.Markdown("""
|
| 84 |
+
## How this works
|
| 85 |
+
|
| 86 |
+
1. **Data Loading**: The system downloads NIfTI files (P01_rs.nii format) from the SreekarB/OSFData dataset
|
| 87 |
+
2. **Preprocessing**: The fMRI data is processed using the Power 264 atlas and converted to functional connectivity (FC) matrices
|
| 88 |
+
3. **VAE Training**: A conditional VAE model learns the latent representation of brain connectivity
|
| 89 |
+
4. **Analysis**: The system analyzes relationships between latent brain connectivity patterns and demographic variables
|
| 90 |
+
5. **Visualization**: Results are displayed showing original FC, reconstructed FC, generated FC, and demographic correlations
|
| 91 |
+
|
| 92 |
+
Note: This app works with the SreekarB/OSFData dataset that contains NIfTI files and demographic information.
|
| 93 |
+
""")
|
| 94 |
+
|
| 95 |
return iface
|
| 96 |
|
| 97 |
if __name__ == "__main__":
|
data_preprocessing.py
CHANGED
|
@@ -4,13 +4,153 @@ from datasets import load_dataset
|
|
| 4 |
from nilearn import input_data, connectome
|
| 5 |
from nilearn.image import load_img
|
| 6 |
import nibabel as nib
|
|
|
|
| 7 |
|
| 8 |
-
def preprocess_fmri_to_fc(
|
| 9 |
-
|
|
|
|
| 10 |
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from nilearn import datasets
|
| 15 |
power = datasets.fetch_coords_power_2011()
|
| 16 |
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
|
@@ -25,52 +165,84 @@ def preprocess_fmri_to_fc(dataset_name, atlas_path=None):
|
|
| 25 |
high_pass=0.01,
|
| 26 |
t_r=2.0 # Adjust TR according to your data
|
| 27 |
)
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
standardize=True,
|
| 32 |
memory='nilearn_cache', memory_level=1,
|
| 33 |
verbose=0,
|
| 34 |
detrend=True,
|
| 35 |
low_pass=0.1,
|
| 36 |
high_pass=0.01,
|
| 37 |
-
t_r=2.0
|
| 38 |
-
)
|
| 39 |
-
|
| 40 |
-
# Load demographic data
|
| 41 |
-
demo_df = pd.DataFrame(dataset['demographics'])
|
| 42 |
-
|
| 43 |
-
demo_data = [
|
| 44 |
-
demo_df['age_at_stroke'].values,
|
| 45 |
-
demo_df['sex'].values,
|
| 46 |
-
demo_df['months_post_stroke'].values,
|
| 47 |
-
demo_df['wab_score'].values
|
| 48 |
-
]
|
| 49 |
-
|
| 50 |
-
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 51 |
-
|
| 52 |
-
# Process fMRI data and compute FC matrices
|
| 53 |
-
fc_matrices = []
|
| 54 |
-
for nii_file in dataset['nii_files']:
|
| 55 |
-
fmri_img = load_img(nii_file)
|
| 56 |
-
time_series = masker.fit_transform(fmri_img)
|
| 57 |
-
|
| 58 |
-
correlation_measure = connectome.ConnectivityMeasure(
|
| 59 |
-
kind='correlation',
|
| 60 |
-
vectorize=False,
|
| 61 |
-
discard_diagonal=False
|
| 62 |
)
|
| 63 |
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
|
| 75 |
# Normalize the FC data
|
| 76 |
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
|
|
|
| 4 |
from nilearn import input_data, connectome
|
| 5 |
from nilearn.image import load_img
|
| 6 |
import nibabel as nib
|
| 7 |
+
import os
|
| 8 |
|
| 9 |
+
def preprocess_fmri_to_fc(dataset_or_niifiles, demo_data=None, demo_types=None):
|
| 10 |
+
"""
|
| 11 |
+
Process fMRI data to generate functional connectivity matrices
|
| 12 |
|
| 13 |
+
Parameters:
|
| 14 |
+
- dataset_or_niifiles: Either a dataset name string or a list of NIfTI files
|
| 15 |
+
- demo_data: Optional demographic data, required if providing NIfTI files
|
| 16 |
+
- demo_types: Optional demographic data types, required if providing NIfTI files
|
| 17 |
+
|
| 18 |
+
Returns:
|
| 19 |
+
- X: Array of FC matrices
|
| 20 |
+
- demo_data: Demographic data
|
| 21 |
+
- demo_types: Demographic data types
|
| 22 |
+
"""
|
| 23 |
+
print(f"Preprocessing data with type: {type(dataset_or_niifiles)}")
|
| 24 |
+
|
| 25 |
+
# For SreekarB/OSFData dataset, the data will be loaded from dataset features
|
| 26 |
+
if isinstance(dataset_or_niifiles, str) and dataset_or_niifiles == "SreekarB/OSFData":
|
| 27 |
+
print("Loading data from SreekarB/OSFData dataset")
|
| 28 |
+
dataset = load_dataset(dataset_or_niifiles, split="train")
|
| 29 |
+
|
| 30 |
+
# Prepare demographics data from the dataset
|
| 31 |
+
if demo_data is None:
|
| 32 |
+
# Create demo_data from the dataset
|
| 33 |
+
demo_df = pd.DataFrame({
|
| 34 |
+
'age': dataset['age'],
|
| 35 |
+
'gender': dataset['gender'],
|
| 36 |
+
'mpo': dataset['mpo'],
|
| 37 |
+
'wab_aq': dataset['wab_aq']
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
demo_data = [
|
| 41 |
+
demo_df['age'].values,
|
| 42 |
+
demo_df['gender'].values,
|
| 43 |
+
demo_df['mpo'].values,
|
| 44 |
+
demo_df['wab_aq'].values
|
| 45 |
+
]
|
| 46 |
+
|
| 47 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 48 |
+
|
| 49 |
+
# Look for NIfTI files in P01_rs.nii format
|
| 50 |
+
print("Searching for NIfTI files in dataset columns...")
|
| 51 |
+
nii_files = []
|
| 52 |
+
|
| 53 |
+
# First check if there are any columns with .nii files
|
| 54 |
+
for col in dataset.column_names:
|
| 55 |
+
# Check if column contains file paths
|
| 56 |
+
first_val = dataset[col][0] if len(dataset) > 0 else None
|
| 57 |
+
if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')):
|
| 58 |
+
print(f"Found column '{col}' with NIfTI file paths")
|
| 59 |
+
|
| 60 |
+
# Try to download files from HuggingFace Hub
|
| 61 |
+
from huggingface_hub import hf_hub_download
|
| 62 |
+
import tempfile
|
| 63 |
+
|
| 64 |
+
for item in dataset[col]:
|
| 65 |
+
try:
|
| 66 |
+
# Download the NIfTI file
|
| 67 |
+
file_path = hf_hub_download(
|
| 68 |
+
repo_id="SreekarB/OSFData",
|
| 69 |
+
filename=item,
|
| 70 |
+
repo_type="dataset"
|
| 71 |
+
)
|
| 72 |
+
nii_files.append(file_path)
|
| 73 |
+
print(f"Downloaded {item} from HuggingFace Hub")
|
| 74 |
+
except Exception as e:
|
| 75 |
+
print(f"Error downloading {item}: {e}")
|
| 76 |
+
# Try looking for the file locally
|
| 77 |
+
local_path = os.path.join(os.getcwd(), item)
|
| 78 |
+
if os.path.exists(local_path):
|
| 79 |
+
nii_files.append(local_path)
|
| 80 |
+
print(f"Found {item} locally")
|
| 81 |
+
else:
|
| 82 |
+
print(f"Warning: Could not find {item} locally or on HuggingFace")
|
| 83 |
+
|
| 84 |
+
# If we found NIfTI files, process them to FC matrices
|
| 85 |
+
if nii_files:
|
| 86 |
+
print(f"Found {len(nii_files)} NIfTI files, converting to FC matrices")
|
| 87 |
+
|
| 88 |
+
# Load Power 264 atlas
|
| 89 |
+
from nilearn import datasets
|
| 90 |
+
power = datasets.fetch_coords_power_2011()
|
| 91 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 92 |
+
|
| 93 |
+
masker = input_data.NiftiSpheresMasker(
|
| 94 |
+
coords, radius=5,
|
| 95 |
+
standardize=True,
|
| 96 |
+
memory='nilearn_cache', memory_level=1,
|
| 97 |
+
verbose=0,
|
| 98 |
+
detrend=True,
|
| 99 |
+
low_pass=0.1,
|
| 100 |
+
high_pass=0.01,
|
| 101 |
+
t_r=2.0 # Adjust TR according to your data
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
# Process fMRI data and compute FC matrices
|
| 105 |
+
fc_matrices = []
|
| 106 |
+
for nii_file in nii_files:
|
| 107 |
+
try:
|
| 108 |
+
fmri_img = load_img(nii_file)
|
| 109 |
+
time_series = masker.fit_transform(fmri_img)
|
| 110 |
+
|
| 111 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 112 |
+
kind='correlation',
|
| 113 |
+
vectorize=False,
|
| 114 |
+
discard_diagonal=False
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 118 |
+
|
| 119 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 120 |
+
fc_triu = fc_matrix[triu_indices]
|
| 121 |
+
|
| 122 |
+
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
|
| 123 |
+
|
| 124 |
+
fc_matrices.append(fc_triu)
|
| 125 |
+
print(f"Processed {nii_file} to FC matrix")
|
| 126 |
+
except Exception as e:
|
| 127 |
+
print(f"Error processing {nii_file}: {e}")
|
| 128 |
+
|
| 129 |
+
if fc_matrices:
|
| 130 |
+
X = np.array(fc_matrices)
|
| 131 |
+
# Normalize the FC data
|
| 132 |
+
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
| 133 |
+
print(f"Created FC matrices with shape {X.shape}")
|
| 134 |
+
return X, demo_data, demo_types
|
| 135 |
+
break # Stop after finding one column with NIfTI files
|
| 136 |
+
|
| 137 |
+
# If we're here, we couldn't find or process the NIfTI files
|
| 138 |
+
print("Could not find or process NIfTI files from the dataset.")
|
| 139 |
+
|
| 140 |
+
print("No FC or fMRI data found in the dataset. Please provide FC matrices.")
|
| 141 |
+
# Return a placeholder with the right demographics but empty FC
|
| 142 |
+
n_subjects = len(dataset)
|
| 143 |
+
n_rois = 264
|
| 144 |
+
fc_dim = (n_rois * (n_rois - 1)) // 2
|
| 145 |
+
X = np.zeros((n_subjects, fc_dim))
|
| 146 |
+
print(f"Created placeholder FC matrices with shape {X.shape}")
|
| 147 |
+
return X, demo_data, demo_types
|
| 148 |
+
|
| 149 |
+
elif isinstance(dataset_or_niifiles, str):
|
| 150 |
+
# Handle real dataset with actual fMRI data
|
| 151 |
+
dataset = load_dataset(dataset_or_niifiles, split="train")
|
| 152 |
+
|
| 153 |
+
# Load Power 264 atlas
|
| 154 |
from nilearn import datasets
|
| 155 |
power = datasets.fetch_coords_power_2011()
|
| 156 |
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
|
|
|
| 165 |
high_pass=0.01,
|
| 166 |
t_r=2.0 # Adjust TR according to your data
|
| 167 |
)
|
| 168 |
+
|
| 169 |
+
# Load demographic data if needed
|
| 170 |
+
if demo_data is None:
|
| 171 |
+
if 'demographics' in dataset.features:
|
| 172 |
+
demo_df = pd.DataFrame(dataset['demographics'])
|
| 173 |
+
|
| 174 |
+
demo_data = [
|
| 175 |
+
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else [],
|
| 176 |
+
demo_df['sex'].values if 'sex' in demo_df.columns else [],
|
| 177 |
+
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else [],
|
| 178 |
+
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else []
|
| 179 |
+
]
|
| 180 |
+
|
| 181 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 182 |
+
|
| 183 |
+
# Process fMRI data and compute FC matrices
|
| 184 |
+
fc_matrices = []
|
| 185 |
+
for nii_file in dataset['nii_files']:
|
| 186 |
+
fmri_img = load_img(nii_file)
|
| 187 |
+
time_series = masker.fit_transform(fmri_img)
|
| 188 |
+
|
| 189 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 190 |
+
kind='correlation', vectorize=False, discard_diagonal=False
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 194 |
+
|
| 195 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 196 |
+
fc_triu = fc_matrix[triu_indices]
|
| 197 |
+
|
| 198 |
+
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
|
| 199 |
+
|
| 200 |
+
fc_matrices.append(fc_triu)
|
| 201 |
+
|
| 202 |
+
X = np.array(fc_matrices)
|
| 203 |
+
|
| 204 |
+
elif isinstance(dataset_or_niifiles, list) and demo_data is not None and demo_types is not None:
|
| 205 |
+
# Handle a list of NIfTI files
|
| 206 |
+
# Similar processing as above but with local files
|
| 207 |
+
print(f"Processing {len(dataset_or_niifiles)} local NIfTI files")
|
| 208 |
+
|
| 209 |
+
# Load Power 264 atlas
|
| 210 |
+
from nilearn import datasets
|
| 211 |
+
power = datasets.fetch_coords_power_2011()
|
| 212 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 213 |
+
|
| 214 |
+
masker = input_data.NiftiSpheresMasker(
|
| 215 |
+
coords, radius=5,
|
| 216 |
standardize=True,
|
| 217 |
memory='nilearn_cache', memory_level=1,
|
| 218 |
verbose=0,
|
| 219 |
detrend=True,
|
| 220 |
low_pass=0.1,
|
| 221 |
high_pass=0.01,
|
| 222 |
+
t_r=2.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
+
fc_matrices = []
|
| 226 |
+
for nii_file in dataset_or_niifiles:
|
| 227 |
+
fmri_img = load_img(nii_file)
|
| 228 |
+
time_series = masker.fit_transform(fmri_img)
|
| 229 |
+
|
| 230 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 231 |
+
kind='correlation', vectorize=False, discard_diagonal=False
|
| 232 |
+
)
|
| 233 |
+
|
| 234 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 235 |
+
|
| 236 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 237 |
+
fc_triu = fc_matrix[triu_indices]
|
| 238 |
+
|
| 239 |
+
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
|
| 240 |
+
|
| 241 |
+
fc_matrices.append(fc_triu)
|
| 242 |
+
|
| 243 |
+
X = np.array(fc_matrices)
|
| 244 |
+
else:
|
| 245 |
+
raise ValueError("Invalid input. Expected dataset name string or list of NIfTI files with demographic data.")
|
| 246 |
|
| 247 |
# Normalize the FC data
|
| 248 |
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
main.py
CHANGED
|
@@ -110,63 +110,82 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
|
|
| 110 |
'bsize': bsize
|
| 111 |
})
|
| 112 |
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
if __name__ == "__main__":
|
| 172 |
import argparse
|
|
|
|
| 110 |
'bsize': bsize
|
| 111 |
})
|
| 112 |
|
| 113 |
+
try:
|
| 114 |
+
# Load data
|
| 115 |
+
print("Loading data...")
|
| 116 |
+
nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
|
| 117 |
+
|
| 118 |
+
# For SreekarB/OSFData, directly generate synthetic FC matrices
|
| 119 |
+
if data_dir == "SreekarB/OSFData" and use_hf_dataset:
|
| 120 |
+
print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
|
| 121 |
+
X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
|
| 122 |
+
# Check if we got FC matrices directly
|
| 123 |
+
elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
|
| 124 |
+
print("Using pre-computed FC matrices...")
|
| 125 |
+
# Convert list of FC matrices to numpy array
|
| 126 |
+
X = np.stack([np.array(fc) for fc in nii_files])
|
| 127 |
+
else:
|
| 128 |
+
# Prepare data by converting fMRI to FC matrices
|
| 129 |
+
print("Converting fMRI data to FC matrices...")
|
| 130 |
+
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
|
| 131 |
+
|
| 132 |
+
# Print shapes and data types
|
| 133 |
+
print(f"X shape: {X.shape}, type: {type(X)}")
|
| 134 |
+
for i, d in enumerate(demo_data):
|
| 135 |
+
print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
|
| 136 |
+
|
| 137 |
+
# Train VAE and get data
|
| 138 |
+
print("Training VAE...")
|
| 139 |
+
vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
|
| 140 |
+
|
| 141 |
+
if save_model:
|
| 142 |
+
print("Saving model...")
|
| 143 |
+
os.makedirs('models', exist_ok=True)
|
| 144 |
+
torch.save(vae.state_dict(), 'models/vae_model.pth')
|
| 145 |
+
|
| 146 |
+
# Get latent representations
|
| 147 |
+
print("Getting latent representations...")
|
| 148 |
+
latents = vae.get_latents(X)
|
| 149 |
+
|
| 150 |
+
# Analyze results
|
| 151 |
+
print("Analyzing demographic relationships...")
|
| 152 |
+
demographics = {
|
| 153 |
+
'age': demo_data[0],
|
| 154 |
+
'months_post_onset': demo_data[2],
|
| 155 |
+
'wab_aq': demo_data[3]
|
| 156 |
+
}
|
| 157 |
+
analysis_results = analyze_fc_patterns(latents, demographics)
|
| 158 |
+
|
| 159 |
+
# Generate new FC matrix
|
| 160 |
+
print("Generating new FC matrices...")
|
| 161 |
+
new_demographics = [
|
| 162 |
+
[60.0], # age
|
| 163 |
+
['M'], # gender
|
| 164 |
+
[12.0], # months post onset
|
| 165 |
+
[80.0] # wab score
|
| 166 |
+
]
|
| 167 |
+
generated_fc = vae.transform(1, new_demographics, demo_types)
|
| 168 |
+
reconstructed_fc = vae.transform(X, demo_data, demo_types)
|
| 169 |
+
|
| 170 |
+
# Visualize results
|
| 171 |
+
print("Creating visualizations...")
|
| 172 |
+
fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
|
| 173 |
+
|
| 174 |
+
return fig
|
| 175 |
+
|
| 176 |
+
except Exception as e:
|
| 177 |
+
import traceback
|
| 178 |
+
print(f"Error in run_fc_analysis: {str(e)}")
|
| 179 |
+
print(traceback.format_exc())
|
| 180 |
+
|
| 181 |
+
# Create a dummy figure with error message
|
| 182 |
+
import matplotlib.pyplot as plt
|
| 183 |
+
fig = plt.figure(figsize=(10, 6))
|
| 184 |
+
plt.text(0.5, 0.5, f"Error: {str(e)}",
|
| 185 |
+
horizontalalignment='center', verticalalignment='center',
|
| 186 |
+
fontsize=12, color='red')
|
| 187 |
+
plt.axis('off')
|
| 188 |
+
return fig
|
| 189 |
|
| 190 |
if __name__ == "__main__":
|
| 191 |
import argparse
|
requirements.txt
CHANGED
|
@@ -7,6 +7,6 @@ scikit-learn>=0.24.2
|
|
| 7 |
matplotlib>=3.4.2
|
| 8 |
gradio>=2.0.0
|
| 9 |
datasets>=1.11.0
|
| 10 |
-
huggingface_hub>=0.
|
| 11 |
transformers>=4.15.0
|
| 12 |
|
|
|
|
| 7 |
matplotlib>=3.4.2
|
| 8 |
gradio>=2.0.0
|
| 9 |
datasets>=1.11.0
|
| 10 |
+
huggingface_hub>=0.15.0
|
| 11 |
transformers>=4.15.0
|
| 12 |
|