Spaces:
Sleeping
Sleeping
Upload 3 files
Browse files
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 🧠
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
@@ -26,11 +26,14 @@ This application implements a VAE model that:
|
|
| 26 |
This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
|
| 27 |
|
| 28 |
- Functional connectivity matrices from fMRI data
|
| 29 |
-
- Demographic information in
|
| 30 |
-
-
|
| 31 |
-
-
|
| 32 |
-
-
|
| 33 |
-
-
|
|
|
|
|
|
|
|
|
|
| 34 |
|
| 35 |
## How to Use
|
| 36 |
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 26 |
This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
|
| 27 |
|
| 28 |
- Functional connectivity matrices from fMRI data
|
| 29 |
+
- Demographic information directly in the dataset:
|
| 30 |
+
- ID: Subject identifier
|
| 31 |
+
- wab_aq: Aphasia quotient score (severity measure)
|
| 32 |
+
- age: Subject age
|
| 33 |
+
- mpo: Months post onset
|
| 34 |
+
- education: Years of education
|
| 35 |
+
- gender: Subject gender
|
| 36 |
+
- handedness: Subject handedness (ignored in this analysis)
|
| 37 |
|
| 38 |
## How to Use
|
| 39 |
|
app.py
CHANGED
|
@@ -2,10 +2,10 @@ import gradio as gr
|
|
| 2 |
from main import run_fc_analysis
|
| 3 |
import os
|
| 4 |
|
| 5 |
-
def gradio_fc_analysis(data_source,
|
| 6 |
fig = run_fc_analysis(
|
| 7 |
data_dir=data_source,
|
| 8 |
-
demographic_file=
|
| 9 |
latent_dim=latent_dim,
|
| 10 |
nepochs=nepochs,
|
| 11 |
bsize=bsize,
|
|
@@ -20,12 +20,10 @@ def create_interface():
|
|
| 20 |
inputs=[
|
| 21 |
gr.Textbox(label="Data Source (HF Dataset ID or Local Directory)",
|
| 22 |
value="SreekarB/OSFData"),
|
| 23 |
-
gr.Textbox(label="Demographic File",
|
| 24 |
-
value="FC_graph_covariate_data.csv"),
|
| 25 |
gr.Slider(minimum=8, maximum=64, step=8,
|
| 26 |
label="Latent Dimensions", value=32),
|
| 27 |
gr.Slider(minimum=100, maximum=5000, step=100,
|
| 28 |
-
label="Number of Epochs", value=
|
| 29 |
gr.Slider(minimum=8, maximum=64, step=8,
|
| 30 |
label="Batch Size", value=16),
|
| 31 |
gr.Checkbox(label="Use HuggingFace Dataset",
|
|
@@ -37,10 +35,17 @@ def create_interface():
|
|
| 37 |
Analysis pipeline: fMRI → FC matrices → VAE → Analysis
|
| 38 |
|
| 39 |
This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
|
| 40 |
-
The
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
""",
|
| 42 |
examples=[
|
| 43 |
-
["SreekarB/OSFData",
|
| 44 |
],
|
| 45 |
cache_examples=False,
|
| 46 |
)
|
|
|
|
| 2 |
from main import run_fc_analysis
|
| 3 |
import os
|
| 4 |
|
| 5 |
+
def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
|
| 6 |
fig = run_fc_analysis(
|
| 7 |
data_dir=data_source,
|
| 8 |
+
demographic_file=None, # We're now getting demographics directly from the dataset
|
| 9 |
latent_dim=latent_dim,
|
| 10 |
nepochs=nepochs,
|
| 11 |
bsize=bsize,
|
|
|
|
| 20 |
inputs=[
|
| 21 |
gr.Textbox(label="Data Source (HF Dataset ID or Local Directory)",
|
| 22 |
value="SreekarB/OSFData"),
|
|
|
|
|
|
|
| 23 |
gr.Slider(minimum=8, maximum=64, step=8,
|
| 24 |
label="Latent Dimensions", value=32),
|
| 25 |
gr.Slider(minimum=100, maximum=5000, step=100,
|
| 26 |
+
label="Number of Epochs", value=500), # Reduced for faster demos
|
| 27 |
gr.Slider(minimum=8, maximum=64, step=8,
|
| 28 |
label="Batch Size", value=16),
|
| 29 |
gr.Checkbox(label="Use HuggingFace Dataset",
|
|
|
|
| 35 |
Analysis pipeline: fMRI → FC matrices → VAE → Analysis
|
| 36 |
|
| 37 |
This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
|
| 38 |
+
The dataset contains the following columns:
|
| 39 |
+
- ID: Subject identifier
|
| 40 |
+
- wab_aq: Aphasia severity score
|
| 41 |
+
- age: Age of the subject
|
| 42 |
+
- mpo: Months post onset
|
| 43 |
+
- education: Years of education
|
| 44 |
+
- gender: Subject gender
|
| 45 |
+
- handedness: Subject handedness (ignored in the analysis)
|
| 46 |
""",
|
| 47 |
examples=[
|
| 48 |
+
["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
|
| 49 |
],
|
| 50 |
cache_examples=False,
|
| 51 |
)
|
main.py
CHANGED
|
@@ -12,7 +12,7 @@ import pandas as pd
|
|
| 12 |
import io
|
| 13 |
from typing import List, Dict, Union, Tuple, Any
|
| 14 |
|
| 15 |
-
def load_data(data_dir="SreekarB/OSFData", demographic_file=
|
| 16 |
"""
|
| 17 |
Load fMRI data and demographics from HuggingFace dataset or local files
|
| 18 |
"""
|
|
@@ -23,56 +23,70 @@ def load_data(data_dir="SreekarB/OSFData", demographic_file="FC_graph_covariate_
|
|
| 23 |
print(f"Loading dataset from HuggingFace: {data_dir}")
|
| 24 |
dataset = load_dataset(data_dir)
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
|
|
|
| 40 |
|
| 41 |
-
# Extract demographic data
|
|
|
|
| 42 |
demo_data = [
|
| 43 |
-
demo_df['
|
| 44 |
-
demo_df['
|
| 45 |
-
demo_df['
|
| 46 |
-
demo_df['
|
| 47 |
]
|
| 48 |
|
| 49 |
-
#
|
| 50 |
-
|
| 51 |
-
for
|
| 52 |
-
if
|
| 53 |
-
|
| 54 |
|
| 55 |
-
if
|
| 56 |
-
print("
|
| 57 |
-
#
|
| 58 |
fc_matrices = []
|
| 59 |
-
for
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
|
| 63 |
-
if fc_matrices:
|
| 64 |
-
print(f"Found {len(fc_matrices)} FC matrices in dataset")
|
| 65 |
-
return fc_matrices, demo_data, demo_types
|
| 66 |
else:
|
| 67 |
# Original local file loading
|
| 68 |
# Load demographics
|
| 69 |
demo_df = pd.read_csv(demographic_file)
|
| 70 |
|
| 71 |
demo_data = [
|
| 72 |
-
demo_df['age_at_stroke'].values,
|
| 73 |
-
demo_df['sex'].values,
|
| 74 |
-
demo_df['months_post_stroke'].values,
|
| 75 |
-
demo_df['wab_score'].values
|
| 76 |
]
|
| 77 |
|
| 78 |
# Load fMRI files
|
|
@@ -82,7 +96,7 @@ def load_data(data_dir="SreekarB/OSFData", demographic_file="FC_graph_covariate_
|
|
| 82 |
return nii_files, demo_data, demo_types
|
| 83 |
|
| 84 |
def run_fc_analysis(data_dir="SreekarB/OSFData",
|
| 85 |
-
demographic_file=
|
| 86 |
latent_dim=32,
|
| 87 |
nepochs=1000,
|
| 88 |
bsize=16,
|
|
@@ -100,18 +114,21 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
|
|
| 100 |
print("Loading data...")
|
| 101 |
nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
|
| 102 |
|
| 103 |
-
# Add import for io module if it's missing
|
| 104 |
-
import io
|
| 105 |
-
|
| 106 |
# Check if we got FC matrices directly
|
| 107 |
-
if isinstance(nii_files, list) and
|
| 108 |
print("Using pre-computed FC matrices...")
|
| 109 |
-
|
|
|
|
| 110 |
else:
|
| 111 |
# Prepare data by converting fMRI to FC matrices
|
| 112 |
print("Converting fMRI data to FC matrices...")
|
| 113 |
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
|
| 114 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 115 |
# Train VAE and get data
|
| 116 |
print("Training VAE...")
|
| 117 |
vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
|
|
@@ -128,18 +145,18 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
|
|
| 128 |
# Analyze results
|
| 129 |
print("Analyzing demographic relationships...")
|
| 130 |
demographics = {
|
| 131 |
-
'
|
| 132 |
-
'
|
| 133 |
-
'
|
| 134 |
}
|
| 135 |
analysis_results = analyze_fc_patterns(latents, demographics)
|
| 136 |
|
| 137 |
# Generate new FC matrix
|
| 138 |
print("Generating new FC matrices...")
|
| 139 |
new_demographics = [
|
| 140 |
-
[60.0], # age
|
| 141 |
-
['M'], #
|
| 142 |
-
[12.0], # months post
|
| 143 |
[80.0] # wab score
|
| 144 |
]
|
| 145 |
generated_fc = vae.transform(1, new_demographics, demo_types)
|
|
|
|
| 12 |
import io
|
| 13 |
from typing import List, Dict, Union, Tuple, Any
|
| 14 |
|
| 15 |
+
def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
|
| 16 |
"""
|
| 17 |
Load fMRI data and demographics from HuggingFace dataset or local files
|
| 18 |
"""
|
|
|
|
| 23 |
print(f"Loading dataset from HuggingFace: {data_dir}")
|
| 24 |
dataset = load_dataset(data_dir)
|
| 25 |
|
| 26 |
+
print(f"Dataset columns: {dataset['train'].column_names}")
|
| 27 |
+
|
| 28 |
+
# Get demographics directly from the dataset
|
| 29 |
+
# Create a DataFrame from the dataset features
|
| 30 |
+
demo_df = pd.DataFrame({
|
| 31 |
+
'ID': dataset['train']['ID'],
|
| 32 |
+
'wab_aq': dataset['train']['wab_aq'],
|
| 33 |
+
'age': dataset['train']['age'],
|
| 34 |
+
'mpo': dataset['train']['mpo'],
|
| 35 |
+
'education': dataset['train']['education'],
|
| 36 |
+
'gender': dataset['train']['gender'],
|
| 37 |
+
'handedness': dataset['train']['handedness']
|
| 38 |
+
})
|
| 39 |
+
|
| 40 |
+
print(f"Loaded demographic data with {len(demo_df)} subjects")
|
| 41 |
|
| 42 |
+
# Extract demographic data matching our expected format
|
| 43 |
+
# Map the dataset columns to our expected format
|
| 44 |
demo_data = [
|
| 45 |
+
demo_df['age'].values, # age at stroke -> age
|
| 46 |
+
demo_df['gender'].values, # sex -> gender
|
| 47 |
+
demo_df['mpo'].values, # months post stroke -> mpo
|
| 48 |
+
demo_df['wab_aq'].values # wab score -> wab_aq
|
| 49 |
]
|
| 50 |
|
| 51 |
+
# Check for FC matrices in the dataset
|
| 52 |
+
fc_columns = []
|
| 53 |
+
for col in dataset['train'].column_names:
|
| 54 |
+
if col.startswith("fc_") or "_fc" in col:
|
| 55 |
+
fc_columns.append(col)
|
| 56 |
|
| 57 |
+
if fc_columns:
|
| 58 |
+
print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
|
| 59 |
+
# Extract FC matrices
|
| 60 |
fc_matrices = []
|
| 61 |
+
for fc_col in fc_columns:
|
| 62 |
+
fc_matrices.append(dataset['train'][fc_col])
|
| 63 |
+
|
| 64 |
+
# If we have FC matrices, return them directly
|
| 65 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 66 |
+
return fc_matrices, demo_data, demo_types
|
| 67 |
+
|
| 68 |
+
# If no FC matrices, look for .nii files
|
| 69 |
+
nii_files = []
|
| 70 |
+
for col in dataset['train'].column_names:
|
| 71 |
+
if col.endswith(".nii.gz") or col.endswith(".nii"):
|
| 72 |
+
nii_files.append(dataset['train'][col])
|
| 73 |
+
|
| 74 |
+
if nii_files:
|
| 75 |
+
print(f"Found {len(nii_files)} .nii files")
|
| 76 |
+
else:
|
| 77 |
+
print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
|
| 78 |
+
# If no structured data is found, we can try to download raw files later
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
else:
|
| 81 |
# Original local file loading
|
| 82 |
# Load demographics
|
| 83 |
demo_df = pd.read_csv(demographic_file)
|
| 84 |
|
| 85 |
demo_data = [
|
| 86 |
+
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
|
| 87 |
+
demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
|
| 88 |
+
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
|
| 89 |
+
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
|
| 90 |
]
|
| 91 |
|
| 92 |
# Load fMRI files
|
|
|
|
| 96 |
return nii_files, demo_data, demo_types
|
| 97 |
|
| 98 |
def run_fc_analysis(data_dir="SreekarB/OSFData",
|
| 99 |
+
demographic_file=None,
|
| 100 |
latent_dim=32,
|
| 101 |
nepochs=1000,
|
| 102 |
bsize=16,
|
|
|
|
| 114 |
print("Loading data...")
|
| 115 |
nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
|
| 116 |
|
|
|
|
|
|
|
|
|
|
| 117 |
# Check if we got FC matrices directly
|
| 118 |
+
if isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
|
| 119 |
print("Using pre-computed FC matrices...")
|
| 120 |
+
# Convert list of FC matrices to numpy array
|
| 121 |
+
X = np.stack([np.array(fc) for fc in nii_files])
|
| 122 |
else:
|
| 123 |
# Prepare data by converting fMRI to FC matrices
|
| 124 |
print("Converting fMRI data to FC matrices...")
|
| 125 |
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
|
| 126 |
|
| 127 |
+
# Print shapes and data types
|
| 128 |
+
print(f"X shape: {X.shape}, type: {type(X)}")
|
| 129 |
+
for i, d in enumerate(demo_data):
|
| 130 |
+
print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
|
| 131 |
+
|
| 132 |
# Train VAE and get data
|
| 133 |
print("Training VAE...")
|
| 134 |
vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
|
|
|
|
| 145 |
# Analyze results
|
| 146 |
print("Analyzing demographic relationships...")
|
| 147 |
demographics = {
|
| 148 |
+
'age': demo_data[0],
|
| 149 |
+
'months_post_onset': demo_data[2],
|
| 150 |
+
'wab_aq': demo_data[3]
|
| 151 |
}
|
| 152 |
analysis_results = analyze_fc_patterns(latents, demographics)
|
| 153 |
|
| 154 |
# Generate new FC matrix
|
| 155 |
print("Generating new FC matrices...")
|
| 156 |
new_demographics = [
|
| 157 |
+
[60.0], # age
|
| 158 |
+
['M'], # gender
|
| 159 |
+
[12.0], # months post onset
|
| 160 |
[80.0] # wab score
|
| 161 |
]
|
| 162 |
generated_fc = vae.transform(1, new_demographics, demo_types)
|