Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| # Add the src directory to the path so we can import from demovae | |
| sys.path.append(os.path.join(os.path.dirname(__file__), 'src')) | |
| import numpy as np | |
| import torch | |
| from pathlib import Path | |
| import nibabel as nib | |
| from data_preprocessing import preprocess_fmri_to_fc | |
| from src.demovae.sklearn import DemoVAE | |
| from analysis import analyze_fc_patterns | |
| from visualization import plot_fc_matrices | |
| from config import MODEL_CONFIG, DATASET_CONFIG | |
| import pandas as pd | |
| import io | |
| from typing import List, Dict, Union, Tuple, Any | |
| def train_fc_vae(X, demo_data, demo_types, model_config): | |
| """ | |
| Train a VAE model on functional connectivity matrices | |
| """ | |
| n_rois = 264 | |
| input_dim = (n_rois * (n_rois - 1)) // 2 | |
| print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}") | |
| # Ensure X is a numpy array with correct data type | |
| if not isinstance(X, np.ndarray): | |
| print(f"Converting X from {type(X)} to numpy array") | |
| X = np.array(X, dtype=np.float32) | |
| # Ensure demo_data contains numpy arrays | |
| for i, d in enumerate(demo_data): | |
| if not isinstance(d, np.ndarray): | |
| print(f"Converting demographic {i} from {type(d)} to numpy array") | |
| demo_data[i] = np.array(d) | |
| # Check for NaN or Inf values | |
| if np.isnan(X).any() or np.isinf(X).any(): | |
| print("Warning: X contains NaN or Inf values. Replacing with zeros.") | |
| X = np.nan_to_num(X) | |
| # Create the VAE model | |
| vae = DemoVAE( | |
| latent_dim=model_config['latent_dim'], | |
| nepochs=model_config['nepochs'], | |
| bsize=model_config['bsize'], | |
| loss_rec_mult=model_config.get('loss_rec_mult', 100), | |
| loss_decor_mult=model_config.get('loss_decor_mult', 10), | |
| lr=model_config.get('lr', 1e-4), | |
| use_cuda=torch.cuda.is_available() | |
| ) | |
| print("Fitting VAE model...") | |
| vae.fit(X, demo_data, demo_types) | |
| return vae, X, demo_data, demo_types | |
| def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True): | |
| """ | |
| Load fMRI data and demographics from HuggingFace dataset or local files | |
| """ | |
| if use_hf_dataset: | |
| # Load from HuggingFace Datasets | |
| from datasets import load_dataset | |
| print(f"Loading dataset from HuggingFace: {data_dir}") | |
| dataset = load_dataset(data_dir) | |
| print(f"Dataset columns: {dataset['train'].column_names}") | |
| # Get demographics directly from the dataset | |
| # Create a DataFrame from the dataset features | |
| demo_df = pd.DataFrame({ | |
| 'ID': dataset['train']['ID'], | |
| 'wab_aq': dataset['train']['wab_aq'], | |
| 'age': dataset['train']['age'], | |
| 'mpo': dataset['train']['mpo'], | |
| 'education': dataset['train']['education'], | |
| 'gender': dataset['train']['gender'], | |
| 'handedness': dataset['train']['handedness'] | |
| }) | |
| print(f"Loaded demographic data with {len(demo_df)} subjects") | |
| # Extract demographic data matching our expected format | |
| # Map the dataset columns to our expected format | |
| demo_data = [ | |
| demo_df['age'].values, # age at stroke -> age | |
| demo_df['gender'].values, # sex -> gender | |
| demo_df['mpo'].values, # months post stroke -> mpo | |
| demo_df['wab_aq'].values # wab score -> wab_aq | |
| ] | |
| # Check for FC matrices in the dataset | |
| fc_columns = [] | |
| for col in dataset['train'].column_names: | |
| if col.startswith("fc_") or "_fc" in col: | |
| fc_columns.append(col) | |
| if fc_columns: | |
| print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}") | |
| # Extract FC matrices | |
| fc_matrices = [] | |
| for fc_col in fc_columns: | |
| fc_matrices.append(dataset['train'][fc_col]) | |
| # If we have FC matrices, return them directly | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| return fc_matrices, demo_data, demo_types | |
| # If no FC matrices, look for .nii files | |
| nii_files = [] | |
| for col in dataset['train'].column_names: | |
| if col.endswith(".nii.gz") or col.endswith(".nii"): | |
| nii_files.append(dataset['train'][col]) | |
| if nii_files: | |
| print(f"Found {len(nii_files)} .nii files") | |
| else: | |
| print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.") | |
| # If no structured data is found, we can try to download raw files later | |
| else: | |
| # Original local file loading | |
| # Load demographics | |
| demo_df = pd.read_csv(demographic_file) | |
| demo_data = [ | |
| demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values, | |
| demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values, | |
| demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values, | |
| demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values | |
| ] | |
| # Load fMRI files | |
| nii_files = sorted(list(Path(data_dir).glob('*.nii.gz'))) | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| return nii_files, demo_data, demo_types | |
| def run_fc_analysis(data_dir="SreekarB/OSFData", | |
| demographic_file=None, | |
| latent_dim=32, | |
| nepochs=1000, | |
| bsize=16, | |
| save_model=True, | |
| use_hf_dataset=True, | |
| return_data=False): | |
| # Update MODEL_CONFIG with user-specified parameters | |
| MODEL_CONFIG.update({ | |
| 'latent_dim': latent_dim, | |
| 'nepochs': nepochs, | |
| 'bsize': bsize | |
| }) | |
| try: | |
| # Load data | |
| print("Loading data...") | |
| nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset) | |
| # For SreekarB/OSFData, directly generate synthetic FC matrices | |
| if data_dir == "SreekarB/OSFData" and use_hf_dataset: | |
| print("Using SreekarB/OSFData dataset with synthetic FC matrices...") | |
| X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types) | |
| # Check if we got FC matrices directly | |
| elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'): | |
| print("Using pre-computed FC matrices...") | |
| # Convert list of FC matrices to numpy array | |
| X = np.stack([np.array(fc) for fc in nii_files]) | |
| else: | |
| # Prepare data by converting fMRI to FC matrices | |
| print("Converting fMRI data to FC matrices...") | |
| X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types) | |
| # Print shapes and data types | |
| print(f"X shape: {X.shape}, type: {type(X)}") | |
| for i, d in enumerate(demo_data): | |
| print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}") | |
| # Train VAE and get data | |
| print("Training VAE...") | |
| try: | |
| # Use the proper DemoVAE implementation from src/demovae/sklearn.py | |
| vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG) | |
| if save_model: | |
| print("Saving model...") | |
| os.makedirs('models', exist_ok=True) | |
| # Use the save method from DemoVAE | |
| vae.save('models/vae_model.pth') | |
| print("Model saved successfully.") | |
| except Exception as e: | |
| print(f"Error during VAE training: {e}") | |
| raise | |
| # Get latent representations | |
| print("Getting latent representations...") | |
| latents = vae.get_latents(X) | |
| # Analyze results | |
| print("Analyzing demographic relationships...") | |
| demographics = { | |
| 'age': demo_data[0], | |
| 'months_post_onset': demo_data[2], | |
| 'wab_aq': demo_data[3] | |
| } | |
| analysis_results = analyze_fc_patterns(latents, demographics) | |
| # Generate new FC matrix | |
| print("Generating new FC matrices...") | |
| # Get data types from original demographic data for proper conversion | |
| demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data] | |
| # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error | |
| new_demographics = [ | |
| np.array([60.0], dtype=np.float64), # age | |
| np.array(['M'], dtype=np.str_), # gender | |
| np.array([12.0], dtype=np.float64), # months post onset | |
| np.array([80.0], dtype=np.float64) # wab score | |
| ] | |
| # Verify the demographic data arrays match the expected types | |
| print("Demographic data types:") | |
| for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)): | |
| print(f" {name}: shape={data.shape}, dtype={data.dtype}") | |
| print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80") | |
| try: | |
| generated_fc = vae.transform(1, new_demographics, demo_types) | |
| except Exception as e: | |
| print(f"Error generating new FC matrix: {e}") | |
| # Try with a fallback approach | |
| print("Trying alternative generation approach...") | |
| # If specific gender is causing issues, try the first gender from training data | |
| new_demographics[1] = np.array([demo_data[1][0]]) | |
| generated_fc = vae.transform(1, new_demographics, demo_types) | |
| reconstructed_fc = vae.transform(X, demo_data, demo_types) | |
| # Visualize results | |
| print("Creating visualizations...") | |
| fig = plot_fc_matrices(X[0], reconstructed_fc[0], generated_fc[0]) | |
| # If requested, return additional data for accuracy calculations | |
| if return_data: | |
| # Create a structured outcome measures dictionary | |
| outcome_measures = { | |
| 'wab_aq': demo_data[3], # WAB-AQ scores | |
| # Could add other outcome measures here | |
| } | |
| results = { | |
| 'vae': vae, | |
| 'X': X, | |
| 'latents': latents, | |
| 'demographics': demographics, | |
| 'reconstructed_fc': reconstructed_fc, | |
| 'generated_fc': generated_fc, | |
| 'analysis_results': analysis_results, | |
| 'outcome_measures': outcome_measures | |
| } | |
| return fig, results | |
| return fig | |
| except Exception as e: | |
| import traceback | |
| print(f"Error in run_fc_analysis: {str(e)}") | |
| print(traceback.format_exc()) | |
| # Create a dummy figure with error message | |
| import matplotlib.pyplot as plt | |
| fig = plt.figure(figsize=(10, 6)) | |
| plt.text(0.5, 0.5, f"Error: {str(e)}", | |
| horizontalalignment='center', verticalalignment='center', | |
| fontsize=12, color='red') | |
| plt.axis('off') | |
| # Return the error figure and empty results if requested | |
| if return_data: | |
| return fig, None | |
| return fig | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description='Run FC Analysis using VAE') | |
| parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData', | |
| help='HuggingFace dataset ID or directory containing fMRI data') | |
| parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv', | |
| help='Path to demographic data CSV file') | |
| parser.add_argument('--latent_dim', type=int, default=32, | |
| help='Dimension of latent space') | |
| parser.add_argument('--nepochs', type=int, default=1000, | |
| help='Number of training epochs') | |
| parser.add_argument('--bsize', type=int, default=16, | |
| help='Batch size for training') | |
| parser.add_argument('--no_save', action='store_false', | |
| help='Do not save the model') | |
| parser.add_argument('--use_local', action='store_true', | |
| help='Use local data instead of HuggingFace dataset') | |
| args = parser.parse_args() | |
| fig = run_fc_analysis( | |
| data_dir=args.data_dir, | |
| demographic_file=args.demographic_file, | |
| latent_dim=args.latent_dim, | |
| nepochs=args.nepochs, | |
| bsize=args.bsize, | |
| save_model=args.no_save, | |
| use_hf_dataset=not args.use_local | |
| ) | |
| fig.show() | |