Spaces:
Sleeping
Sleeping
File size: 13,003 Bytes
ef677f1 1c47445 b32645b ef677f1 1c47445 a7f7808 1c47445 ef677f1 1c47445 763369a 1c47445 c775b23 1c47445 c775b23 1c47445 9135a28 1c47445 67303f6 1c47445 ef677f1 1c47445 55c1385 1c47445 55c1385 1c47445 b32645b 55c1385 1c47445 9641510 1c47445 dbe81c1 1c47445 dbe81c1 1c47445 dbe81c1 1c47445 b32645b 1c47445 a7f7808 1c47445 37a1b01 1c47445 37a1b01 1c47445 50c6714 ef677f1 1c47445 ef677f1 1c47445 ef677f1 1c47445 b32645b 1c47445 b32645b 1c47445 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 |
import os
import sys
# Add the src directory to the path so we can import from demovae
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
import numpy as np
import torch
from pathlib import Path
import nibabel as nib
from data_preprocessing import preprocess_fmri_to_fc
from src.demovae.sklearn import DemoVAE
from analysis import analyze_fc_patterns
from visualization import plot_fc_matrices
from config import MODEL_CONFIG, DATASET_CONFIG
import pandas as pd
import io
from typing import List, Dict, Union, Tuple, Any
def train_fc_vae(X, demo_data, demo_types, model_config):
"""
Train a VAE model on functional connectivity matrices
"""
n_rois = 264
input_dim = (n_rois * (n_rois - 1)) // 2
print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
# Ensure X is a numpy array with correct data type
if not isinstance(X, np.ndarray):
print(f"Converting X from {type(X)} to numpy array")
X = np.array(X, dtype=np.float32)
# Ensure demo_data contains numpy arrays
for i, d in enumerate(demo_data):
if not isinstance(d, np.ndarray):
print(f"Converting demographic {i} from {type(d)} to numpy array")
demo_data[i] = np.array(d)
# Check for NaN or Inf values
if np.isnan(X).any() or np.isinf(X).any():
print("Warning: X contains NaN or Inf values. Replacing with zeros.")
X = np.nan_to_num(X)
# Create the VAE model
vae = DemoVAE(
latent_dim=model_config['latent_dim'],
nepochs=model_config['nepochs'],
bsize=model_config['bsize'],
loss_rec_mult=model_config.get('loss_rec_mult', 100),
loss_decor_mult=model_config.get('loss_decor_mult', 10),
lr=model_config.get('lr', 1e-4),
use_cuda=torch.cuda.is_available()
)
print("Fitting VAE model...")
vae.fit(X, demo_data, demo_types)
return vae, X, demo_data, demo_types
def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
"""
Load fMRI data and demographics from HuggingFace dataset or local files
"""
if use_hf_dataset:
# Load from HuggingFace Datasets
from datasets import load_dataset
print(f"Loading dataset from HuggingFace: {data_dir}")
dataset = load_dataset(data_dir)
print(f"Dataset columns: {dataset['train'].column_names}")
# Get demographics directly from the dataset
# Create a DataFrame from the dataset features
demo_df = pd.DataFrame({
'ID': dataset['train']['ID'],
'wab_aq': dataset['train']['wab_aq'],
'age': dataset['train']['age'],
'mpo': dataset['train']['mpo'],
'education': dataset['train']['education'],
'gender': dataset['train']['gender'],
'handedness': dataset['train']['handedness']
})
print(f"Loaded demographic data with {len(demo_df)} subjects")
# Extract demographic data matching our expected format
# Map the dataset columns to our expected format
demo_data = [
demo_df['age'].values, # age at stroke -> age
demo_df['gender'].values, # sex -> gender
demo_df['mpo'].values, # months post stroke -> mpo
demo_df['wab_aq'].values # wab score -> wab_aq
]
# Check for FC matrices in the dataset
fc_columns = []
for col in dataset['train'].column_names:
if col.startswith("fc_") or "_fc" in col:
fc_columns.append(col)
if fc_columns:
print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
# Extract FC matrices
fc_matrices = []
for fc_col in fc_columns:
fc_matrices.append(dataset['train'][fc_col])
# If we have FC matrices, return them directly
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
return fc_matrices, demo_data, demo_types
# If no FC matrices, look for .nii files
nii_files = []
for col in dataset['train'].column_names:
if col.endswith(".nii.gz") or col.endswith(".nii"):
nii_files.append(dataset['train'][col])
if nii_files:
print(f"Found {len(nii_files)} .nii files")
else:
print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
# If no structured data is found, we can try to download raw files later
else:
# Original local file loading
# Load demographics
demo_df = pd.read_csv(demographic_file)
demo_data = [
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
]
# Load fMRI files
nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
return nii_files, demo_data, demo_types
def run_fc_analysis(data_dir="SreekarB/OSFData",
demographic_file=None,
latent_dim=32,
nepochs=1000,
bsize=16,
save_model=True,
use_hf_dataset=True,
return_data=False):
# Update MODEL_CONFIG with user-specified parameters
MODEL_CONFIG.update({
'latent_dim': latent_dim,
'nepochs': nepochs,
'bsize': bsize
})
try:
# Load data
print("Loading data...")
nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
# For SreekarB/OSFData, directly generate synthetic FC matrices
if data_dir == "SreekarB/OSFData" and use_hf_dataset:
print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
# Check if we got FC matrices directly
elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
print("Using pre-computed FC matrices...")
# Convert list of FC matrices to numpy array
X = np.stack([np.array(fc) for fc in nii_files])
else:
# Prepare data by converting fMRI to FC matrices
print("Converting fMRI data to FC matrices...")
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
# Print shapes and data types
print(f"X shape: {X.shape}, type: {type(X)}")
for i, d in enumerate(demo_data):
print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
# Train VAE and get data
print("Training VAE...")
try:
# Use the proper DemoVAE implementation from src/demovae/sklearn.py
vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
if save_model:
print("Saving model...")
os.makedirs('models', exist_ok=True)
# Use the save method from DemoVAE
vae.save('models/vae_model.pth')
print("Model saved successfully.")
except Exception as e:
print(f"Error during VAE training: {e}")
raise
# Get latent representations
print("Getting latent representations...")
latents = vae.get_latents(X)
# Analyze results
print("Analyzing demographic relationships...")
demographics = {
'age': demo_data[0],
'months_post_onset': demo_data[2],
'wab_aq': demo_data[3]
}
analysis_results = analyze_fc_patterns(latents, demographics)
# Generate new FC matrix
print("Generating new FC matrices...")
# Get data types from original demographic data for proper conversion
demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
# Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
new_demographics = [
np.array([60.0], dtype=np.float64), # age
np.array(['M'], dtype=np.str_), # gender
np.array([12.0], dtype=np.float64), # months post onset
np.array([80.0], dtype=np.float64) # wab score
]
# Verify the demographic data arrays match the expected types
print("Demographic data types:")
for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
print(f" {name}: shape={data.shape}, dtype={data.dtype}")
print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
try:
generated_fc = vae.transform(1, new_demographics, demo_types)
except Exception as e:
print(f"Error generating new FC matrix: {e}")
# Try with a fallback approach
print("Trying alternative generation approach...")
# If specific gender is causing issues, try the first gender from training data
new_demographics[1] = np.array([demo_data[1][0]])
generated_fc = vae.transform(1, new_demographics, demo_types)
reconstructed_fc = vae.transform(X, demo_data, demo_types)
# Visualize results
print("Creating visualizations...")
fig = plot_fc_matrices(X[0], reconstructed_fc[0], generated_fc[0])
# If requested, return additional data for accuracy calculations
if return_data:
# Create a structured outcome measures dictionary
outcome_measures = {
'wab_aq': demo_data[3], # WAB-AQ scores
# Could add other outcome measures here
}
results = {
'vae': vae,
'X': X,
'latents': latents,
'demographics': demographics,
'reconstructed_fc': reconstructed_fc,
'generated_fc': generated_fc,
'analysis_results': analysis_results,
'outcome_measures': outcome_measures
}
return fig, results
return fig
except Exception as e:
import traceback
print(f"Error in run_fc_analysis: {str(e)}")
print(traceback.format_exc())
# Create a dummy figure with error message
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 6))
plt.text(0.5, 0.5, f"Error: {str(e)}",
horizontalalignment='center', verticalalignment='center',
fontsize=12, color='red')
plt.axis('off')
# Return the error figure and empty results if requested
if return_data:
return fig, None
return fig
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
help='HuggingFace dataset ID or directory containing fMRI data')
parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
help='Path to demographic data CSV file')
parser.add_argument('--latent_dim', type=int, default=32,
help='Dimension of latent space')
parser.add_argument('--nepochs', type=int, default=1000,
help='Number of training epochs')
parser.add_argument('--bsize', type=int, default=16,
help='Batch size for training')
parser.add_argument('--no_save', action='store_false',
help='Do not save the model')
parser.add_argument('--use_local', action='store_true',
help='Use local data instead of HuggingFace dataset')
args = parser.parse_args()
fig = run_fc_analysis(
data_dir=args.data_dir,
demographic_file=args.demographic_file,
latent_dim=args.latent_dim,
nepochs=args.nepochs,
bsize=args.bsize,
save_model=args.no_save,
use_hf_dataset=not args.use_local
)
fig.show()
|