AphasiaPred / main.py
SreekarB's picture
Upload 13 files
37a1b01 verified
import os
import sys
# Add the src directory to the path so we can import from demovae
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))
import numpy as np
import torch
from pathlib import Path
import nibabel as nib
from data_preprocessing import preprocess_fmri_to_fc
from src.demovae.sklearn import DemoVAE
from analysis import analyze_fc_patterns
from visualization import plot_fc_matrices
from config import MODEL_CONFIG, DATASET_CONFIG
import pandas as pd
import io
from typing import List, Dict, Union, Tuple, Any
def train_fc_vae(X, demo_data, demo_types, model_config):
"""
Train a VAE model on functional connectivity matrices
"""
n_rois = 264
input_dim = (n_rois * (n_rois - 1)) // 2
print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
# Ensure X is a numpy array with correct data type
if not isinstance(X, np.ndarray):
print(f"Converting X from {type(X)} to numpy array")
X = np.array(X, dtype=np.float32)
# Ensure demo_data contains numpy arrays
for i, d in enumerate(demo_data):
if not isinstance(d, np.ndarray):
print(f"Converting demographic {i} from {type(d)} to numpy array")
demo_data[i] = np.array(d)
# Check for NaN or Inf values
if np.isnan(X).any() or np.isinf(X).any():
print("Warning: X contains NaN or Inf values. Replacing with zeros.")
X = np.nan_to_num(X)
# Create the VAE model
vae = DemoVAE(
latent_dim=model_config['latent_dim'],
nepochs=model_config['nepochs'],
bsize=model_config['bsize'],
loss_rec_mult=model_config.get('loss_rec_mult', 100),
loss_decor_mult=model_config.get('loss_decor_mult', 10),
lr=model_config.get('lr', 1e-4),
use_cuda=torch.cuda.is_available()
)
print("Fitting VAE model...")
vae.fit(X, demo_data, demo_types)
return vae, X, demo_data, demo_types
def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
"""
Load fMRI data and demographics from HuggingFace dataset or local files
"""
if use_hf_dataset:
# Load from HuggingFace Datasets
from datasets import load_dataset
print(f"Loading dataset from HuggingFace: {data_dir}")
dataset = load_dataset(data_dir)
print(f"Dataset columns: {dataset['train'].column_names}")
# Get demographics directly from the dataset
# Create a DataFrame from the dataset features
demo_df = pd.DataFrame({
'ID': dataset['train']['ID'],
'wab_aq': dataset['train']['wab_aq'],
'age': dataset['train']['age'],
'mpo': dataset['train']['mpo'],
'education': dataset['train']['education'],
'gender': dataset['train']['gender'],
'handedness': dataset['train']['handedness']
})
print(f"Loaded demographic data with {len(demo_df)} subjects")
# Extract demographic data matching our expected format
# Map the dataset columns to our expected format
demo_data = [
demo_df['age'].values, # age at stroke -> age
demo_df['gender'].values, # sex -> gender
demo_df['mpo'].values, # months post stroke -> mpo
demo_df['wab_aq'].values # wab score -> wab_aq
]
# Check for FC matrices in the dataset
fc_columns = []
for col in dataset['train'].column_names:
if col.startswith("fc_") or "_fc" in col:
fc_columns.append(col)
if fc_columns:
print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
# Extract FC matrices
fc_matrices = []
for fc_col in fc_columns:
fc_matrices.append(dataset['train'][fc_col])
# If we have FC matrices, return them directly
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
return fc_matrices, demo_data, demo_types
# If no FC matrices, look for .nii files
nii_files = []
for col in dataset['train'].column_names:
if col.endswith(".nii.gz") or col.endswith(".nii"):
nii_files.append(dataset['train'][col])
if nii_files:
print(f"Found {len(nii_files)} .nii files")
else:
print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
# If no structured data is found, we can try to download raw files later
else:
# Original local file loading
# Load demographics
demo_df = pd.read_csv(demographic_file)
demo_data = [
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
]
# Load fMRI files
nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
return nii_files, demo_data, demo_types
def run_fc_analysis(data_dir="SreekarB/OSFData",
demographic_file=None,
latent_dim=32,
nepochs=1000,
bsize=16,
save_model=True,
use_hf_dataset=True,
return_data=False):
# Update MODEL_CONFIG with user-specified parameters
MODEL_CONFIG.update({
'latent_dim': latent_dim,
'nepochs': nepochs,
'bsize': bsize
})
try:
# Load data
print("Loading data...")
nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
# For SreekarB/OSFData, directly generate synthetic FC matrices
if data_dir == "SreekarB/OSFData" and use_hf_dataset:
print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
# Check if we got FC matrices directly
elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
print("Using pre-computed FC matrices...")
# Convert list of FC matrices to numpy array
X = np.stack([np.array(fc) for fc in nii_files])
else:
# Prepare data by converting fMRI to FC matrices
print("Converting fMRI data to FC matrices...")
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
# Print shapes and data types
print(f"X shape: {X.shape}, type: {type(X)}")
for i, d in enumerate(demo_data):
print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
# Train VAE and get data
print("Training VAE...")
try:
# Use the proper DemoVAE implementation from src/demovae/sklearn.py
vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
if save_model:
print("Saving model...")
os.makedirs('models', exist_ok=True)
# Use the save method from DemoVAE
vae.save('models/vae_model.pth')
print("Model saved successfully.")
except Exception as e:
print(f"Error during VAE training: {e}")
raise
# Get latent representations
print("Getting latent representations...")
latents = vae.get_latents(X)
# Analyze results
print("Analyzing demographic relationships...")
demographics = {
'age': demo_data[0],
'months_post_onset': demo_data[2],
'wab_aq': demo_data[3]
}
analysis_results = analyze_fc_patterns(latents, demographics)
# Generate new FC matrix
print("Generating new FC matrices...")
# Get data types from original demographic data for proper conversion
demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
# Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
new_demographics = [
np.array([60.0], dtype=np.float64), # age
np.array(['M'], dtype=np.str_), # gender
np.array([12.0], dtype=np.float64), # months post onset
np.array([80.0], dtype=np.float64) # wab score
]
# Verify the demographic data arrays match the expected types
print("Demographic data types:")
for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
print(f" {name}: shape={data.shape}, dtype={data.dtype}")
print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
try:
generated_fc = vae.transform(1, new_demographics, demo_types)
except Exception as e:
print(f"Error generating new FC matrix: {e}")
# Try with a fallback approach
print("Trying alternative generation approach...")
# If specific gender is causing issues, try the first gender from training data
new_demographics[1] = np.array([demo_data[1][0]])
generated_fc = vae.transform(1, new_demographics, demo_types)
reconstructed_fc = vae.transform(X, demo_data, demo_types)
# Visualize results
print("Creating visualizations...")
fig = plot_fc_matrices(X[0], reconstructed_fc[0], generated_fc[0])
# If requested, return additional data for accuracy calculations
if return_data:
# Create a structured outcome measures dictionary
outcome_measures = {
'wab_aq': demo_data[3], # WAB-AQ scores
# Could add other outcome measures here
}
results = {
'vae': vae,
'X': X,
'latents': latents,
'demographics': demographics,
'reconstructed_fc': reconstructed_fc,
'generated_fc': generated_fc,
'analysis_results': analysis_results,
'outcome_measures': outcome_measures
}
return fig, results
return fig
except Exception as e:
import traceback
print(f"Error in run_fc_analysis: {str(e)}")
print(traceback.format_exc())
# Create a dummy figure with error message
import matplotlib.pyplot as plt
fig = plt.figure(figsize=(10, 6))
plt.text(0.5, 0.5, f"Error: {str(e)}",
horizontalalignment='center', verticalalignment='center',
fontsize=12, color='red')
plt.axis('off')
# Return the error figure and empty results if requested
if return_data:
return fig, None
return fig
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
help='HuggingFace dataset ID or directory containing fMRI data')
parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
help='Path to demographic data CSV file')
parser.add_argument('--latent_dim', type=int, default=32,
help='Dimension of latent space')
parser.add_argument('--nepochs', type=int, default=1000,
help='Number of training epochs')
parser.add_argument('--bsize', type=int, default=16,
help='Batch size for training')
parser.add_argument('--no_save', action='store_false',
help='Do not save the model')
parser.add_argument('--use_local', action='store_true',
help='Use local data instead of HuggingFace dataset')
args = parser.parse_args()
fig = run_fc_analysis(
data_dir=args.data_dir,
demographic_file=args.demographic_file,
latent_dim=args.latent_dim,
nepochs=args.nepochs,
bsize=args.bsize,
save_model=args.no_save,
use_hf_dataset=not args.use_local
)
fig.show()