SreekarB commited on
Commit
f91cacf
·
verified ·
1 Parent(s): 7cf1145

Upload 3 files

Browse files
Files changed (3) hide show
  1. README.md +9 -6
  2. app.py +12 -7
  3. main.py +67 -50
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🧠
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -26,11 +26,14 @@ This application implements a VAE model that:
26
  This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
27
 
28
  - Functional connectivity matrices from fMRI data
29
- - Demographic information in `FC_graph_covariate_data.csv` including:
30
- - Age at stroke
31
- - Sex
32
- - Months post-stroke
33
- - WAB scores (aphasia severity)
 
 
 
34
 
35
  ## How to Use
36
 
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
26
  This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
27
 
28
  - Functional connectivity matrices from fMRI data
29
+ - Demographic information directly in the dataset:
30
+ - ID: Subject identifier
31
+ - wab_aq: Aphasia quotient score (severity measure)
32
+ - age: Subject age
33
+ - mpo: Months post onset
34
+ - education: Years of education
35
+ - gender: Subject gender
36
+ - handedness: Subject handedness (ignored in this analysis)
37
 
38
  ## How to Use
39
 
app.py CHANGED
@@ -2,10 +2,10 @@ import gradio as gr
2
  from main import run_fc_analysis
3
  import os
4
 
5
- def gradio_fc_analysis(data_source, demographic_file, latent_dim, nepochs, bsize, use_hf_dataset):
6
  fig = run_fc_analysis(
7
  data_dir=data_source,
8
- demographic_file=demographic_file,
9
  latent_dim=latent_dim,
10
  nepochs=nepochs,
11
  bsize=bsize,
@@ -20,12 +20,10 @@ def create_interface():
20
  inputs=[
21
  gr.Textbox(label="Data Source (HF Dataset ID or Local Directory)",
22
  value="SreekarB/OSFData"),
23
- gr.Textbox(label="Demographic File",
24
- value="FC_graph_covariate_data.csv"),
25
  gr.Slider(minimum=8, maximum=64, step=8,
26
  label="Latent Dimensions", value=32),
27
  gr.Slider(minimum=100, maximum=5000, step=100,
28
- label="Number of Epochs", value=1000),
29
  gr.Slider(minimum=8, maximum=64, step=8,
30
  label="Batch Size", value=16),
31
  gr.Checkbox(label="Use HuggingFace Dataset",
@@ -37,10 +35,17 @@ def create_interface():
37
  Analysis pipeline: fMRI → FC matrices → VAE → Analysis
38
 
39
  This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
40
- The demographic file FC_graph_covariate_data.csv contains age_at_stroke, sex, months_post_stroke, and wab_score.
 
 
 
 
 
 
 
41
  """,
42
  examples=[
43
- ["SreekarB/OSFData", "FC_graph_covariate_data.csv", 32, 500, 16, True],
44
  ],
45
  cache_examples=False,
46
  )
 
2
  from main import run_fc_analysis
3
  import os
4
 
5
+ def gradio_fc_analysis(data_source, latent_dim, nepochs, bsize, use_hf_dataset):
6
  fig = run_fc_analysis(
7
  data_dir=data_source,
8
+ demographic_file=None, # We're now getting demographics directly from the dataset
9
  latent_dim=latent_dim,
10
  nepochs=nepochs,
11
  bsize=bsize,
 
20
  inputs=[
21
  gr.Textbox(label="Data Source (HF Dataset ID or Local Directory)",
22
  value="SreekarB/OSFData"),
 
 
23
  gr.Slider(minimum=8, maximum=64, step=8,
24
  label="Latent Dimensions", value=32),
25
  gr.Slider(minimum=100, maximum=5000, step=100,
26
+ label="Number of Epochs", value=500), # Reduced for faster demos
27
  gr.Slider(minimum=8, maximum=64, step=8,
28
  label="Batch Size", value=16),
29
  gr.Checkbox(label="Use HuggingFace Dataset",
 
35
  Analysis pipeline: fMRI → FC matrices → VAE → Analysis
36
 
37
  This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
38
+ The dataset contains the following columns:
39
+ - ID: Subject identifier
40
+ - wab_aq: Aphasia severity score
41
+ - age: Age of the subject
42
+ - mpo: Months post onset
43
+ - education: Years of education
44
+ - gender: Subject gender
45
+ - handedness: Subject handedness (ignored in the analysis)
46
  """,
47
  examples=[
48
+ ["SreekarB/OSFData", 32, 200, 16, True], # Fewer epochs for faster demo
49
  ],
50
  cache_examples=False,
51
  )
main.py CHANGED
@@ -12,7 +12,7 @@ import pandas as pd
12
  import io
13
  from typing import List, Dict, Union, Tuple, Any
14
 
15
- def load_data(data_dir="SreekarB/OSFData", demographic_file="FC_graph_covariate_data.csv", use_hf_dataset=True):
16
  """
17
  Load fMRI data and demographics from HuggingFace dataset or local files
18
  """
@@ -23,56 +23,70 @@ def load_data(data_dir="SreekarB/OSFData", demographic_file="FC_graph_covariate_
23
  print(f"Loading dataset from HuggingFace: {data_dir}")
24
  dataset = load_dataset(data_dir)
25
 
26
- # Load demographics from the dataset
27
- if demographic_file in dataset["train"].features:
28
- demo_df = pd.DataFrame(dataset["train"][demographic_file])
29
- else:
30
- # Try to load from the dataset files
31
- try:
32
- demo_content = dataset["train"][demographic_file][0]
33
- demo_df = pd.read_csv(io.StringIO(demo_content))
34
- except Exception as e:
35
- print(f"Error loading demographics from dataset: {e}")
36
- # Download the CSV from the dataset repo
37
- import huggingface_hub
38
- csv_path = huggingface_hub.hf_hub_download(repo_id=data_dir, filename=demographic_file)
39
- demo_df = pd.read_csv(csv_path)
 
40
 
41
- # Extract demographic data
 
42
  demo_data = [
43
- demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else np.array([]),
44
- demo_df['sex'].values if 'sex' in demo_df.columns else np.array([]),
45
- demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else np.array([]),
46
- demo_df['wab_score'].values if 'wab_score' in demo_df.columns else np.array([])
47
  ]
48
 
49
- # Get fMRI/FC files from dataset
50
- nii_files = []
51
- for f in dataset["train"].features:
52
- if f.endswith(".nii.gz") or f.endswith(".nii"):
53
- nii_files.append(f)
54
 
55
- if not nii_files:
56
- print("No .nii/.nii.gz files found in dataset, checking for FC matrices")
57
- # Try to find FC matrices directly
58
  fc_matrices = []
59
- for f in dataset["train"].features:
60
- if f.startswith("fc_") or f.endswith("_fc"):
61
- fc_matrices.append(dataset["train"][f])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
63
- if fc_matrices:
64
- print(f"Found {len(fc_matrices)} FC matrices in dataset")
65
- return fc_matrices, demo_data, demo_types
66
  else:
67
  # Original local file loading
68
  # Load demographics
69
  demo_df = pd.read_csv(demographic_file)
70
 
71
  demo_data = [
72
- demo_df['age_at_stroke'].values,
73
- demo_df['sex'].values,
74
- demo_df['months_post_stroke'].values,
75
- demo_df['wab_score'].values
76
  ]
77
 
78
  # Load fMRI files
@@ -82,7 +96,7 @@ def load_data(data_dir="SreekarB/OSFData", demographic_file="FC_graph_covariate_
82
  return nii_files, demo_data, demo_types
83
 
84
  def run_fc_analysis(data_dir="SreekarB/OSFData",
85
- demographic_file="FC_graph_covariate_data.csv",
86
  latent_dim=32,
87
  nepochs=1000,
88
  bsize=16,
@@ -100,18 +114,21 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
100
  print("Loading data...")
101
  nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
102
 
103
- # Add import for io module if it's missing
104
- import io
105
-
106
  # Check if we got FC matrices directly
107
- if isinstance(nii_files, list) and all(isinstance(item, np.ndarray) for item in nii_files):
108
  print("Using pre-computed FC matrices...")
109
- X = np.stack(nii_files)
 
110
  else:
111
  # Prepare data by converting fMRI to FC matrices
112
  print("Converting fMRI data to FC matrices...")
113
  X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
114
 
 
 
 
 
 
115
  # Train VAE and get data
116
  print("Training VAE...")
117
  vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
@@ -128,18 +145,18 @@ def run_fc_analysis(data_dir="SreekarB/OSFData",
128
  # Analyze results
129
  print("Analyzing demographic relationships...")
130
  demographics = {
131
- 'age_at_stroke': demo_data[0] if len(demo_data[0]) > 0 else np.zeros(len(X)),
132
- 'months_post_stroke': demo_data[2] if len(demo_data[2]) > 0 else np.zeros(len(X)),
133
- 'wab_score': demo_data[3] if len(demo_data[3]) > 0 else np.zeros(len(X))
134
  }
135
  analysis_results = analyze_fc_patterns(latents, demographics)
136
 
137
  # Generate new FC matrix
138
  print("Generating new FC matrices...")
139
  new_demographics = [
140
- [60.0], # age at stroke
141
- ['M'], # sex
142
- [12.0], # months post stroke
143
  [80.0] # wab score
144
  ]
145
  generated_fc = vae.transform(1, new_demographics, demo_types)
 
12
  import io
13
  from typing import List, Dict, Union, Tuple, Any
14
 
15
+ def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
16
  """
17
  Load fMRI data and demographics from HuggingFace dataset or local files
18
  """
 
23
  print(f"Loading dataset from HuggingFace: {data_dir}")
24
  dataset = load_dataset(data_dir)
25
 
26
+ print(f"Dataset columns: {dataset['train'].column_names}")
27
+
28
+ # Get demographics directly from the dataset
29
+ # Create a DataFrame from the dataset features
30
+ demo_df = pd.DataFrame({
31
+ 'ID': dataset['train']['ID'],
32
+ 'wab_aq': dataset['train']['wab_aq'],
33
+ 'age': dataset['train']['age'],
34
+ 'mpo': dataset['train']['mpo'],
35
+ 'education': dataset['train']['education'],
36
+ 'gender': dataset['train']['gender'],
37
+ 'handedness': dataset['train']['handedness']
38
+ })
39
+
40
+ print(f"Loaded demographic data with {len(demo_df)} subjects")
41
 
42
+ # Extract demographic data matching our expected format
43
+ # Map the dataset columns to our expected format
44
  demo_data = [
45
+ demo_df['age'].values, # age at stroke -> age
46
+ demo_df['gender'].values, # sex -> gender
47
+ demo_df['mpo'].values, # months post stroke -> mpo
48
+ demo_df['wab_aq'].values # wab score -> wab_aq
49
  ]
50
 
51
+ # Check for FC matrices in the dataset
52
+ fc_columns = []
53
+ for col in dataset['train'].column_names:
54
+ if col.startswith("fc_") or "_fc" in col:
55
+ fc_columns.append(col)
56
 
57
+ if fc_columns:
58
+ print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
59
+ # Extract FC matrices
60
  fc_matrices = []
61
+ for fc_col in fc_columns:
62
+ fc_matrices.append(dataset['train'][fc_col])
63
+
64
+ # If we have FC matrices, return them directly
65
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
66
+ return fc_matrices, demo_data, demo_types
67
+
68
+ # If no FC matrices, look for .nii files
69
+ nii_files = []
70
+ for col in dataset['train'].column_names:
71
+ if col.endswith(".nii.gz") or col.endswith(".nii"):
72
+ nii_files.append(dataset['train'][col])
73
+
74
+ if nii_files:
75
+ print(f"Found {len(nii_files)} .nii files")
76
+ else:
77
+ print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
78
+ # If no structured data is found, we can try to download raw files later
79
 
 
 
 
80
  else:
81
  # Original local file loading
82
  # Load demographics
83
  demo_df = pd.read_csv(demographic_file)
84
 
85
  demo_data = [
86
+ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
87
+ demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
88
+ demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
89
+ demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
90
  ]
91
 
92
  # Load fMRI files
 
96
  return nii_files, demo_data, demo_types
97
 
98
  def run_fc_analysis(data_dir="SreekarB/OSFData",
99
+ demographic_file=None,
100
  latent_dim=32,
101
  nepochs=1000,
102
  bsize=16,
 
114
  print("Loading data...")
115
  nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
116
 
 
 
 
117
  # Check if we got FC matrices directly
118
+ if isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
119
  print("Using pre-computed FC matrices...")
120
+ # Convert list of FC matrices to numpy array
121
+ X = np.stack([np.array(fc) for fc in nii_files])
122
  else:
123
  # Prepare data by converting fMRI to FC matrices
124
  print("Converting fMRI data to FC matrices...")
125
  X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
126
 
127
+ # Print shapes and data types
128
+ print(f"X shape: {X.shape}, type: {type(X)}")
129
+ for i, d in enumerate(demo_data):
130
+ print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
131
+
132
  # Train VAE and get data
133
  print("Training VAE...")
134
  vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
 
145
  # Analyze results
146
  print("Analyzing demographic relationships...")
147
  demographics = {
148
+ 'age': demo_data[0],
149
+ 'months_post_onset': demo_data[2],
150
+ 'wab_aq': demo_data[3]
151
  }
152
  analysis_results = analyze_fc_patterns(latents, demographics)
153
 
154
  # Generate new FC matrix
155
  print("Generating new FC matrices...")
156
  new_demographics = [
157
+ [60.0], # age
158
+ ['M'], # gender
159
+ [12.0], # months post onset
160
  [80.0] # wab score
161
  ]
162
  generated_fc = vae.transform(1, new_demographics, demo_types)