diff --git "a/app.py" "b/app.py" --- "a/app.py" +++ "b/app.py" @@ -1,2723 +1,265 @@ +""" +Simplified app for Huggingface Spaces. +Provides a simple UI for VAE training and visualization. +""" import os -import sys - -# Set Huggingface cache directory to avoid permission issues -os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache') -os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True) - import gradio as gr -from main import run_analysis -from rcf_prediction import AphasiaTreatmentPredictor import numpy as np -# Configure matplotlib for headless environment -import matplotlib -matplotlib.use('Agg') # Use non-interactive backend -matplotlib.rcParams['figure.dpi'] = 100 -matplotlib.rcParams['savefig.dpi'] = 100 -import matplotlib.pyplot as plt -from data_preprocessing import preprocess_fmri_to_fc, process_single_fmri -from visualization import plot_fc_matrices, plot_learning_curves -import glob -from sklearn.metrics import mean_squared_error, r2_score -import json -import pickle import pandas as pd -import seaborn as sns +import matplotlib.pyplot as plt +from vae_model import DemoVAE, plot_learning_curves +import time +import tempfile import logging -from config import MODEL_CONFIG, PREDICTION_CONFIG +# Set up logging logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') logger = logging.getLogger(__name__) -class AphasiaPredictionApp: - def __init__(self): - self.vae = None - self.predictor = None - self.trained = False - self.latent_dim = MODEL_CONFIG['latent_dim'] - self.last_treatment_file = None # Track the last treatment file used - - def train_models(self, data_dir, latent_dim, nepochs, bsize, num_samples): - """ - Train VAE and Random Forest models - """ - # Train VAE and Random Forest - logger.info(f"Training models with data from {data_dir}") - logger.info(f"VAE params: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}") - - # Default prediction parameters from our config - outcome_variable = PREDICTION_CONFIG.get('default_outcome', 'wab_aq') - logger.info(f"Prediction: type=regression, outcome={outcome_variable}") - - figures = {} - - try: - # Run the full analysis pipeline - # For HuggingFace dataset, we don't need the demographic file physically - # as we'll extract demographics directly from the dataset - if data_dir == "SreekarB/OSFData1": - logger.info("Using SreekarB/OSFData1 dataset, loading demographic data directly from the dataset API") - - try: - # Import HF dataset libraries - from datasets import load_dataset - import pandas as pd - import os - import tempfile - from huggingface_hub import hf_hub_download - - try: - # First try direct download of demographic file - logger.info(f"Attempting to directly download FC_graph_covariate_data.csv from {data_dir}") - - temp_dir = tempfile.mkdtemp(prefix="hf_demo_") - - try: - # Try to download the demographic CSV directly - csv_path = hf_hub_download( - repo_id=data_dir, - filename="FC_graph_covariate_data.csv", - repo_type="dataset", - cache_dir=temp_dir - ) - logger.info(f"✓ Successfully downloaded FC_graph_covariate_data.csv directly") - - # Read the demographic file manually - try: - # Try UTF-8 first - demo_df = pd.read_csv(csv_path) - except UnicodeDecodeError: - # Try alternative encodings - for encoding in ['latin1', 'cp1252', 'iso-8859-1']: - try: - logger.info(f"Trying {encoding} encoding...") - demo_df = pd.read_csv(csv_path, encoding=encoding) - logger.info(f"Successfully loaded with {encoding} encoding") - break - except Exception: - continue - else: - # Try as Excel if all encodings fail - try: - logger.info("Trying to read as Excel file...") - demo_df = pd.read_excel(csv_path) - except Exception: - raise ValueError("Could not read demographic file with any encoding") - - # Create a mock dataset with the demographic data - dataset = { - 'train': demo_df - } - logger.info(f"Created dataset from manual file reading with {len(demo_df)} rows") - - # Extract the column names to validate - columns = demo_df.columns.tolist() - logger.info(f"Dataset columns: {columns}") - - except Exception as direct_err: - logger.warning(f"Could not download demographic file directly: {direct_err}") - - # Fall back to dataset loading - logger.info("Falling back to standard dataset loading") - # Try to load the dataset with encoding parameters - dataset = load_dataset(data_dir) - logger.info(f"Successfully loaded dataset: {data_dir}") - - # Check if we have the 'train' split - if 'train' not in dataset: - logger.error(f"Dataset {data_dir} does not have a 'train' split") - raise ValueError(f"Dataset {data_dir} is missing the required 'train' split") - - # Extract the column names to validate - columns = dataset['train'].column_names - logger.info(f"Dataset columns: {columns}") - - except UnicodeDecodeError as ude: - logger.error(f"Unicode decode error when loading dataset: {ude}") - - # Try to download the files directly without using datasets API - logger.warning("Trying direct file download instead of using datasets API") - - # Create a temporary directory for downloads - temp_dir = tempfile.mkdtemp(prefix="direct_download_") - - # Try different filenames and extensions for demographic data - demo_files = [ - "FC_graph_covariate_data.csv", - "FC_graph_covariate_data.xlsx", - "demographics.csv", - "demographics.xlsx", - "subjects.csv", - "patients.csv" - ] - - # Try each file - for demo_file in demo_files: - try: - file_path = hf_hub_download( - repo_id=data_dir, - filename=demo_file, - repo_type="dataset", - cache_dir=temp_dir - ) - - logger.info(f"Found demographic file: {demo_file}") - - # Try to read with different encodings - if demo_file.endswith('.csv'): - for encoding in ['utf-8', 'latin1', 'cp1252', 'iso-8859-1']: - try: - demo_df = pd.read_csv(file_path, encoding=encoding) - logger.info(f"Successfully read {demo_file} with {encoding} encoding") - break - except UnicodeDecodeError: - continue - else: - logger.warning(f"Could not read {demo_file} with any encoding") - continue - elif demo_file.endswith('.xlsx'): - try: - demo_df = pd.read_excel(file_path) - logger.info(f"Successfully read Excel file {demo_file}") - except Exception as excel_err: - logger.warning(f"Could not read Excel file {demo_file}: {excel_err}") - continue - - # If we got here, we successfully read a file - dataset = { - 'train': demo_df - } - columns = demo_df.columns.tolist() - logger.info(f"Loaded demographic data with {len(demo_df)} rows") - logger.info(f"Columns: {columns}") - break - - except Exception as file_err: - logger.warning(f"Could not download or read {demo_file}: {file_err}") - - else: - # If all files failed, raise an error - raise ValueError("Could not find or read any demographic data file from the dataset. Please provide a valid FC_graph_covariate_data.csv file.") - - # Check for required demographic columns - required_columns = ['ID', 'wab_aq', 'age', 'mpo', 'education', 'gender'] - missing_columns = [col for col in required_columns if col not in columns] - - if missing_columns: - # Try alternative column names - column_mapping = { - 'ID': ['ID', 'id', 'subject_id', 'Subject', 'PatientID'], - 'wab_aq': ['wab_aq', 'wab', 'WAB', 'WAB_AQ', 'aphasia_score'], - 'age': ['age', 'Age', 'age_at_stroke', 'patient_age'], - 'mpo': ['mpo', 'MPO', 'months_post_onset', 'months_post_stroke'], - 'education': ['education', 'educate', 'edu', 'years_education', 'educ'], - 'gender': ['gender', 'sex', 'Gender', 'Sex'] - } - - for missing_col in missing_columns: - for alt_col in column_mapping[missing_col]: - if alt_col in columns: - # Found an alternative - logger.info(f"Mapped column {alt_col} to {missing_col}") - if isinstance(dataset['train'], pd.DataFrame): - # If using DataFrame - dataset['train'][missing_col] = dataset['train'][alt_col] - else: - # If using HF dataset - dataset['train'] = dataset['train'].add_column( - missing_col, dataset['train'][alt_col] - ) - missing_columns.remove(missing_col) - break - - # If we still have missing columns after mapping, raise an error - if missing_columns: - logger.error(f"Missing required columns after mapping: {missing_columns}") - raise ValueError(f"The demographic data is missing required columns: {missing_columns}. Please ensure your FC_graph_covariate_data.csv file contains these columns or equivalent alternatives.") - - # First, check if FC_graph_covariate_data.csv exists in the dataset - try: - from huggingface_hub import hf_hub_download - import tempfile - - temp_dir = tempfile.mkdtemp(prefix="hf_csv_") - logger.info(f"Looking for FC_graph_covariate_data.csv in dataset {data_dir}") - - try: - csv_path = hf_hub_download( - repo_id=data_dir, - filename="FC_graph_covariate_data.csv", - repo_type="dataset", - cache_dir=temp_dir - ) - logger.info(f"✓ Successfully found FC_graph_covariate_data.csv in the dataset!") - - # Use the file directly instead of the API - demographic_file = csv_path - logger.info(f"Using FC_graph_covariate_data.csv from dataset: {demographic_file}") - except Exception as e: - logger.info(f"FC_graph_covariate_data.csv not found in dataset (this is okay): {e}") - # Fall back to API - demographic_file = "FROM_DATASET_API" - except Exception as e: - logger.warning(f"Error checking for FC_graph_covariate_data.csv: {e}") - demographic_file = "FROM_DATASET_API" - - treatment_file = "FROM_DATASET_API" # We'll generate this - - except Exception as e: - logger.error(f"Error loading HuggingFace dataset: {e}", exc_info=True) - raise - else: - # For local directories, look for the actual files - demographic_file = os.path.join(data_dir, "FC_graph_covariate_data.csv") - treatment_file = os.path.join(data_dir, "treatment_outcomes.csv") - - # Check data directory for files - if not os.path.exists(demographic_file): - # Try app directory as fallback - app_dir_demographic = os.path.join(os.path.dirname(os.path.abspath(__file__)), "FC_graph_covariate_data.csv") - if os.path.exists(app_dir_demographic): - demographic_file = app_dir_demographic - logger.info(f"Using FC_graph_covariate_data.csv from app directory: {demographic_file}") - else: - logger.error(f"FC_graph_covariate_data.csv not found in data directory or app directory") - raise FileNotFoundError(f"Demographic file not found. Please ensure FC_graph_covariate_data.csv exists in {data_dir} or the application directory.") - - # Create a simple fallback treatment outcomes file that will be used if no actual data is found - fallback_file = os.path.join('results', 'treatment_outcomes.csv') - try: - # Create a simple fallback treatment outcomes file - os.makedirs('results', exist_ok=True) - mock_outcomes = pd.DataFrame([ - {'subject_id': 'P001', 'treatment_type': 'Standard', 'outcome_score': 5.2}, - {'subject_id': 'P002', 'treatment_type': 'Intensive', 'outcome_score': 7.8}, - {'subject_id': 'P003', 'treatment_type': 'Standard', 'outcome_score': 3.1}, - {'subject_id': 'P004', 'treatment_type': 'Intensive', 'outcome_score': 9.4}, - {'subject_id': 'P005', 'treatment_type': 'Control', 'outcome_score': 1.2} - ]) - mock_outcomes.to_csv(fallback_file, index=False) - logger.info(f"Created standard treatment outcomes file with 5 subjects") - except Exception as e: - logger.error(f"Failed to create standard outcomes file: {e}") - - # Set default treatment file path to our fallback file - treatment_file = fallback_file - - # For SreekarB/OSFData1 dataset, optionally look for real treatment data - if data_dir == "SreekarB/OSFData1": - # Check if the user wants to skip behavioral data processing - skip_behavioral = PREDICTION_CONFIG.get('skip_behavioral_data', False) - - if skip_behavioral: - # Skip behavioral data processing entirely - logger.info("Skipping behavioral data processing as requested in config") - else: - # Try to find behavioral_data.csv in the dataset - try: - from huggingface_hub import hf_hub_download - import tempfile - - temp_dir = tempfile.mkdtemp(prefix="hf_behavioral_") - logger.info(f"Looking for behavioral_data.csv in dataset {data_dir}") - - try: - csv_path = hf_hub_download( - repo_id=data_dir, - filename="behavioral_data.csv", - repo_type="dataset", - cache_dir=temp_dir - ) - logger.info(f"✓ Successfully found behavioral_data.csv in the dataset!") - - # Process behavioral data to extract treatment outcomes - try: - real_treatment_file = process_behavioral_data_to_outcomes(csv_path) - treatment_file = real_treatment_file # Use the real treatment file if processing succeeded - # Store the treatment file path for later use - self.last_treatment_file = treatment_file - logger.info(f"Using processed behavioral data for treatment outcomes") - except Exception as proc_err: - logger.warning(f"Couldn't process behavioral data: {proc_err}, using standard outcomes") - # Keep using the fallback file - except Exception as e: - logger.warning(f"behavioral_data.csv not found or couldn't be processed: {e}") - - # Try to find any treatment outcomes file - try: - # Use our treatment outcomes file finder - real_treatment_file = find_treatment_outcomes_file(data_dir) - logger.info(f"Found treatment outcomes file: {real_treatment_file}") - - # Use the found file - treatment_file = real_treatment_file - # Store the treatment file path for later use - self.last_treatment_file = treatment_file - logger.info(f"Using real treatment outcomes file") - except Exception as find_err: - logger.warning(f"Couldn't find treatment outcomes file: {find_err}, using standard outcomes") - # Keep using the fallback file - except Exception as e: - logger.warning(f"Error during treatment data lookup: {e}, using standard outcomes") - # Keep using the fallback file - # Only check for treatment_file if we're not using the SreekarB/OSFData1 dataset - elif not os.path.exists(treatment_file): - # Try app directory as fallback - app_dir_treatment = os.path.join(os.path.dirname(os.path.abspath(__file__)), "treatment_outcomes.csv") - if os.path.exists(app_dir_treatment): - treatment_file = app_dir_treatment - logger.info(f"Using treatment outcomes file from app directory: {treatment_file}") - else: - logger.error(f"Treatment outcomes file not found in data directory or app directory") - raise FileNotFoundError(f"Treatment outcomes file not found. Please ensure treatment_outcomes.csv exists in {data_dir} or the application directory.") - - logger.info(f"Using demographic file: {demographic_file}") - logger.info(f"Using treatment file: {treatment_file}") - - # Special handling for HuggingFace dataset - if data_dir == "SreekarB/OSFData1": - # For NIfTI files, we need to search the API or download regardless of demographic source - logger.info("Searching for NIfTI files in the dataset...") - - # First check if NIfTI files exist in a local directory - local_nii_files = [] - - # Check different possible local paths, starting with user-specified directory - possible_paths = [] - - # Add user-specified directory from config if available - if PREDICTION_CONFIG.get('local_nii_dir'): - user_dir = PREDICTION_CONFIG.get('local_nii_dir') - if os.path.exists(user_dir): - possible_paths.append(user_dir) - logger.info(f"Checking user-specified NIfTI directory: {user_dir}") - - # Add other standard paths to check - possible_paths.extend([ - os.path.join(os.path.dirname(os.path.abspath(__file__)), "nii_files"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "nifti"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "fmri"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "nii_files"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "nifti"), - "/tmp/nii_files" # In case files were manually placed here - ]) - - for path in possible_paths: - if os.path.exists(path): - # Check for .nii or .nii.gz files - nii_files_here = [] - nii_files_here.extend(glob.glob(os.path.join(path, "*.nii"))) - nii_files_here.extend(glob.glob(os.path.join(path, "*.nii.gz"))) - - if nii_files_here: - local_nii_files.extend(nii_files_here) - logger.info(f"Found {len(nii_files_here)} local NIfTI files in {path}") - - if local_nii_files: - logger.info(f"Using {len(local_nii_files)} local NIfTI files instead of searching HuggingFace dataset") - - # Log filenames to help with debugging - for i, nii_file in enumerate(local_nii_files[:5]): # Log first 5 files - logger.info(f"Local NIfTI file {i+1}: {os.path.basename(nii_file)}") - - if len(local_nii_files) > 5: - logger.info(f"... and {len(local_nii_files) - 5} more files") - - nii_files = local_nii_files - else: - # If no local files found, find NIfTI files using our comprehensive search function - logger.info("No local NIfTI files found. Searching in the HuggingFace dataset...") - nii_files = find_nifti_files_in_hf_dataset(data_dir, dataset) - - # Log what was found - if nii_files: - logger.info(f"Found {len(nii_files)} NIfTI files in the dataset") - - # Log filenames to help with debugging - for i, nii_file in enumerate(nii_files[:5]): # Log first 5 files - logger.info(f"NIfTI file {i+1}: {os.path.basename(nii_file)}") - - if len(nii_files) > 5: - logger.info(f"... and {len(nii_files) - 5} more files") - else: - logger.warning("No NIfTI files found in the dataset. This will likely cause an error later.") - - if demographic_file == "FROM_DATASET_API": - logger.info("Using dataset API for demographics rather than files") - - # Extract demographics and prepare dataset for run_analysis - # Get demographics directly from the dataset - demo_data = [ - dataset['train']['age'], - dataset['train']['gender'], - dataset['train']['mpo'], - dataset['train']['wab_aq'] - ] - - # Create demo_types - demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] - - # Run with API data - results = run_analysis( - data_dir=data_dir, - demographic_file=None, # No physical file, data from API - treatment_file=treatment_file, # Using the treatment file (real or synthetic) - latent_dim=latent_dim, - nepochs=nepochs, - bsize=bsize, - save_model=True, - use_hf_dataset=True, - hf_dataset=dataset, - hf_nii_files=nii_files, - hf_demo_data=demo_data, - hf_demo_types=demo_types, - max_samples=num_samples - ) - else: - logger.info(f"Using FC_graph_covariate_data.csv from dataset: {demographic_file}") - - # Run with CSV file but still pass the NIfTI files - results = run_analysis( - data_dir=data_dir, - demographic_file=demographic_file, # Using the CSV file from dataset - treatment_file=treatment_file, # Now using the treatment file (real or synthetic) - latent_dim=latent_dim, - nepochs=nepochs, - bsize=bsize, - save_model=True, - use_hf_dataset=True, # Still using HF for NIfTI - hf_nii_files=nii_files, # Pass the found NIfTI files - max_samples=num_samples - ) - else: - # Standard call for local files - results = run_analysis( - data_dir=data_dir, - demographic_file=demographic_file, - treatment_file=treatment_file, - latent_dim=latent_dim, - nepochs=nepochs, - bsize=bsize, - save_model=True, - max_samples=num_samples - ) - - # Get the VAE figure from results - vae_fig = results.get('figures', {}).get('vae') - - figures['vae'] = vae_fig - - if results: - self.vae = results.get('vae') - self.predictor = results.get('predictor') - latents = results.get('latents') - demographics = results.get('demographics') - predictor_cv_results = results.get('predictor_cv_results') - - # Store the latent dimension - self.latent_dim = latent_dim - - # Mark models as trained - self.trained = True - - # Prepare prediction visualization if available - if self.predictor and predictor_cv_results: - try: - # Get the outcome variable data - outcomes = None - if demographics: - if outcome_variable == 'wab_aq' and 'wab_aq' in demographics: - outcomes = demographics['wab_aq'] - elif outcome_variable == 'age' and 'age' in demographics: - outcomes = demographics['age'] - elif (outcome_variable == 'mpo' or outcome_variable == 'months_post_onset') and 'mpo' in demographics: - outcomes = demographics['mpo'] - else: - # Try to find the outcome in demographics data - for key in demographics: - if outcome_variable.lower() in key.lower(): - outcomes = demographics[key] - logger.info(f"Found matching outcome variable: {key}") - break - - if outcomes is None: - logger.warning(f"Could not find outcome variable '{outcome_variable}' in demographics") - # Create a dummy array to prevent errors - if 'predictions' in predictor_cv_results: - outcomes = np.zeros_like(predictor_cv_results['predictions']) - else: - logger.warning("Cannot create prediction plots without outcome data") - except Exception as e: - logger.error(f"Error getting outcome variable: {e}") - outcomes = None - - # Create plots if we have the necessary data - if outcomes is not None and 'prediction_stds' in predictor_cv_results and 'predictions' in predictor_cv_results: - # Create prediction plots - prediction_fig = self.create_prediction_plots( - latents, - demographics, - outcomes, - predictor_cv_results['predictions'], - predictor_cv_results['prediction_stds'] - ) - figures['prediction'] = prediction_fig - - # Create feature importance plot if available - try: - feature_importance = self.predictor.get_feature_importance() - if feature_importance is not None: - importance_fig = self.create_importance_plot(feature_importance) - figures['importance'] = importance_fig - except Exception as e: - logger.warning(f"Could not create feature importance plot: {e}") - - logger.info("Training completed successfully") - - # Create learning curve plots if available - if 'fold_metrics' in predictor_cv_results: - learning_fig = self.create_learning_curve_plot( - predictor_cv_results['fold_metrics'] - ) - figures['learning'] = learning_fig - - except Exception as e: - logger.error(f"Error in training: {str(e)}", exc_info=True) - error_fig = plt.figure(figsize=(10, 6)) - - # Provide more helpful error message for common issues - error_message = str(e) - if "demographic" in error_message.lower() and "not found" in error_message.lower(): - error_message = f"Demographic file not found. Please ensure FC_graph_covariate_data.csv exists in your data directory or application directory." - elif "treatment" in error_message.lower() and "not found" in error_message.lower(): - error_message = f"Treatment outcomes file not found. Please ensure treatment_outcomes.csv exists in your data directory or application directory." - elif "cuda" in error_message.lower() or "gpu" in error_message.lower(): - error_message = f"GPU/CUDA error detected. Try running with CPU only." - - plt.text(0.5, 0.5, f"Error: {error_message}", - horizontalalignment='center', verticalalignment='center', - fontsize=12, color='red', wrap=True) - plt.axis('off') - figures['error'] = error_fig - - return figures - - def predict_treatment(self, fmri_file=None, age=50, sex="M", - months_post_stroke=12, wab_score=50, fc_matrix=None): - """ - Predict treatment outcome for a patient - - Args: - fmri_file: Path to patient's fMRI file - age: Patient's age at stroke - sex: Patient's sex (M/F) - months_post_stroke: Months since stroke - wab_score: Current WAB score - fc_matrix: Pre-processed FC matrix (if fMRI file not provided) - - Returns: - Prediction results and visualization - """ - if not self.trained: - return "Please train the models first!", None - - try: - # Process fMRI to FC matrix if provided - if fmri_file and not fc_matrix: - logger.info(f"Processing fMRI file: {fmri_file}") - # Use the single fMRI processing function - fc_matrix = process_single_fmri(fmri_file) - - if fc_matrix is None: - return "Please provide either an fMRI file or an FC matrix", None - - # Ensure FC matrix is properly shaped - if isinstance(fc_matrix, list): - fc_matrix = np.array(fc_matrix) - - # Get latent representation - logger.info("Extracting latent representation from FC matrix") - if len(fc_matrix.shape) == 2: # If matrix is 2D (e.g., 264x264) - # Convert to flattened upper triangular form - n = fc_matrix.shape[0] - indices = np.triu_indices(n, k=1) - fc_flattened = fc_matrix[indices] - fc_flattened = fc_flattened.reshape(1, -1) - latent = self.vae.get_latents(fc_flattened) - else: - # Assume already flattened - latent = self.vae.get_latents(fc_matrix.reshape(1, -1)) - - # Prepare demographics - demographics = { - 'age': np.array([float(age)]), - 'gender': np.array([sex]), - 'mpo': np.array([float(months_post_stroke)]), - 'wab_aq': np.array([float(wab_score)]), - 'education': np.array([16.0]) # Default value for education - } - - logger.info("Making prediction") - # Make prediction - if self.predictor is None: - return "Predictor model not trained", None - - # Make prediction using the model's predict method - prediction, prediction_std = self.predictor.predict(latent, demographics) - - # Create visualization - fig = self.plot_treatment_trajectory( - current_score=wab_score, - predicted_score=prediction[0], - months_post_stroke=months_post_stroke, - prediction_std=prediction_std[0] - ) - - result_text = f"Predicted treatment outcome: {prediction[0]:.2f} ± {2*prediction_std[0]:.2f}" - logger.info(result_text) - - return result_text, fig - - except Exception as e: - error_msg = f"Error in prediction: {str(e)}" - logger.error(error_msg, exc_info=True) - error_fig = plt.figure(figsize=(10, 6)) - - # Provide more helpful error message for common issues - if "fmri_file" in str(e).lower() or "file not found" in str(e).lower(): - error_msg = "Error: fMRI file not found or invalid. Please provide a valid NIfTI file." - elif "fc_matrix" in str(e).lower(): - error_msg = "Error: Invalid FC matrix format. Please ensure the matrix is properly formatted." - elif "predictor" in str(e).lower() and "none" in str(e).lower(): - error_msg = "Error: Prediction model not trained. Please train the model first." - elif "cuda" in str(e).lower() or "gpu" in str(e).lower(): - error_msg = "Error: GPU/CUDA error. Try running with CPU only." - else: - error_msg = f"Error in prediction: {str(e)}" - - plt.text(0.5, 0.5, error_msg, - horizontalalignment='center', verticalalignment='center', - fontsize=12, color='red', wrap=True) - plt.axis('off') - return error_msg, error_fig - - def plot_treatment_trajectory(self, current_score, predicted_score, - months_post_stroke, prediction_std, - treatment_duration=6): - """ - Create a visualization of predicted treatment trajectory - - Args: - current_score: Current WAB score - predicted_score: Predicted WAB score after treatment - months_post_stroke: Current months post stroke - prediction_std: Standard deviation of prediction - treatment_duration: Duration of treatment in months - - Returns: - matplotlib figure - """ - fig = plt.figure(figsize=(10, 6)) - - # X-axis: months - x = np.array([months_post_stroke, months_post_stroke + treatment_duration]) - - # Y-axis: WAB scores - y = np.array([current_score, predicted_score]) - - # Plot the trajectory - plt.plot(x, y, 'bo-', linewidth=2, label='Predicted Trajectory') - - # Add confidence interval - plt.fill_between( - x, - [y[0], y[1] - 2*prediction_std], - [y[0], y[1] + 2*prediction_std], - alpha=0.2, color='blue', label='95% Confidence Interval' - ) - - # Add reference lines - if current_score < predicted_score: - improvement = predicted_score - current_score - plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5, - label=f'Current WAB = {current_score:.1f}') - plt.axhline(y=predicted_score, color='g', linestyle='--', alpha=0.5, - label=f'Predicted WAB = {predicted_score:.1f} (+{improvement:.1f})') - else: - decline = current_score - predicted_score - plt.axhline(y=current_score, color='r', linestyle='--', alpha=0.5, - label=f'Current WAB = {current_score:.1f}') - plt.axhline(y=predicted_score, color='orange', linestyle='--', alpha=0.5, - label=f'Predicted WAB = {predicted_score:.1f} (-{decline:.1f})') - - # Add labels and title - plt.xlabel('Months Post Stroke') - plt.ylabel('WAB Score') - plt.title('Predicted Treatment Trajectory') - plt.legend(loc='best') - - # Set y-axis limits - plt.ylim([0, 100]) - - plt.tight_layout() - return fig - - def create_prediction_plots(self, latents, demographics, y_true, y_pred, y_std): - """Create prediction performance plots""" - fig = plt.figure(figsize=(12, 8)) - - # Create a 2x2 grid for plots - gs = plt.GridSpec(2, 2, figure=fig) - - # Plot predicted vs actual values - ax1 = fig.add_subplot(gs[0, 0]) - - # Regression plots - # Scatter plot - ax1.scatter(y_true, y_pred, alpha=0.7) - - # Add perfect prediction line - min_val = min(np.min(y_true), np.min(y_pred)) - max_val = max(np.max(y_true), np.max(y_pred)) - ax1.plot([min_val, max_val], [min_val, max_val], 'r--') - - ax1.set_xlabel('Actual Values') - ax1.set_ylabel('Predicted Values') - ax1.set_title('Predicted vs. Actual Values') - - # Add R² to the plot - r2 = r2_score(y_true, y_pred) - ax1.text(0.05, 0.95, f'R² = {r2:.4f}', transform=ax1.transAxes, - bbox=dict(facecolor='white', alpha=0.5)) - - # Plot residuals - ax2 = fig.add_subplot(gs[0, 1]) - residuals = y_true - y_pred - ax2.scatter(y_pred, residuals, alpha=0.7) - ax2.axhline(y=0, color='r', linestyle='--') - ax2.set_xlabel('Predicted Values') - ax2.set_ylabel('Residuals') - ax2.set_title('Residual Plot') - - # Plot prediction errors - ax3 = fig.add_subplot(gs[1, 0]) - ax3.errorbar(range(len(y_pred)), y_pred, yerr=2*y_std, fmt='o', alpha=0.7, - label='Predicted ± 2σ') - ax3.plot(range(len(y_true)), y_true, 'rx', alpha=0.7, label='Actual') - ax3.set_xlabel('Sample Index') - ax3.set_ylabel('Value') - ax3.set_title('Prediction with Error Bars') - ax3.legend() - - # Plot error distribution - ax4 = fig.add_subplot(gs[1, 1]) - ax4.hist(residuals, bins=20, alpha=0.7) - ax4.axvline(x=0, color='r', linestyle='--') - ax4.set_xlabel('Prediction Error') - ax4.set_ylabel('Frequency') - ax4.set_title('Error Distribution') - - plt.tight_layout() - return fig - - def create_importance_plot(self, feature_importance, top_n=15): - """Create feature importance plot""" - # If feature_importance is a DataFrame, use it directly - if isinstance(feature_importance, pd.DataFrame): - importance_df = feature_importance - else: - # Create DataFrame - importance_df = pd.DataFrame({ - 'feature': [f'Feature {i}' for i in range(len(feature_importance))], - 'importance': feature_importance - }) - - # Get top N features - top_features = importance_df.sort_values('importance', ascending=False).head(top_n) - - # Create plot - fig = plt.figure(figsize=(10, 6)) - plt.barh(range(len(top_features)), top_features['importance'], align='center') - plt.yticks(range(len(top_features)), top_features['feature']) - plt.xlabel('Importance') - plt.ylabel('Features') - plt.title(f'Top {top_n} Features by Importance') - plt.tight_layout() - - return fig - - def create_learning_curve_plot(self, fold_metrics): - """Create learning curve plots from cross-validation results""" - fig = plt.figure(figsize=(12, 6)) - - # For regression, show R² and RMSE - ax1 = plt.subplot(1, 2, 1) - ax2 = plt.subplot(1, 2, 2) - - # Plot R² for each fold - for i, metrics in enumerate(fold_metrics): - ax1.plot(i+1, metrics['r2'], 'bo') - - # Plot average R² - avg_r2 = np.mean([m['r2'] for m in fold_metrics]) - ax1.axhline(y=avg_r2, color='r', linestyle='--', - label=f'Average R² = {avg_r2:.4f}') - - ax1.set_xlabel('Fold') - ax1.set_ylabel('R²') - ax1.set_title('R² by Fold') - ax1.set_xticks(range(1, len(fold_metrics)+1)) - ax1.legend() - - # Plot RMSE for each fold - for i, metrics in enumerate(fold_metrics): - ax2.plot(i+1, metrics['rmse'], 'go') - - # Plot average RMSE - avg_rmse = np.mean([m['rmse'] for m in fold_metrics]) - ax2.axhline(y=avg_rmse, color='r', linestyle='--', - label=f'Average RMSE = {avg_rmse:.4f}') - - ax2.set_xlabel('Fold') - ax2.set_ylabel('RMSE') - ax2.set_title('RMSE by Fold') - ax2.set_xticks(range(1, len(fold_metrics)+1)) - ax2.legend() - - plt.tight_layout() - return fig +# Make sure directories exist +os.makedirs('models', exist_ok=True) +os.makedirs('results', exist_ok=True) -def calculate_fc_accuracy(original_fc, reconstructed_fc): - """ - Calculate accuracy metrics between original and reconstructed FC matrices - """ - # Mean Squared Error (lower is better) - mse = mean_squared_error(original_fc, reconstructed_fc) - - # Root Mean Squared Error (lower is better) - rmse = np.sqrt(mse) - - # R² Score (higher is better, 1 is perfect) - r2 = r2_score(original_fc, reconstructed_fc) - - # Correlation between matrices (higher is better) - corr = np.corrcoef(original_fc.flatten(), reconstructed_fc.flatten())[0, 1] - - # Custom similarity score based on normalized dot product (higher is better) - norm_dot = np.dot(original_fc.flatten(), reconstructed_fc.flatten()) / ( - np.linalg.norm(original_fc.flatten()) * np.linalg.norm(reconstructed_fc.flatten())) - - return { - "MSE": float(mse), - "RMSE": float(rmse), - "R²": float(r2), - "Correlation": float(corr), - "Cosine Similarity": float(norm_dot) - } +# Global app state +app_state = { + 'vae': None, + 'latents': None, + 'demographics': None, + 'fc_data': None, + 'vae_trained': False +} -def save_latents(latents, demographics, subjects=None, file_path='latents.pkl'): - """ - Save latent representations and associated demographics to file - """ - os.makedirs('results', exist_ok=True) - - # Create a dictionary with latents and demographics - data = { - 'latents': latents, - 'demographics': demographics - } - - if subjects is not None: - data['subjects'] = subjects - - # Save as pickle for easy loading in Python - with open(os.path.join('results', file_path), 'wb') as f: - pickle.dump(data, f) - - # Also save as JSON for more universal access - json_data = { - 'latents': latents.tolist() if isinstance(latents, np.ndarray) else latents, - 'demographics': {k: v.tolist() if isinstance(v, np.ndarray) else v - for k, v in demographics.items()} - } - - if subjects is not None: - json_data['subjects'] = subjects - - with open(os.path.join('results', file_path.replace('.pkl', '.json')), 'w') as f: - json.dump(json_data, f) - - return os.path.join('results', file_path) +# Function to convert vector to matrix for visualization +def vector_to_matrix(vector, size=10): + """Convert a vector to a square matrix for visualization""" + matrix = np.zeros((size, size)) + idx = 0 + # Fill upper triangle and mirror + for i in range(size): + for j in range(i+1, size): + matrix[i, j] = matrix[j, i] = vector[idx % len(vector)] + idx += 1 + # Set diagonal to 1.0 + np.fill_diagonal(matrix, 1.0) + return matrix -# Function to process behavioral data into treatment outcomes -def process_behavioral_data_to_outcomes(behavioral_file): - """ - Process behavioral_data.csv to create a treatment outcomes file - - The behavioral data contains: - - patient_id: Patient identifier - - Session: Session number - - Session Type: Baseline (B), Treatment, Post Treatment - - sess_acc: Session accuracy - - We'll convert this to treatment outcomes by: - 1. Finding baseline and post-treatment sessions for each patient - 2. Calculating improvement (post - baseline) - 3. Creating a treatment_outcomes.csv file with subject_id, treatment_type, outcome_score - - Args: - behavioral_file: Path to behavioral_data.csv - - Returns: - Path to generated treatment_outcomes.csv file - """ - # Create a simple mock outcomes file as a fallback - os.makedirs('results', exist_ok=True) - fallback_file = os.path.join('results', 'fallback_treatment_outcomes.csv') - - # Create a simple outcomes file with dummy data (useful as last resort) - try: - mock_outcomes = pd.DataFrame([ - {'subject_id': 'P001', 'treatment_type': 'Standard', 'outcome_score': 5.2}, - {'subject_id': 'P002', 'treatment_type': 'Intensive', 'outcome_score': 7.8}, - {'subject_id': 'P003', 'treatment_type': 'Standard', 'outcome_score': 3.1}, - {'subject_id': 'P004', 'treatment_type': 'Intensive', 'outcome_score': 9.4}, - {'subject_id': 'P005', 'treatment_type': 'Control', 'outcome_score': 1.2} - ]) - mock_outcomes.to_csv(fallback_file, index=False) - logger.info(f"Created fallback treatment outcomes file with 5 subjects") - except Exception as e: - logger.error(f"Failed to create fallback file: {e}") - - logger.info(f"Processing behavioral data from {behavioral_file}") - - # Create output file path - os.makedirs('results', exist_ok=True) - outcomes_file = os.path.join('results', 'behavioral_treatment_outcomes.csv') - +def train_vae(fc_file, demo_file, epochs=20, latent_dim=16, batch_size=8, progress=gr.Progress()): + """Train a VAE model on uploaded data""" try: - # Read the behavioral data with error handling for different formats - import pandas as pd - import numpy as np + # Reset state + app_state['vae_trained'] = False + app_state['vae'] = None + app_state['latents'] = None - try: - # Try regular CSV format first - behavioral_df = pd.read_csv(behavioral_file) - logger.info(f"Loaded behavioral data with {len(behavioral_df)} rows") - except UnicodeDecodeError as e: - logger.warning(f"Unicode decode error with behavioral file: {e}") + # Ensure uploaded files exist + if not fc_file or not os.path.exists(fc_file.name): + return "Error: Missing FC matrix file", None, None - # Try with different encodings - for encoding in ['latin1', 'cp1252', 'iso-8859-1']: - try: - logger.info(f"Trying {encoding} encoding...") - behavioral_df = pd.read_csv(behavioral_file, encoding=encoding) - logger.info(f"Successfully loaded with {encoding} encoding") - break - except Exception as enc_error: - logger.warning(f"Failed with {encoding} encoding: {enc_error}") - else: - # If all encodings fail, try Excel format - try: - logger.info("Trying to read as Excel file...") - behavioral_df = pd.read_excel(behavioral_file) - logger.info(f"Successfully loaded as Excel file with {len(behavioral_df)} rows") - except Exception as xl_error: - logger.error(f"Failed to read as Excel: {xl_error}") - raise ValueError(f"Could not read behavioral data file in any format") - - # Print column names for debugging - logger.info(f"Behavioral data columns: {behavioral_df.columns.tolist()}") - - # Try alternative column names for required fields - column_mapping = { - 'ID': ['ID', 'patient_id', 'subject_id', 'Subject', 'PatientID', 'id', 'patient', 'subj', 'sub'], - 'Session': ['Session', 'session', 'Session_Number', 'SessionNum', 'sess_num', 'session_num', 'time', 'timepoint'], - 'Session Type': ['Session Type', 'SessionType', 'Type', 'session_type', 'sess_type', 'phase', 'treatment_phase', 'study_phase', 'condition'], - 'sess_acc': ['sess_acc', 'Accuracy', 'accuracy', 'acc', 'session_accuracy', 'score', 'performance', 'wab', 'wab_score', 'value'] - } - - # Attempt to map columns - mapped_columns = {} - for target_col, alt_cols in column_mapping.items(): - if target_col in behavioral_df.columns: - mapped_columns[target_col] = target_col + # Load FC data + try: + progress(0.1, "Loading FC data...") + if fc_file.name.endswith('.npy'): + X = np.load(fc_file.name) + elif fc_file.name.endswith('.csv'): + X = pd.read_csv(fc_file.name).values else: - for alt_col in alt_cols: - if alt_col in behavioral_df.columns: - mapped_columns[target_col] = alt_col - logger.info(f"Mapped column {alt_col} to {target_col}") - break - - # Check what columns we found - logger.info(f"Mapped columns: {mapped_columns}") + # Try to interpret as text + X = np.loadtxt(fc_file.name) + + logger.info(f"Loaded FC data with shape: {X.shape}") + app_state['fc_data'] = X + except Exception as e: + logger.error(f"Error loading FC data: {e}") + return f"Error loading FC data: {str(e)}", None, None - # Determine how to proceed based on what we found - if 'ID' not in mapped_columns: - # Try to create patient IDs if not found - if 'ID' not in behavioral_df.columns: - logger.warning("No patient ID column found, creating synthetic IDs") - # Look for any identifier-like columns - for col in behavioral_df.columns: - if any(id_term in col.lower() for id_term in ['id', 'subject', 'patient', 'participant']): - behavioral_df['ID'] = behavioral_df[col] - mapped_columns['ID'] = col - logger.info(f"Using {col} as patient ID") - break + # Load demographic data if provided + try: + progress(0.2, "Loading demographic data...") + if demo_file and os.path.exists(demo_file.name): + demo_df = pd.read_csv(demo_file.name) + logger.info(f"Loaded demographics with shape: {demo_df.shape}") + + # Try to extract standard demographics + demographics = [] + + # Age + if 'age' in demo_df.columns: + age = demo_df['age'].values + elif 'age_at_stroke' in demo_df.columns: + age = demo_df['age_at_stroke'].values else: - # Create sequential IDs if no identifier found - behavioral_df['ID'] = [f"P{i+1:03d}" for i in range(len(behavioral_df))] - mapped_columns['ID'] = 'ID' - logger.warning("Created sequential patient IDs") - - # Handle session identification - if 'Session' not in mapped_columns: - # Try to create session numbers if not found - if 'Session' not in behavioral_df.columns: - logger.warning("No session number column found, creating sequential session numbers") - # Check if we have any time-related columns - time_columns = [col for col in behavioral_df.columns if any(time_term in col.lower() for time_term in ['time', 'session', 'visit', 'week'])] - if time_columns: - behavioral_df['Session'] = behavioral_df[time_columns[0]] - mapped_columns['Session'] = time_columns[0] - logger.info(f"Using {time_columns[0]} as session number") + age = np.random.normal(60, 10, len(X)) + logger.warning("Age column not found, using synthetic data") + demographics.append(age) + + # Sex + if 'sex' in demo_df.columns: + sex = demo_df['sex'].values + elif 'gender' in demo_df.columns: + sex = demo_df['gender'].values else: - # Create sequential session numbers for each patient - if 'ID' in mapped_columns: - behavioral_df['Session'] = behavioral_df.groupby(mapped_columns['ID']).cumcount() + 1 - else: - behavioral_df['Session'] = range(1, len(behavioral_df) + 1) - mapped_columns['Session'] = 'Session' - logger.warning("Created sequential session numbers") - - # Handle session type - if 'Session Type' not in mapped_columns: - # Try to create session types if not found - if 'Session Type' not in behavioral_df.columns: - logger.warning("No session type column found, inferring from session sequence") - # Create simple session type based on sequence: first=Baseline, last=Post, middle=Treatment - behavioral_df['Session Type'] = 'Treatment' - - # Group by patient ID if available - if 'ID' in mapped_columns: - # Get min and max session for each patient - session_col = mapped_columns.get('Session', 'Session') - id_col = mapped_columns.get('ID', 'ID') - - # Get first and last session for each patient - for patient in behavioral_df[id_col].unique(): - patient_sessions = behavioral_df[behavioral_df[id_col] == patient][session_col].sort_values() - if len(patient_sessions) > 0: - first_session = patient_sessions.iloc[0] - last_session = patient_sessions.iloc[-1] - - # Mark first as Baseline, last as Post - behavioral_df.loc[(behavioral_df[id_col] == patient) & - (behavioral_df[session_col] == first_session), - 'Session Type'] = 'Baseline' - - behavioral_df.loc[(behavioral_df[id_col] == patient) & - (behavioral_df[session_col] == last_session), - 'Session Type'] = 'Post Treatment' + sex = np.random.choice(['M', 'F'], len(X)) + logger.warning("Sex column not found, using synthetic data") + demographics.append(sex) + + # Months post stroke + if 'months_post_stroke' in demo_df.columns: + mps = demo_df['months_post_stroke'].values + elif 'mpo' in demo_df.columns: + mps = demo_df['mpo'].values + else: + mps = np.random.normal(24, 12, len(X)) + logger.warning("Months post stroke column not found, using synthetic data") + demographics.append(mps) + + # WAB score + if 'wab_score' in demo_df.columns: + wab = demo_df['wab_score'].values + elif 'wab_aq' in demo_df.columns: + wab = demo_df['wab_aq'].values else: - # Just use the first and last rows - if len(behavioral_df) > 0: - behavioral_df.loc[0, 'Session Type'] = 'Baseline' - if len(behavioral_df) > 1: - behavioral_df.loc[len(behavioral_df)-1, 'Session Type'] = 'Post Treatment' + wab = np.random.normal(65, 15, len(X)) + logger.warning("WAB score column not found, using synthetic data") + demographics.append(wab) - mapped_columns['Session Type'] = 'Session Type' - logger.warning("Created session types based on sequence") - - # Handle accuracy/score - if 'sess_acc' not in mapped_columns: - # Find any numeric columns that might contain scores - numeric_cols = behavioral_df.select_dtypes(include=['number']).columns.tolist() - score_candidates = [col for col in numeric_cols if any(score_term in col.lower() for score_term in - ['score', 'acc', 'wab', 'value', 'measure', 'perf', 'test'])] - - if score_candidates: - behavioral_df['sess_acc'] = behavioral_df[score_candidates[0]] - mapped_columns['sess_acc'] = score_candidates[0] - logger.info(f"Using {score_candidates[0]} as accuracy score") - elif numeric_cols: - # Just use the first numeric column - behavioral_df['sess_acc'] = behavioral_df[numeric_cols[0]] - mapped_columns['sess_acc'] = numeric_cols[0] - logger.warning(f"Using first numeric column {numeric_cols[0]} as accuracy score") else: - # No suitable column found - raise ValueError("No suitable accuracy/score column found in behavioral data") - - # Now work with the mapped columns - id_col = mapped_columns.get('ID', 'ID') - session_col = mapped_columns.get('Session', 'Session') - type_col = mapped_columns.get('Session Type', 'Session Type') - acc_col = mapped_columns.get('sess_acc', 'sess_acc') - - # Extract baseline and post-treatment sessions - outcome_data = [] - - # Get unique patient IDs - patient_ids = behavioral_df[id_col].unique() - logger.info(f"Found {len(patient_ids)} unique patients") - - for patient_id in patient_ids: - patient_data = behavioral_df[behavioral_df[id_col] == patient_id] - logger.info(f"Processing patient {patient_id} with {len(patient_data)} sessions") - - # Try to identify baseline and post sessions by string matching if possible - try: - # Look for Baseline sessions (may be labeled as 'B', 'Baseline', etc.) - baseline_mask = ( - patient_data[type_col].str.contains('B', case=False) | - patient_data[type_col].str.contains('base', case=False) | - patient_data[type_col].str.contains('pre', case=False) - ) - baseline_sessions = patient_data[baseline_mask] - - # Look for Post Treatment sessions - post_mask = ( - patient_data[type_col].str.contains('Post', case=False) | - patient_data[type_col].str.contains('final', case=False) | - ((patient_data[type_col].str.contains('Treatment', case=False)) & - (~patient_data[type_col].str.contains('Pre', case=False))) - ) - post_sessions = patient_data[post_mask] - except AttributeError: - # In case the column doesn't support string operations - logger.warning(f"Column {type_col} doesn't support string operations, using first/last approach") - baseline_sessions = pd.DataFrame() - post_sessions = pd.DataFrame() - - # If we can't find labeled sessions, use first and last session - if len(baseline_sessions) == 0 or len(post_sessions) == 0: - # Sort by session number if possible - try: - patient_data = patient_data.sort_values(session_col) - except: - logger.warning(f"Could not sort by {session_col}, using data as-is") - - baseline_sessions = patient_data.iloc[[0]] # First session - post_sessions = patient_data.iloc[[-1]] # Last session - logger.info(f"Using first/last approach for patient {patient_id}") - - # If we have both baseline and post sessions, calculate improvement - if len(baseline_sessions) > 0 and len(post_sessions) > 0: - # Use the average if multiple sessions - try: - baseline_acc = baseline_sessions[acc_col].mean() - post_acc = post_sessions[acc_col].mean() - - # Calculate improvement - improvement = post_acc - baseline_acc - - # Determine treatment type - if type_col in patient_data.columns: - try: - # Get middle sessions (between baseline and post) - all_sessions = patient_data.sort_values(session_col) - first_session = all_sessions[session_col].iloc[0] - last_session = all_sessions[session_col].iloc[-1] - - middle_mask = ( - (all_sessions[session_col] > first_session) & - (all_sessions[session_col] < last_session) - ) - middle_sessions = all_sessions[middle_mask] - - if len(middle_sessions) > 0 and type_col in middle_sessions.columns: - # Use most common treatment type - treatment_type = middle_sessions[type_col].mode()[0] - else: - # Default treatment type - treatment_type = "Standard" - except: - treatment_type = "Standard" - else: - treatment_type = "Standard" - - # Append to outcomes - outcome_data.append({ - 'subject_id': patient_id, - 'treatment_type': treatment_type, - 'outcome_score': improvement - }) - logger.info(f"Patient {patient_id}: Baseline={baseline_acc:.2f}, Post={post_acc:.2f}, Improvement={improvement:.2f}") - except Exception as e: - logger.warning(f"Could not calculate improvement for patient {patient_id}: {e}") - - # Create DataFrame and save - if outcome_data: - outcomes_df = pd.DataFrame(outcome_data) - outcomes_df.to_csv(outcomes_file, index=False) - logger.info(f"Created treatment outcomes file with {len(outcomes_df)} patients") - # Store the treatment file path for later use - self.last_treatment_file = outcomes_file - return outcomes_file - else: - # If we couldn't extract outcomes per patient, try a simpler approach - logger.warning("Could not extract patient-level outcomes, trying simpler approach") - - try: - # Calculate overall pre/post changes - behavioral_df = behavioral_df.sort_values(session_col) - first_half = behavioral_df.iloc[:len(behavioral_df)//2] - second_half = behavioral_df.iloc[len(behavioral_df)//2:] + logger.info("No demographics file provided, using synthetic data") + demographics = [ + np.random.normal(60, 10, len(X)), # age + np.random.choice(['M', 'F'], len(X)), # sex + np.random.normal(24, 12, len(X)), # months post stroke + np.random.normal(65, 15, len(X)) # WAB score + ] - pre_score = first_half[acc_col].mean() - post_score = second_half[acc_col].mean() - improvement = post_score - pre_score - - # Create a simple outcomes file - outcomes_df = pd.DataFrame([ - { - 'subject_id': 'GROUP', - 'treatment_type': 'Standard', - 'outcome_score': improvement - } - ]) - outcomes_df.to_csv(outcomes_file, index=False) - logger.warning(f"Created simplified treatment outcomes with group improvement: {improvement:.2f}") - # Store the treatment file path for later use - self.last_treatment_file = outcomes_file - return outcomes_file - except Exception as e: - logger.error(f"Could not create even simplified outcomes: {e}") - logger.warning("Falling back to predefined treatment outcomes") - return fallback_file + app_state['demographics'] = demographics + demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] - except Exception as e: - logger.error(f"Error processing behavioral data: {e}", exc_info=True) - logger.warning("Using fallback treatment outcomes file due to error") - # Return the fallback file instead of raising an error - return fallback_file - -# Function to look for treatment outcome files in the dataset -def find_treatment_outcomes_file(data_dir): - """ - Look for treatment outcomes file in the dataset - - Args: - data_dir: Dataset directory or HuggingFace dataset ID - - Returns: - Path to treatment outcomes file or None if not found - """ - logger.info(f"Looking for treatment outcomes file in {data_dir}") - - # Create a temporary directory for downloads - import tempfile - from huggingface_hub import hf_hub_download - - temp_dir = tempfile.mkdtemp(prefix="hf_treatment_") - - # Try different filenames for treatment outcomes - outcome_files = [ - "treatment_outcomes.csv", - "outcomes.csv", - "treatment_outcomes.xlsx", - "outcomes.xlsx", - "behavioral_data.csv", - "behavioral_data.xlsx", - "behavioral.csv", - "behavioral.xlsx", - "treatment_results.csv" - ] - - # Try each file - for outcome_file in outcome_files: - try: - file_path = hf_hub_download( - repo_id=data_dir, - filename=outcome_file, - repo_type="dataset", - cache_dir=temp_dir - ) - - logger.info(f"Found treatment outcomes file: {outcome_file}") - return file_path - - except Exception as e: - logger.debug(f"Could not find {outcome_file}: {e}") - - # If we get here, no files were found - logger.error("No treatment outcomes file found in the dataset") - - # Create a fallback file - fallback_file = os.path.join('results', 'fallback_treatment_outcomes.csv') - try: - # Create a simple fallback treatment outcomes file - os.makedirs('results', exist_ok=True) - mock_outcomes = pd.DataFrame([ - {'subject_id': 'P001', 'treatment_type': 'Standard', 'outcome_score': 5.2}, - {'subject_id': 'P002', 'treatment_type': 'Intensive', 'outcome_score': 7.8}, - {'subject_id': 'P003', 'treatment_type': 'Standard', 'outcome_score': 3.1}, - {'subject_id': 'P004', 'treatment_type': 'Intensive', 'outcome_score': 9.4}, - {'subject_id': 'P005', 'treatment_type': 'Control', 'outcome_score': 1.2} - ]) - mock_outcomes.to_csv(fallback_file, index=False) - logger.warning("Created and using fallback treatment outcomes file") - return fallback_file - except Exception as e: - logger.error(f"Failed to create fallback file: {e}") - raise FileNotFoundError(f"No treatment outcomes file found in {data_dir} and could not create fallback. Please provide a treatment_outcomes.csv file with columns: subject_id, treatment_type, outcome_score.") - -# Function to search and download NIfTI files from HuggingFace datasets -def find_nifti_files_in_hf_dataset(dataset_name, dataset=None): - """ - Find and download NIfTI files from a HuggingFace dataset - - Args: - dataset_name: Name of the HuggingFace dataset - dataset: Optional pre-loaded dataset object - - Returns: - List of downloaded NIfTI file paths - """ - logger.info(f"Searching for NIfTI files in dataset: {dataset_name}") - - # If dataset is not provided, load it - if dataset is None: - from datasets import load_dataset - try: - dataset = load_dataset(dataset_name) - logger.info(f"Loaded dataset: {dataset_name}") except Exception as e: - logger.error(f"Error loading dataset: {e}") - return [] - - # Load from HuggingFace Datasets - nii_files = [] - - # Create a temp directory for downloads - import tempfile - from huggingface_hub import hf_hub_download - import shutil - import json - - temp_dir = tempfile.mkdtemp(prefix="hf_nifti_") - logger.info(f"Created temporary directory for NIfTI files: {temp_dir}") - - # Log dataset information for debugging - logger.info(f"Dataset info: type={type(dataset)}") - if dataset is not None: - if isinstance(dataset, dict): - logger.info(f"Dataset is a dictionary with keys: {list(dataset.keys())}") - if 'train' in dataset: - train_type = type(dataset['train']) - logger.info(f"Train split type: {train_type}") - if hasattr(dataset['train'], 'shape'): - logger.info(f"Train split shape: {dataset['train'].shape}") - elif hasattr(dataset['train'], '__len__'): - logger.info(f"Train split length: {len(dataset['train'])}") - - # Log first few rows for pandas DataFrames - if isinstance(dataset['train'], pd.DataFrame): - try: - logger.info(f"DataFrame columns: {dataset['train'].columns.tolist()}") - logger.info(f"DataFrame preview: \n{dataset['train'].head(2).to_string()}") - except Exception as e: - logger.error(f"Error logging DataFrame info: {e}") - elif isinstance(dataset, pd.DataFrame): - logger.info(f"Dataset is a pandas DataFrame with shape: {dataset.shape}") - try: - logger.info(f"DataFrame columns: {dataset.columns.tolist()}") - logger.info(f"DataFrame preview: \n{dataset.head(2).to_string()}") - except Exception as e: - logger.error(f"Error logging DataFrame info: {e}") - - try: - # First approach: Check if there are any columns containing file paths - nii_columns = [] - - # Handle both HuggingFace dataset and pandas DataFrame - if isinstance(dataset, dict) and 'train' in dataset: - # It's a HuggingFace dataset object - try: - if hasattr(dataset['train'], 'column_names'): - # Standard HuggingFace dataset - columns = dataset['train'].column_names - else: - # It might be a pandas DataFrame - columns = dataset['train'].columns.tolist() - - for col in columns: - # Check if column name suggests NIfTI files - if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower(): - nii_columns.append(col) - # Or check if column contains file paths - elif len(dataset['train']) > 0: - # Try to get first value, handling both Dataset and DataFrame - try: - if hasattr(dataset['train'], '__getitem__'): - first_val = dataset['train'][0][col] - else: - first_val = dataset['train'][col].iloc[0] - - if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')): - nii_columns.append(col) - except Exception as e: - logger.debug(f"Error checking first value of column {col}: {e}") - except Exception as e: - logger.error(f"Error inspecting dataset columns: {e}") - elif isinstance(dataset, pd.DataFrame): - # It's just a pandas DataFrame directly - try: - columns = dataset.columns.tolist() - - for col in columns: - # Check if column name suggests NIfTI files - if 'nii' in col.lower() or 'nifti' in col.lower() or 'fmri' in col.lower(): - nii_columns.append(col) - # Or check if column contains file paths - elif len(dataset) > 0: - try: - first_val = dataset[col].iloc[0] - if isinstance(first_val, str) and (first_val.endswith('.nii') or first_val.endswith('.nii.gz')): - nii_columns.append(col) - except Exception as e: - logger.debug(f"Error checking first value of column {col}: {e}") - except Exception as e: - logger.error(f"Error inspecting DataFrame columns: {e}") - else: - logger.error(f"Unexpected dataset type: {type(dataset)}") + logger.error(f"Error processing demographics: {e}") + return f"Error processing demographics: {str(e)}", None, None + + # Initialize model + progress(0.3, "Initializing model...") + model = DemoVAE(nepochs=epochs, batch_size=batch_size, latent_dim=latent_dim) + + # Train model + progress(0.4, "Training VAE model...") + train_losses, val_losses = model.fit(X, demographics, demo_types) + + # Save model + progress(0.7, "Saving model...") + model.save('models/vae_model.pt') + app_state['vae'] = model + app_state['vae_trained'] = True + + # Generate latent representations + progress(0.8, "Generating latent representations...") + latents = model.get_latents(X) + app_state['latents'] = latents + np.save('results/latents.npy', latents) + + # Create visualizations + progress(0.9, "Creating visualizations...") + + # Learning curves + learning_fig = plot_learning_curves(model.train_losses, model.val_losses) + learning_img = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + learning_fig.savefig(learning_img.name) + plt.close(learning_fig) + + # FC visualization + progress(0.95, "Creating FC visualizations...") + reconstructed = model.transform(X, demographics, demo_types) + np.save('results/reconstructed.npy', reconstructed) + + generated = model.transform(1, [d[0] for d in demographics], demo_types) + np.save('results/generated.npy', generated) + + fc_fig, axes = plt.subplots(1, 3, figsize=(15, 5)) + original_matrix = vector_to_matrix(X[0]) + recon_matrix = vector_to_matrix(reconstructed[0]) + gen_matrix = vector_to_matrix(generated[0]) + + # Plot matrices + titles = ['Original', 'Reconstructed', 'Generated'] + for i, matrix in enumerate([original_matrix, recon_matrix, gen_matrix]): + im = axes[i].imshow(matrix, cmap='RdBu_r', vmin=-1, vmax=1) + axes[i].set_title(titles[i]) + axes[i].axis('off') + + fc_fig.subplots_adjust(right=0.8) + cbar_ax = fc_fig.add_axes([0.85, 0.15, 0.05, 0.7]) + fc_fig.colorbar(im, cax=cbar_ax) + + fc_img = tempfile.NamedTemporaryFile(suffix='.png', delete=False) + fc_fig.savefig(fc_img.name) + plt.close(fc_fig) + + progress(1.0, "Training complete!") + return "Training completed successfully!", learning_img.name, fc_img.name - if nii_columns: - logger.info(f"Found columns that may contain NIfTI files: {nii_columns}") - - for col in nii_columns: - logger.info(f"Processing column '{col}'...") - - # Handle different dataset types - try: - # Get the column data - if isinstance(dataset, dict) and 'train' in dataset: - if hasattr(dataset['train'], 'column_names'): - # It's a standard HuggingFace dataset - col_data = dataset['train'][col] - else: - # It's a DataFrame wrapped in dict - col_data = dataset['train'][col].values - elif isinstance(dataset, pd.DataFrame): - # It's a DataFrame directly - col_data = dataset[col].values - else: - logger.error(f"Unexpected dataset type: {type(dataset)}") - continue - - # Process the column data - for i, item in enumerate(col_data): - if not isinstance(item, str): - logger.info(f"Item {i} in column {col} is not a string but {type(item)}") - continue - - if not (item.endswith('.nii') or item.endswith('.nii.gz')): - logger.info(f"Item {i} in column {col} is not a NIfTI file: {item}") - continue - - logger.info(f"Downloading {item} from dataset {dataset_name}...") - except Exception as e: - logger.error(f"Error processing column {col}: {e}") - - try: - # Attempt to download with explicit filename - file_path = hf_hub_download( - repo_id=dataset_name, - filename=item, - repo_type="dataset", - cache_dir=temp_dir - ) - nii_files.append(file_path) - logger.info(f"✓ Successfully downloaded {item}") - except Exception as e1: - logger.error(f"Error downloading with explicit filename: {e1}") - - # Second attempt: try with the item's basename - try: - basename = os.path.basename(item) - logger.info(f"Trying with basename: {basename}") - file_path = hf_hub_download( - repo_id=dataset_name, - filename=basename, - repo_type="dataset", - cache_dir=temp_dir - ) - nii_files.append(file_path) - logger.info(f"✓ Successfully downloaded {basename}") - except Exception as e2: - logger.error(f"Error downloading with basename: {e2}") - - # Third attempt: check if it's a binary blob in the dataset - try: - # Handle different dataset types for binary data - binary_data = None - - if isinstance(dataset, dict) and 'train' in dataset: - if hasattr(dataset['train'], '__getitem__') and hasattr(dataset['train'][i], 'keys') and 'bytes' in dataset['train'][i]: - # Standard HuggingFace dataset with binary data - binary_data = dataset['train'][i]['bytes'] - elif hasattr(dataset['train'], 'iloc') and 'bytes' in dataset['train'].columns: - # DataFrame with bytes column - binary_data = dataset['train'].iloc[i]['bytes'] - elif isinstance(dataset, pd.DataFrame) and 'bytes' in dataset.columns: - # Direct DataFrame with bytes column - binary_data = dataset.iloc[i]['bytes'] - - if binary_data is not None: - logger.info("Found binary data in dataset, saving to temporary file...") - temp_file = os.path.join(temp_dir, basename) - with open(temp_file, 'wb') as f: - f.write(binary_data) - nii_files.append(temp_file) - logger.info(f"✓ Saved binary data to {temp_file}") - except Exception as e3: - logger.error(f"Error handling binary data: {e3}") - - # Last resort: look for the file locally - local_path = os.path.join(os.getcwd(), item) - if os.path.exists(local_path): - nii_files.append(local_path) - logger.info(f"✓ Found {item} locally") - else: - logger.warning(f"❌ Warning: Could not find {item} anywhere") - - # Second approach: Try to find NIfTI files in dataset repository directly - if not nii_files: - logger.info("No NIfTI files found in dataset columns. Trying direct repository search...") - - try: - from huggingface_hub import list_repo_files, hf_hub_download - - # Try to list all files in the repository - try: - logger.info("Listing all repository files...") - all_repo_files = list_repo_files(dataset_name, repo_type="dataset") - logger.info(f"Found {len(all_repo_files)} files in repository") - - # First prioritize P*_rs.nii files - p_rs_files = [f for f in all_repo_files if f.endswith('_rs.nii') and f.startswith('P')] - - # Then include all other NIfTI files - other_nii_files = [f for f in all_repo_files if (f.endswith('.nii') or f.endswith('.nii.gz')) and f not in p_rs_files] - - # Combine, with P*_rs.nii files first - nii_repo_files = p_rs_files + other_nii_files - - if nii_repo_files: - preview = nii_repo_files[:5] if len(nii_repo_files) > 5 else nii_repo_files - logger.info(f"Found {len(nii_repo_files)} NIfTI files in repository: {preview}...") - - # Download each file - for nii_file in nii_repo_files: - try: - file_path = hf_hub_download( - repo_id=dataset_name, - filename=nii_file, - repo_type="dataset", - cache_dir=temp_dir - ) - nii_files.append(file_path) - logger.info(f"✓ Downloaded {nii_file}") - except Exception as e: - logger.error(f"Error downloading {nii_file}: {e}") - except Exception as e: - logger.error(f"Error listing repository files: {e}") - logger.info("Will try alternative approaches...") - - # If repo listing fails, try with common NIfTI file patterns directly - if not nii_files: - logger.info("Trying common NIfTI file patterns...") - - # Focus specifically on P*_rs.nii pattern - patterns = [] - - # Generate P01_rs.nii through P30_rs.nii - for i in range(1, 31): # Try subjects 1-30 - patterns.append(f"P{i:02d}_rs.nii") - - # Also try with .nii.gz extension - for i in range(1, 31): - patterns.append(f"P{i:02d}_rs.nii.gz") - - # Include a few other common patterns as fallbacks - patterns.extend([ - "sub-01_task-rest_bold.nii.gz", # BIDS format - "fmri.nii.gz", "bold.nii.gz", - "rest.nii.gz" - ]) - - for pattern in patterns: - try: - logger.info(f"Trying to download {pattern}...") - file_path = hf_hub_download( - repo_id=dataset_name, - filename=pattern, - repo_type="dataset", - cache_dir=temp_dir - ) - nii_files.append(file_path) - logger.info(f"✓ Successfully downloaded {pattern}") - except Exception as e: - logger.debug(f"Failed to download {pattern}") - - # If we still couldn't find any files, check if data files are nested - if not nii_files: - logger.info("Checking for nested data files...") - nested_paths = ["data/", "raw/", "nii/", "derivatives/", "fmri/", "nifti/"] - - for path in nested_paths: - for pattern in patterns: - nested_file = f"{path}{pattern}" - try: - logger.info(f"Trying to download {nested_file}...") - file_path = hf_hub_download( - repo_id=dataset_name, - filename=nested_file, - repo_type="dataset", - cache_dir=temp_dir - ) - nii_files.append(file_path) - logger.info(f"✓ Successfully downloaded {nested_file}") - # If we found one file in this directory, try to find all files in it - try: - all_files_in_dir = [f for f in all_repo_files if f.startswith(path)] - nii_files_in_dir = [f for f in all_files_in_dir if f.endswith('.nii') or f.endswith('.nii.gz')] - logger.info(f"Found {len(nii_files_in_dir)} additional NIfTI files in {path}") - - for nii_file in nii_files_in_dir: - if nii_file != nested_file: # Skip the one we already downloaded - try: - file_path = hf_hub_download( - repo_id=dataset_name, - filename=nii_file, - repo_type="dataset", - cache_dir=temp_dir - ) - nii_files.append(file_path) - logger.info(f"✓ Downloaded {nii_file}") - except Exception as e: - logger.error(f"Error downloading {nii_file}: {e}") - except Exception as e: - logger.error(f"Error finding additional files in {path}: {e}") - except Exception as e: - pass - - except Exception as e: - logger.error(f"Error during repository exploration: {e}") - - # If we still don't have any files, try to search for P*_rs.nii pattern specifically - if not nii_files: - logger.info("Trying to find files matching P*_rs.nii pattern specifically...") - - try: - # List all files in the repository (if we haven't already) - if 'all_repo_files' not in locals(): - from huggingface_hub import list_repo_files - try: - all_repo_files = list_repo_files(dataset_name, repo_type="dataset") - except Exception as e: - logger.error(f"Error listing repo files: {e}") - all_repo_files = [] - - # Look for files matching the pattern exactly (P*_rs.nii) - pattern_files = [f for f in all_repo_files if '_rs.nii' in f and f.startswith('P')] - - # If we don't find any exact matches, try a more relaxed pattern - if not pattern_files: - pattern_files = [f for f in all_repo_files if 'rs.nii' in f.lower()] - - if pattern_files: - logger.info(f"Found {len(pattern_files)} files matching rs.nii pattern") - - # Download each file - for pattern_file in pattern_files: - try: - file_path = hf_hub_download( - repo_id=dataset_name, - filename=pattern_file, - repo_type="dataset", - cache_dir=temp_dir - ) - nii_files.append(file_path) - logger.info(f"✓ Downloaded {pattern_file}") - except Exception as e: - logger.error(f"Error downloading {pattern_file}: {e}") - except Exception as e: - logger.error(f"Error searching for pattern files: {e}") - - logger.info(f"Found total of {len(nii_files)} NIfTI files") except Exception as e: - logger.error(f"Unexpected error during NIfTI file search: {e}") - - return nii_files - -# Make sure directories exist for saving results and models -os.makedirs('results', exist_ok=True) -os.makedirs('models', exist_ok=True) - -def add_synthetic_data_notification(figure, prefer_real_data, use_synthetic_nifti, use_synthetic_fc): - """Add notification to figure if synthetic data was used despite preference for real data""" - if figure and prefer_real_data and (use_synthetic_nifti or use_synthetic_fc): - plt.figure(figure.number) - plt.figtext(0.5, 0.01, "Real Data Preferred - Synthetic Data Used Only When Necessary", - fontsize=9, color='blue', ha='center') - figure.canvas.draw() - return figure + logger.error(f"Error in VAE training: {str(e)}") + return f"Error: {str(e)}", None, None -def create_learning_figure(vae): - """Helper function to create VAE learning curve figure""" - plt.close('all') # Close previous figures - - # First check if loss data exists in the VAE object - has_train_losses = hasattr(vae, 'train_losses') and isinstance(vae.train_losses, (list, tuple)) and len(vae.train_losses) > 0 - has_val_losses = hasattr(vae, 'val_losses') and isinstance(vae.val_losses, (list, tuple)) and len(vae.val_losses) > 0 - - # Log the status for debugging - if has_train_losses: - logger.info(f"Found training losses: {len(vae.train_losses)} points") - else: - logger.warning("No training loss data found in VAE model") - - if has_val_losses: - logger.info(f"Found validation losses: {len(vae.val_losses)} points") - else: - logger.warning("No validation loss data found in VAE model") - - # If we have both train and validation losses, create the learning curve - if has_train_losses and has_val_losses: - logger.info(f"Creating learning curve with {len(vae.train_losses)} loss points") - try: - fig = plot_learning_curves(vae.train_losses, vae.val_losses) - # Force rendering - fig.canvas.draw() - logger.info("Successfully created learning curve figure") - return fig - except Exception as e: - logger.error(f"Error creating learning curve: {e}") - # Fall through to the default figure below - - # If we're missing one type of loss data but have the other, we can create a partial plot - elif has_train_losses: - logger.info("Creating learning curve with training losses only") - try: - # Create dummy validation losses (same as training but offset) - dummy_val = [t * 1.1 for t in vae.train_losses] - fig = plot_learning_curves(vae.train_losses, dummy_val) - plt.title("VAE Learning Curve (Training Only)") - plt.figtext(0.5, 0.01, "Note: Validation data unavailable", - ha='center', fontsize=10, color='red') - fig.canvas.draw() - logger.info("Created partial learning curve with training data only") - return fig - except Exception as e: - logger.error(f"Error creating partial learning curve: {e}") - # Fall through to the default figure below - - # Create a default figure if no loss data is available or plotting failed - logger.warning("No complete loss data found - creating placeholder learning figure") - fig = plt.figure(figsize=(10, 6)) - plt.title("VAE Learning Curve Data Unavailable", color='darkred') - plt.xlabel("Epoch") - plt.ylabel("Loss") - plt.text(0.5, 0.5, "Learning curves will appear here after training", - ha='center', va='center', transform=plt.gca().transAxes, - fontsize=14) - plt.text(0.5, 0.4, "Try using more training epochs to see learning progress", - ha='center', va='center', transform=plt.gca().transAxes, - fontsize=12, color='darkblue') - plt.grid(True, alpha=0.3) - plt.axis('on') - fig.canvas.draw() - return fig - -def find_real_nifti_files(max_samples=2): - """Find real NIfTI files in the dataset or local directories, limited to the specified number""" - real_files = [] - - # Check possible local paths for real NIfTI files - possible_paths = [] - - # Add user-specified directory from config if available - if PREDICTION_CONFIG.get('local_nii_dir'): - user_dir = PREDICTION_CONFIG.get('local_nii_dir') - if os.path.exists(user_dir): - possible_paths.append(user_dir) - - # Add standard paths to check - possible_paths.extend([ - os.path.join(os.path.dirname(os.path.abspath(__file__)), "nii_files"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "nifti"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "fmri"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "nii_files"), - os.path.join(os.path.dirname(os.path.abspath(__file__)), "data", "nifti"), - "/tmp/nii_files", - "/tmp/hf_nifti*" # HuggingFace temp directories - ]) - - # Search all possible paths - for path_pattern in possible_paths: - # Handle glob patterns - import glob - for path in glob.glob(path_pattern): - if os.path.exists(path): - # Look for .nii and .nii.gz files - nii_files = [] - nii_files.extend(glob.glob(os.path.join(path, "*.nii"))) - nii_files.extend(glob.glob(os.path.join(path, "*.nii.gz"))) - real_files.extend(nii_files) - - # Limit to the specified number of files - if len(real_files) >= max_samples: - real_files = real_files[:max_samples] - logger.info(f"Limited to {max_samples} NIfTI files: {[os.path.basename(f) for f in real_files]}") - return real_files - - # If we found fewer than the specified number, create synthetic files - if len(real_files) < max_samples: - import tempfile - import numpy as np - import nibabel as nib - - logger.info(f"Found only {len(real_files)} real NIfTI files, creating synthetic files to reach {max_samples}") - - # Create a temporary directory for synthetic files - temp_dir = tempfile.mkdtemp(prefix="synthetic_nifti_") - - # Create synthetic files (or however many we need to reach max_samples) - for i in range(max_samples - len(real_files)): - # Create random data - vol_shape = (60, 75, 60, 50) # x, y, z, time - data = np.random.randn(*vol_shape) - - # Create NIfTI file - nii_img = nib.Nifti1Image(data, np.eye(4)) - - # Save to temp directory - file_path = os.path.join(temp_dir, f"P{i+1:02d}_rs.nii.gz") - nib.save(nii_img, file_path) - - real_files.append(file_path) - - logger.info(f"Created {max_samples - len(real_files)} synthetic files") - - logger.info(f"Using {len(real_files)} NIfTI files: {[os.path.basename(f) for f in real_files]}") - return real_files - -def process_vae_results(results, prefer_real_data=True, use_synthetic_nifti=False, use_synthetic_fc=False): - """Process VAE results and create visualizations with data preference handling""" - # Get the VAE and figures - vae = results.get('vae') - fc_figure = results.get('figures', {}).get('vae') - - # Create learning curve figure - learning_fig = create_learning_figure(vae) - - # Add notification if synthetic data was used despite preference for real data - fc_figure = add_synthetic_data_notification(fc_figure, prefer_real_data, - use_synthetic_nifti, use_synthetic_fc) - - return fc_figure, learning_fig - -def create_interface(): - """Create the Gradio interface""" - app = AphasiaPredictionApp() - - with gr.Blocks(title="Aphasia Treatment Trajectory Prediction") as interface: - gr.Markdown("# Aphasia Treatment Trajectory Prediction") - - with gr.Tabs(): - # Tab 1: Configuration - with gr.Tab("1. Configuration"): - with gr.Row(): - with gr.Column(scale=1): - data_dir = gr.Textbox( - label="Data Directory or HuggingFace Dataset ID", - value="SreekarB/OSFData1" - ) - local_nii_dir = gr.Textbox( - label="Local NIfTI Files Directory (Optional)", - value="", - placeholder="/path/to/nii_files", - info="If provided, NIfTI files from this directory will be used instead of searching the dataset" - ) - latent_dim = gr.Slider( - minimum=8, maximum=64, step=8, - label="Latent Dimensions", value=32 - ) - nepochs = gr.Slider( - minimum=100, maximum=5000, step=100, - label="Number of Epochs", value=200 # Reduced for faster demos - ) - - with gr.Column(scale=1): - bsize = gr.Slider( - minimum=8, maximum=64, step=8, - label="Batch Size", value=16 - ) - num_samples = gr.Slider( - minimum=2, maximum=30, step=1, - label="Number of Samples to Use", value=2, - info="Maximum number of NIfTI images to download and process" - ) - use_hf_dataset = gr.Checkbox( - label="Use HuggingFace Dataset", value=True - ) - - with gr.Accordion("Advanced Data Options", open=False): - skip_behavioral = gr.Checkbox( - label="Skip Behavioral Data Processing", - value=PREDICTION_CONFIG.get('skip_behavioral_data', True), - info="Use pre-defined treatment outcomes instead of processing behavioral data" - ) - use_synthetic_nifti = gr.Checkbox( - label="Use Synthetic NIfTI Data", - value=True, - info="Generate synthetic NIfTI files if real ones aren't found" - ) - use_synthetic_fc = gr.Checkbox( - label="Use Synthetic FC Matrices", - value=True, - info="Generate synthetic FC matrices if processing fails" - ) - - # Add data preference controls - with gr.Accordion("Data Visualization Preferences", open=False): - prefer_real_data = gr.Checkbox( - label="Prefer Real Data for Plots", - value=True, - info="When enabled, synthetic data will only be used if real data processing fails" - ) - - train_vae_btn = gr.Button("Train VAE Model", variant="primary") - - gr.Markdown("After configuring the parameters, click 'Train VAE Model' and proceed to the 'VAE Results' tab →") - - # Tab 2: VAE Results - with gr.Tab("2. VAE Training Results"): - gr.Markdown("### Functional Connectivity Matrix Visualization") - - with gr.Row(): - with gr.Column(scale=3): - fc_plot = gr.Plot(label="FC Matrices (Original/Reconstructed/Generated)") - with gr.Column(scale=1): - with gr.Group(): - gr.Markdown("### Subject Demographics") - demo_info = gr.Textbox(label="Subject Information", interactive=False) - - with gr.Row(): - generate_fc_btn = gr.Button("Generate FC Matrix from Real Data", variant="secondary") - - with gr.Row(): - learning_plot = gr.Plot(label="VAE Learning Curves") - - gr.Markdown("After reviewing VAE results, proceed to the 'Random Forest Prediction' tab →") - - # Tab 3: Random Forest Prediction - with gr.Tab("3. Random Forest Prediction"): - gr.Markdown("### Random Forest Model Training") - gr.Markdown("First complete the VAE training in the Configuration tab, then configure and train the Random Forest model below:") - - with gr.Row(): - with gr.Column(scale=1): - prediction_type = gr.Radio( - label="Prediction Type", - choices=["regression"], - value="regression" - ) - outcome_variable = gr.Dropdown( - label="Outcome Variable", - choices=["wab_aq", "age", "mpo", "education"], - value="wab_aq" - ) - - with gr.Column(scale=1): - rf_n_estimators = gr.Slider( - minimum=10, maximum=500, step=10, - label="Number of Trees", value=100 - ) - rf_max_depth = gr.Slider( - minimum=3, maximum=50, step=1, - label="Max Tree Depth", value=10, - info="Set to 0 for unlimited depth" - ) - rf_cv_folds = gr.Slider( - minimum=2, maximum=10, step=1, - label="Cross-validation Folds", value=5 - ) - - train_rf_btn = gr.Button("Train Random Forest Model", variant="primary") - - gr.Markdown("### Random Forest Results") - - with gr.Row(): - with gr.Column(scale=1): - importance_plot = gr.Plot(label="Feature Importance") - with gr.Column(scale=1): - prediction_plot = gr.Plot(label="Prediction Performance") - - with gr.Row(): - rf_metrics = gr.Textbox(label="Model Performance Metrics") - - gr.Markdown("After Random Forest training completes, proceed to the 'Treatment Prediction' tab →") - - # Tab 4: Predict Treatment - with gr.Tab("4. Treatment Prediction"): - gr.Markdown("### Predict Individual Treatment Outcomes") - gr.Markdown("After completing VAE and Random Forest training, you can predict treatment outcomes for individual patients:") - - with gr.Row(): - with gr.Column(scale=1): - fmri_file = gr.File(label="Patient fMRI Data (NIfTI file)") - with gr.Column(scale=1): - with gr.Group(): - gr.Markdown("### Patient Demographics") - age = gr.Number(label="Age at Stroke", value=60) - sex = gr.Dropdown(choices=["M", "F"], label="Sex", value="M") - months = gr.Number(label="Months Post Stroke", value=12) - wab = gr.Number(label="Current WAB Score", value=50) - - predict_btn = gr.Button("Predict Treatment Outcome", variant="primary") - - with gr.Row(): - prediction_text = gr.Textbox(label="Prediction Result") - - with gr.Row(): - trajectory_plot = gr.Plot(label="Predicted Treatment Trajectory") - - # Define various handler functions for the different tabs - - # Store shared state between tabs - app_state = { - 'vae': None, - 'latents': None, - 'demographics': None, - 'predictor': None, - 'vae_trained': False, - 'rf_trained': False - } +def create_ui(): + """Create the Gradio UI""" + with gr.Blocks(title="FC Matrix VAE Demo") as app: + gr.Markdown("# Functional Connectivity VAE Demo") + gr.Markdown("Upload FC matrices and train a VAE model to analyze them.") + + with gr.Tab("Train VAE"): + with gr.Row(): + with gr.Column(): + fc_file = gr.File(label="FC Matrix File (CSV or NPY)") + demo_file = gr.File(label="Demographics File (CSV, optional)") + + with gr.Row(): + epochs = gr.Slider(5, 100, 20, step=5, label="Training Epochs") + latent_dim = gr.Slider(8, 64, 16, step=4, label="Latent Dimension") + batch_size = gr.Slider(4, 32, 8, step=4, label="Batch Size") + + train_btn = gr.Button("Train VAE Model") + status = gr.Textbox(label="Status") + + with gr.Column(): + learning_plot = gr.Image(label="Learning Curves") + fc_plot = gr.Image(label="FC Matrices") + + train_btn.click( + fn=train_vae, + inputs=[fc_file, demo_file, epochs, latent_dim, batch_size], + outputs=[status, learning_plot, fc_plot] + ) - # Tab 2: Generate FC Matrix handler - def handle_generate_fc(progress=gr.Progress()): - """Generate FC matrices from a real NIfTI file with demographic information""" - if not app_state['vae_trained'] or app_state['vae'] is None: - error_msg = "You must train the VAE model first in the Configuration tab" - return gr.update(value=error_msg), None + with gr.Tab("About"): + gr.Markdown(""" + ## About this App - try: - # Initialize progress tracking - progress(0, desc="Starting FC matrix generation") - - # Get number of samples from config or use default - max_samples = PREDICTION_CONFIG.get('num_samples', 2) - - # Find real NIfTI files up to the specified number - progress(0.1, desc="Searching for NIfTI files") - real_files = find_real_nifti_files(max_samples=max_samples) - - if not real_files: - error_msg = "No real NIfTI files found. Please train the model with real data first." - return gr.update(value=error_msg), None - - # Select a random file - import random - selected_file = random.choice(real_files) - filename = os.path.basename(selected_file) - - # Try to extract patient ID - patient_id = "Unknown" - if 'P' in filename and '_' in filename: - patient_id = filename.split('_')[0] - - # Process the file to get FC matrix - try: - # For testing, first try with synthetic data allowed to be more reliable - progress(0.3, desc=f"Processing {filename}") - fc_matrix = process_single_fmri(selected_file, allow_synthetic=True) - logger.info(f"Successfully processed file: {filename}") - progress(0.4, desc="Extracted FC matrix") - - # Create fixed demographics for testing - if "P1" in filename or "P01" in filename: - # First patient - age = 65.0 - sex = "M" - months_post_stroke = 12.0 - wab_score = 45.0 - else: - # Second patient - age = 55.0 - sex = "F" - months_post_stroke = 6.0 - wab_score = 65.0 - - demo_info = f"Subject ID: {patient_id}\nAge: {age:.1f}\nSex: {sex}\nMonths Post Stroke: {months_post_stroke:.1f}\nWAB Score: {wab_score:.1f}" - - # Get reconstructed and generated FC matrices - progress(0.5, desc="Preparing VAE inputs") - vae = app_state['vae'] - - # Prepare demo data in the format expected by the VAE - demo_data = [ - np.array([age]), - np.array([sex]), - np.array([months_post_stroke]), - np.array([wab_score]) - ] - demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] - - # Reshape FC matrix if needed - if len(fc_matrix.shape) == 1: - fc_matrix = fc_matrix.reshape(1, -1) - - # Get reconstruction - progress(0.6, desc="Generating reconstruction") - reconstructed = vae.transform(fc_matrix, demo_data, demo_types) - - # Also generate a new FC matrix with the same demographics - progress(0.7, desc="Generating synthetic FC") - generated = vae.transform(1, demo_data, demo_types) - - # Create visualization - progress(0.8, desc="Creating visualization") - plt.close('all') - from visualization import plot_fc_matrices, vector_to_matrix - - # Convert from vector to matrix if needed - if len(fc_matrix.shape) == 2 and fc_matrix.shape[0] == 1: - original = vector_to_matrix(fc_matrix[0]) - else: - original = vector_to_matrix(fc_matrix) - - if len(reconstructed.shape) == 2 and reconstructed.shape[0] == 1: - recon = vector_to_matrix(reconstructed[0]) - else: - recon = vector_to_matrix(reconstructed) - - if len(generated.shape) == 2 and generated.shape[0] == 1: - gen = vector_to_matrix(generated[0]) - else: - gen = vector_to_matrix(generated) - - fc_fig = plot_fc_matrices(original, recon, gen) - plt.suptitle(f"FC Matrices for {patient_id} - Testing Mode", fontsize=14) - fc_fig.canvas.draw() - - # Complete the progress bar - progress(1.0, desc="FC visualization complete!") - - return gr.update(value=demo_info), fc_fig - - except Exception as e: - logger.error(f"Error processing NIfTI file: {e}") - error_msg = f"Error processing NIfTI file: {str(e)}\nTrying another file..." - - # Try the other test file if available - if len(real_files) > 1: - # Just use the other file - other_file = [f for f in real_files if f != selected_file][0] - selected_file = other_file - error_msg += f"\nTrying: {os.path.basename(selected_file)}" - - try: - # For testing, allow synthetic data to increase reliability - fc_matrix = process_single_fmri(selected_file, allow_synthetic=True) - - # Use fixed demographics for the second test file - age = 55.0 - sex = "F" - months_post_stroke = 6.0 - wab_score = 65.0 - - demo_info = f"Subject ID: {os.path.basename(selected_file).split('_')[0]}\nAge: {age:.1f}\nSex: {sex}\nMonths Post Stroke: {months_post_stroke:.1f}\nWAB Score: {wab_score:.1f}" - - # Rest of processing code would be here, but for brevity - # we'll use a fallback approach for testing - - # Create simple matrices for testing - plt.close('all') - fc_fig = plt.figure(figsize=(15, 5)) - plt.suptitle(f"Test FC Matrices (Fallback)", fontsize=14) - plt.figtext(0.5, 0.5, "Testing Mode - Using alternative file", ha='center') - fc_fig.canvas.draw() - - return gr.update(value=demo_info), fc_fig - - except Exception as e2: - error_msg += f"\nFailed again: {str(e2)}" - - return gr.update(value=error_msg), None - - except Exception as e: - logger.error(f"Error generating FC from real data: {e}") - error_msg = f"Error generating FC from real data: {str(e)}" - return gr.update(value=error_msg), None - - # Tab 1: VAE Training Handler - def handle_vae_training(data_dir, local_nii_dir, latent_dim, nepochs, bsize, num_samples, use_hf_dataset, - skip_behavioral, use_synthetic_nifti, use_synthetic_fc, prefer_real_data=True): - """Train the VAE model and display FC visualization and learning curves""" - # Store config values - PREDICTION_CONFIG['skip_behavioral_data'] = skip_behavioral - PREDICTION_CONFIG['use_synthetic_nifti'] = False # Force to use only real NIfTI data - PREDICTION_CONFIG['use_synthetic_fc'] = False # Force to use only real FC matrices - PREDICTION_CONFIG['strict_real_data'] = True # Strictly use real data only - PREDICTION_CONFIG['num_samples'] = num_samples + This app trains a Variational Autoencoder (VAE) on functional connectivity (FC) matrices. - # Store the local NIfTI directory if provided - if local_nii_dir and os.path.exists(local_nii_dir): - PREDICTION_CONFIG['local_nii_dir'] = local_nii_dir - logger.info(f"Using local NIfTI directory: {local_nii_dir}") - else: - PREDICTION_CONFIG['local_nii_dir'] = None + ### Features: + * Load FC matrices from CSV or NPY files + * Incorporate demographic data (age, sex, etc.) + * Visualize learning curves + * Compare original, reconstructed and generated FC matrices - # Log info - logger.info(f"Training VAE model with data from: {data_dir}") - logger.info(f"VAE parameters: latent_dim={latent_dim}, epochs={nepochs}, batch_size={bsize}") + ### Input Format: + * FC matrices should be provided as vectors (flattened upper triangular portion of symmetric matrices) + * Demographics file should be CSV with columns for age, sex, months_post_stroke, and wab_score - # Create a subset of app.train_models functionality that just trains the VAE - try: - # Start by setting up data for the VAE - from vae_model import DemoVAE - from data_preprocessing import load_and_preprocess_data, download_and_cache_dataset - from main import run_analysis - import numpy as np - import os - - # Prepare VAE training parameters - MODEL_CONFIG.update({ - 'latent_dim': latent_dim, - 'nepochs': nepochs, - 'bsize': bsize - }) - - # First, find and preprocess data - logger.info("Looking for data in directory and preprocessing...") - - # This part is similar to app.train_models but only focuses on VAE - if data_dir == "SreekarB/OSFData1" and use_hf_dataset: - # Use our improved download and cache function for HuggingFace datasets - logger.info(f"Using improved HuggingFace dataset caching for {data_dir}") - try: - # Get the dataset and cached NIfTI files - _, _, nii_files = download_and_cache_dataset(data_dir) - - if nii_files and len(nii_files) > 0: - logger.info(f"Successfully found {len(nii_files)} NIfTI files in dataset cache") - - # Limit to the specified number of samples if needed - if num_samples is not None and len(nii_files) > num_samples: - logger.info(f"Limiting to {num_samples} NIfTI files as specified") - nii_files = nii_files[:num_samples] - - # Use the cached files in the train_models function - results = app.train_models( - data_dir=data_dir, - latent_dim=latent_dim, - nepochs=nepochs, - bsize=bsize, - num_samples=num_samples, - hf_nii_files=nii_files # Use our cached files - ) - else: - # Fallback to standard train_models - logger.warning("No NIfTI files found in cache, using standard train_models") - results = app.train_models( - data_dir=data_dir, - latent_dim=latent_dim, - nepochs=nepochs, - bsize=bsize, - num_samples=num_samples - ) - except Exception as cache_err: - logger.error(f"Error using cached dataset: {cache_err}") - # Fallback to standard train_models - results = app.train_models( - data_dir=data_dir, - latent_dim=latent_dim, - nepochs=nepochs, - bsize=bsize, - num_samples=num_samples - ) - else: - # Local directory case or non-HuggingFace dataset - results = app.train_models( - data_dir=data_dir, - latent_dim=latent_dim, - nepochs=nepochs, - bsize=bsize, - num_samples=num_samples - ) - - # Store results in app_state - app_state['vae'] = results.get('vae', None) - app_state['latents'] = results.get('latents', None) - app_state['demographics'] = results.get('demographics', None) - app_state['vae_trained'] = True - - # Process VAE results with data preference - fc_figure, learning_fig = process_vae_results( - results, - prefer_real_data=prefer_real_data, - use_synthetic_nifti=use_synthetic_nifti, - use_synthetic_fc=use_synthetic_fc - ) - - # Return the visualizations - return [fc_figure, learning_fig] - except Exception as e: - logger.error(f"Error in VAE training: {str(e)}", exc_info=True) - error_fig = plt.figure(figsize=(10, 6)) - plt.text(0.5, 0.5, f"Error: {str(e)}", - horizontalalignment='center', verticalalignment='center', - fontsize=12, color='red', wrap=True) - plt.axis('off') - - # Return error figures for both outputs - return [error_fig, error_fig] - - # Tab 2: Random Forest Training Handler - def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds, progress=gr.Progress()): - """Train the Random Forest model using the VAE latent representations""" - # Initialize progress tracking - progress(0, desc="Starting Random Forest training") - - # Try to load the VAE model if it's not already trained - if not app_state.get('vae_trained', False) or app_state.get('latents') is None: - progress(0.1, desc="Loading or training VAE model first") - try: - # Try to load the VAE model from disk - from vae_model import DemoVAE - vae_path = os.path.join('models', 'vae_model.pt') - if os.path.exists(vae_path): - logger.info("Loading saved VAE model...") - vae = DemoVAE() - vae.load(vae_path) - app_state['vae'] = vae - - # Only use real data for training and visualization - logger.info("Using loaded VAE model with real data only...") - - # Set flag to indicate VAE model is loaded, but not using synthetic data - app_state['vae_trained'] = True - - # Try to load previously saved latents if they exist - if os.path.exists('results/latents.npy'): - try: - logger.info("Loading saved latent representations...") - latents = np.load('results/latents.npy') - app_state['latents'] = latents - logger.info(f"Loaded {len(latents)} real latent vectors") - - # Try to load real demographics if available - if os.path.exists('temp_demographics.csv'): - logger.info("Loading demographics from temp_demographics.csv") - demo_df = pd.read_csv('temp_demographics.csv') - app_state['demographics'] = { - 'age_at_stroke': demo_df['age'].values, - 'sex': demo_df['sex'].values, - 'months_post_stroke': demo_df['months_post_stroke'].values, - 'wab_score': demo_df['wab_score'].values - } - else: - logger.warning("No real demographic data found") - except Exception as e: - logger.error(f"Error loading real latents: {e}") - logger.warning("Will not use synthetic data") - else: - logger.warning("No real latent representations found") - logger.warning("Will not use synthetic data") - else: - # Don't train with synthetic data in strict real data mode - logger.info("VAE model not found and using strict real data mode.") - logger.warning("Cannot train VAE model without real data") - - # Set state to indicate VAE is not trained - app_state['vae_trained'] = False - - # Show message about requiring real data - status_msg = "No VAE model available. Please train with real data first." - return { - tab_rf: gr.update(visible=False), - tab_vae: gr.update(visible=True), - status: status_msg, - vae_status: "Model not trained. Upload real data and train with it." - } - except Exception as e: - error_fig = plt.figure(figsize=(10, 6)) - message = f"Error: Unable to load or train VAE model: {str(e)}" - plt.text(0.5, 0.5, message, - horizontalalignment='center', verticalalignment='center', - fontsize=14, color='red') - plt.axis('off') - - # Return error for both outputs - return [error_fig, error_fig, f"Error: Unable to load or train VAE model: {str(e)}"] - - try: - # Update RF configuration - PREDICTION_CONFIG['default_outcome'] = outcome_variable - PREDICTION_CONFIG['n_estimators'] = rf_n_estimators - PREDICTION_CONFIG['max_depth'] = rf_max_depth if rf_max_depth > 0 else None - PREDICTION_CONFIG['cv_folds'] = rf_cv_folds - - # We only use regression for prediction - logger.info(f"Training Random Forest Regression model: outcome={outcome_variable}") - logger.info(f"RF parameters: n_estimators={rf_n_estimators}, max_depth={rf_max_depth}, cv_folds={rf_cv_folds}") - - # Get data from app_state - progress(0.4, desc="Preparing latent features and demographics") - latents = app_state['latents'] - demographics = app_state['demographics'] - - # Train Random Forest predictor - progress(0.5, desc="Setting up Random Forest predictor") - from rcf_prediction import AphasiaTreatmentPredictor - import pandas as pd - import numpy as np - - # Need to find treatment outcomes data - # This would normally be loaded in train_models, so we need - # to mock it here or load from app_state - if hasattr(app, 'last_treatment_file') and os.path.exists(app.last_treatment_file): - treatment_file = app.last_treatment_file - treatment_df = pd.read_csv(treatment_file) - treatment_outcomes = treatment_df['outcome_score'].values - - # Initialize predictor - predictor = AphasiaTreatmentPredictor( - n_estimators=rf_n_estimators, - max_depth=rf_max_depth if rf_max_depth > 0 else None - ) - - # Cross-validate - progress(0.6, desc="Performing cross-validation") - cv_results = predictor.cross_validate( - latents=latents, - demographics=demographics, - treatment_outcomes=treatment_outcomes, - n_splits=rf_cv_folds - ) - progress(0.7, desc="Cross-validation complete") - - # Fit final model - progress(0.8, desc="Training final model") - predictor.fit(latents, demographics, treatment_outcomes) - - # Store in app_state - progress(0.85, desc="Saving model") - app_state['predictor'] = predictor - app_state['rf_trained'] = True - - # Create feature importance plot - progress(0.9, desc="Creating visualizations") - importance_fig = predictor.plot_feature_importance() - - # Create prediction performance plot - predictions = cv_results['predictions'] - prediction_stds = cv_results['prediction_stds'] - - performance_fig = plt.figure(figsize=(8, 6)) - - # Check if we have valid predictions - if len(treatment_outcomes) > 0 and len(predictions) == len(treatment_outcomes): - # Only create scatter plot if we have matching data - plt.scatter(treatment_outcomes, predictions) - - # Reference line - min_val = min(np.min(treatment_outcomes), np.min(predictions)) - max_val = max(np.max(treatment_outcomes), np.max(predictions)) - plt.plot([min_val, max_val], [min_val, max_val], 'r--') - - # Confidence band - plt.fill_between(treatment_outcomes, - predictions - 2*prediction_stds, - predictions + 2*prediction_stds, - alpha=0.2, color='gray') - - plt.xlabel('Actual Outcome') - plt.ylabel('Predicted Outcome') - - # Get performance metrics - metrics_text = "" - mean_metrics = cv_results.get('mean_metrics', {}) - - r2 = mean_metrics.get('r2', 0) - rmse = mean_metrics.get('rmse', 0) - plt.title(f'Treatment Outcome Prediction\nR² = {r2:.3f}, RMSE = {rmse:.3f}') - metrics_text = f"Regression Model Performance:\nR² = {r2:.4f}\nRMSE = {rmse:.4f}" - else: - # Handle case with no data - plt.text(0.5, 0.5, "No prediction data available", - ha='center', va='center', transform=plt.gca().transAxes) - metrics_text = "No performance metrics available" - - plt.tight_layout() - - # Complete the progress - progress(1.0, desc="Random Forest training complete!") - return [importance_fig, performance_fig, metrics_text] - else: - # No treatment file available - error_fig = plt.figure(figsize=(10, 6)) - message = "Error: Treatment outcomes file not found. Please retrain the VAE in Tab 1." - plt.text(0.5, 0.5, message, - horizontalalignment='center', verticalalignment='center', - fontsize=14, color='red') - plt.axis('off') - - return [error_fig, error_fig, "Error: Treatment outcomes file not found."] - - except Exception as e: - logger.error(f"Error in RF training: {str(e)}", exc_info=True) - error_fig = plt.figure(figsize=(10, 6)) - plt.text(0.5, 0.5, f"Error: {str(e)}", - horizontalalignment='center', verticalalignment='center', - fontsize=12, color='red', wrap=True) - plt.axis('off') - - return [error_fig, error_fig, f"Error: {str(e)}"] - - # Connect the tab handlers - - # VAE Training tab with automatic tab switching - train_vae_btn.click( - fn=handle_vae_training, - inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, num_samples, use_hf_dataset, - skip_behavioral, use_synthetic_nifti, use_synthetic_fc, prefer_real_data], - outputs=[fc_plot, learning_plot] - ).then( - # After training completes, switch to the VAE Results tab - lambda: gr.Tabs(selected=1) # The VAE Results tab has index 1 - ) - - # Generate FC matrices from real data - generate_fc_btn.click( - fn=handle_generate_fc, - inputs=[], - outputs=[demo_info, fc_plot] - ) - - # Random Forest Training tab - train_rf_btn.click( - fn=handle_rf_training, - inputs=[prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds], - outputs=[importance_plot, prediction_plot, rf_metrics] - ) - - # Tab 3: Treatment Prediction Handler - def handle_treatment_prediction(fmri_file, age, sex, months, wab): - """Predict treatment outcome for a new patient""" - # Try to load models if they are not already trained - if not app_state.get('vae_trained', False) or not app_state.get('rf_trained', False): - try: - # First check for VAE model - from vae_model import DemoVAE - vae_path = os.path.join('models', 'vae_model.pt') - rf_path = os.path.join('models', 'predictor_model.pt') - - vae_loaded = False - rf_loaded = False - - # Try to load the VAE model - if not app_state.get('vae_trained', False) and os.path.exists(vae_path): - logger.info("Loading saved VAE model...") - vae = DemoVAE() - vae.load(vae_path) - app_state['vae'] = vae - app_state['vae_trained'] = True - vae_loaded = True - - # Try to load the RF model - if not app_state.get('rf_trained', False) and os.path.exists(rf_path): - logger.info("Loading saved RF predictor model...") - from main import RandomForestPredictor - - # Load the model - loaded_data = torch.load(rf_path) - predictor = RandomForestPredictor() - predictor.model = loaded_data['predictor_state'] - predictor.feature_importance = loaded_data.get('feature_importance', {}) - - app_state['predictor'] = predictor - app_state['rf_trained'] = True - rf_loaded = True - - # If we couldn't load both models in strict real data mode - if not (vae_loaded and rf_loaded): - logger.info("Strict real data mode: Not using synthetic data") - - # Show a message to the user - return { - status: "Cannot use synthetic data in strict real data mode. Please train with real data first.", - rf_status: "Not trained. Upload real data and train the VAE model first." - } - except Exception as e: - error_message = f"Error: Unable to load or train required models: {str(e)}" - error_fig = plt.figure(figsize=(10, 6)) - plt.text(0.5, 0.5, error_message, - horizontalalignment='center', verticalalignment='center', - fontsize=14, color='red') - plt.axis('off') - - return [error_message, error_fig] + ### Model Architecture: + * Simple feedforward VAE with demographic conditioning + * Latent space can be specified (default 16 dimensions) + * MSE reconstruction loss + """) - # Use the trained models from app_state for prediction - try: - # Set up prediction - if app_state.get('vae') is None or app_state.get('predictor') is None: - error_message = "Error: Models not properly available" - return [error_message, None] - - # Create a temporary prediction app with our trained models - temp_app = AphasiaPredictionApp() - temp_app.vae = app_state['vae'] - temp_app.predictor = app_state['predictor'] - temp_app.trained = True - temp_app.latent_dim = app_state['vae'].latent_dim if hasattr(app_state['vae'], 'latent_dim') else 32 - - # Make prediction - return temp_app.predict_treatment( - fmri_file=fmri_file, - age=age, - sex=sex, - months_post_stroke=months, - wab_score=wab - ) - except Exception as e: - logger.error(f"Error in treatment prediction: {str(e)}", exc_info=True) - return [f"Error in prediction: {str(e)}", None] - - # Connect the treatment prediction handler - predict_btn.click( - fn=handle_treatment_prediction, - inputs=[fmri_file, age, sex, months, wab], - outputs=[prediction_text, trajectory_plot] - ) - - # Add examples - gr.Examples( - examples=[ - ["SreekarB/OSFData1", "", 32, 200, 16, True, "regression", "wab_aq", True, False, False], # Standard training without synthetic data - ["SreekarB/OSFData1", "", 16, 100, 8, True, "regression", "wab_aq", True, False, False] # Faster training with smaller parameters - ], - inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset, - prediction_type, outcome_variable, skip_behavioral, - use_synthetic_nifti, use_synthetic_fc], - ) - - - - return interface + return app +# For local testing if __name__ == "__main__": - interface = create_interface() + app = create_ui() + app.launch() - # Check if running in Hugging Face Spaces - import os - if os.environ.get('SPACE_ID'): - # Running in Spaces - interface.launch() - else: - # Running locally - interface.launch(share=True) \ No newline at end of file +# For Huggingface Spaces +demo = create_ui() \ No newline at end of file