import os import sys # Add the src directory to the path so we can import from demovae sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) import numpy as np import torch from pathlib import Path import nibabel as nib from data_preprocessing import preprocess_fmri_to_fc from src.demovae.sklearn import DemoVAE from analysis import analyze_fc_patterns from visualization import plot_fc_matrices from config import MODEL_CONFIG, DATASET_CONFIG import pandas as pd import io from typing import List, Dict, Union, Tuple, Any def train_fc_vae(X, demo_data, demo_types, model_config): """ Train a VAE model on functional connectivity matrices """ n_rois = 264 input_dim = (n_rois * (n_rois - 1)) // 2 print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}") # Ensure X is a numpy array with correct data type if not isinstance(X, np.ndarray): print(f"Converting X from {type(X)} to numpy array") X = np.array(X, dtype=np.float32) # Ensure demo_data contains numpy arrays for i, d in enumerate(demo_data): if not isinstance(d, np.ndarray): print(f"Converting demographic {i} from {type(d)} to numpy array") demo_data[i] = np.array(d) # Check for NaN or Inf values if np.isnan(X).any() or np.isinf(X).any(): print("Warning: X contains NaN or Inf values. Replacing with zeros.") X = np.nan_to_num(X) # Create the VAE model vae = DemoVAE( latent_dim=model_config['latent_dim'], nepochs=model_config['nepochs'], bsize=model_config['bsize'], loss_rec_mult=model_config.get('loss_rec_mult', 100), loss_decor_mult=model_config.get('loss_decor_mult', 10), lr=model_config.get('lr', 1e-4), use_cuda=torch.cuda.is_available() ) print("Fitting VAE model...") vae.fit(X, demo_data, demo_types) return vae, X, demo_data, demo_types def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True): """ Load fMRI data and demographics from HuggingFace dataset or local files """ if use_hf_dataset: # Load from HuggingFace Datasets from datasets import load_dataset print(f"Loading dataset from HuggingFace: {data_dir}") dataset = load_dataset(data_dir) print(f"Dataset columns: {dataset['train'].column_names}") # Get demographics directly from the dataset # Create a DataFrame from the dataset features demo_df = pd.DataFrame({ 'ID': dataset['train']['ID'], 'wab_aq': dataset['train']['wab_aq'], 'age': dataset['train']['age'], 'mpo': dataset['train']['mpo'], 'education': dataset['train']['education'], 'gender': dataset['train']['gender'], 'handedness': dataset['train']['handedness'] }) print(f"Loaded demographic data with {len(demo_df)} subjects") # Extract demographic data matching our expected format # Map the dataset columns to our expected format demo_data = [ demo_df['age'].values, # age at stroke -> age demo_df['gender'].values, # sex -> gender demo_df['mpo'].values, # months post stroke -> mpo demo_df['wab_aq'].values # wab score -> wab_aq ] # Check for FC matrices in the dataset fc_columns = [] for col in dataset['train'].column_names: if col.startswith("fc_") or "_fc" in col: fc_columns.append(col) if fc_columns: print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}") # Extract FC matrices fc_matrices = [] for fc_col in fc_columns: fc_matrices.append(dataset['train'][fc_col]) # If we have FC matrices, return them directly demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] return fc_matrices, demo_data, demo_types # If no FC matrices, look for .nii files nii_files = [] for col in dataset['train'].column_names: if col.endswith(".nii.gz") or col.endswith(".nii"): nii_files.append(dataset['train'][col]) if nii_files: print(f"Found {len(nii_files)} .nii files") else: print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.") # If no structured data is found, we can try to download raw files later else: # Original local file loading # Load demographics demo_df = pd.read_csv(demographic_file) demo_data = [ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values, demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values, demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values, demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values ] # Load fMRI files nii_files = sorted(list(Path(data_dir).glob('*.nii.gz'))) demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] return nii_files, demo_data, demo_types def run_fc_analysis(data_dir="SreekarB/OSFData", demographic_file=None, latent_dim=32, nepochs=1000, bsize=16, save_model=True, use_hf_dataset=True, return_data=False): # Update MODEL_CONFIG with user-specified parameters MODEL_CONFIG.update({ 'latent_dim': latent_dim, 'nepochs': nepochs, 'bsize': bsize }) try: # Load data print("Loading data...") nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset) # For SreekarB/OSFData, directly generate synthetic FC matrices if data_dir == "SreekarB/OSFData" and use_hf_dataset: print("Using SreekarB/OSFData dataset with synthetic FC matrices...") X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types) # Check if we got FC matrices directly elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'): print("Using pre-computed FC matrices...") # Convert list of FC matrices to numpy array X = np.stack([np.array(fc) for fc in nii_files]) else: # Prepare data by converting fMRI to FC matrices print("Converting fMRI data to FC matrices...") X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types) # Print shapes and data types print(f"X shape: {X.shape}, type: {type(X)}") for i, d in enumerate(demo_data): print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}") # Train VAE and get data print("Training VAE...") try: # Use the proper DemoVAE implementation from src/demovae/sklearn.py vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG) if save_model: print("Saving model...") os.makedirs('models', exist_ok=True) # Use the save method from DemoVAE vae.save('models/vae_model.pth') print("Model saved successfully.") except Exception as e: print(f"Error during VAE training: {e}") raise # Get latent representations print("Getting latent representations...") latents = vae.get_latents(X) # Analyze results print("Analyzing demographic relationships...") demographics = { 'age': demo_data[0], 'months_post_onset': demo_data[2], 'wab_aq': demo_data[3] } analysis_results = analyze_fc_patterns(latents, demographics) # Generate new FC matrix print("Generating new FC matrices...") # Get data types from original demographic data for proper conversion demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data] # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error new_demographics = [ np.array([60.0], dtype=np.float64), # age np.array(['M'], dtype=np.str_), # gender np.array([12.0], dtype=np.float64), # months post onset np.array([80.0], dtype=np.float64) # wab score ] # Verify the demographic data arrays match the expected types print("Demographic data types:") for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)): print(f" {name}: shape={data.shape}, dtype={data.dtype}") print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80") try: generated_fc = vae.transform(1, new_demographics, demo_types) except Exception as e: print(f"Error generating new FC matrix: {e}") # Try with a fallback approach print("Trying alternative generation approach...") # If specific gender is causing issues, try the first gender from training data new_demographics[1] = np.array([demo_data[1][0]]) generated_fc = vae.transform(1, new_demographics, demo_types) reconstructed_fc = vae.transform(X, demo_data, demo_types) # Visualize results print("Creating visualizations...") fig = plot_fc_matrices(X[0], reconstructed_fc[0], generated_fc[0]) # If requested, return additional data for accuracy calculations if return_data: # Create a structured outcome measures dictionary outcome_measures = { 'wab_aq': demo_data[3], # WAB-AQ scores # Could add other outcome measures here } results = { 'vae': vae, 'X': X, 'latents': latents, 'demographics': demographics, 'reconstructed_fc': reconstructed_fc, 'generated_fc': generated_fc, 'analysis_results': analysis_results, 'outcome_measures': outcome_measures } return fig, results return fig except Exception as e: import traceback print(f"Error in run_fc_analysis: {str(e)}") print(traceback.format_exc()) # Create a dummy figure with error message import matplotlib.pyplot as plt fig = plt.figure(figsize=(10, 6)) plt.text(0.5, 0.5, f"Error: {str(e)}", horizontalalignment='center', verticalalignment='center', fontsize=12, color='red') plt.axis('off') # Return the error figure and empty results if requested if return_data: return fig, None return fig if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description='Run FC Analysis using VAE') parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData', help='HuggingFace dataset ID or directory containing fMRI data') parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv', help='Path to demographic data CSV file') parser.add_argument('--latent_dim', type=int, default=32, help='Dimension of latent space') parser.add_argument('--nepochs', type=int, default=1000, help='Number of training epochs') parser.add_argument('--bsize', type=int, default=16, help='Batch size for training') parser.add_argument('--no_save', action='store_false', help='Do not save the model') parser.add_argument('--use_local', action='store_true', help='Use local data instead of HuggingFace dataset') args = parser.parse_args() fig = run_fc_analysis( data_dir=args.data_dir, demographic_file=args.demographic_file, latent_dim=args.latent_dim, nepochs=args.nepochs, bsize=args.bsize, save_model=args.no_save, use_hf_dataset=not args.use_local ) fig.show()