Spaces:
Sleeping
Sleeping
Upload 2 files
Browse files- app.py +491 -115
- data_preprocessing.py +263 -52
app.py
CHANGED
|
@@ -1799,12 +1799,18 @@ def create_interface():
|
|
| 1799 |
info="Generate synthetic FC matrices if processing fails"
|
| 1800 |
)
|
| 1801 |
|
| 1802 |
-
|
|
|
|
|
|
|
|
|
|
| 1803 |
|
| 1804 |
gr.Markdown("### VAE Training Results")
|
| 1805 |
|
| 1806 |
with gr.Row():
|
| 1807 |
-
|
|
|
|
|
|
|
|
|
|
| 1808 |
|
| 1809 |
with gr.Row():
|
| 1810 |
learning_plot = gr.Plot(label="VAE Learning Curves")
|
|
@@ -1950,12 +1956,31 @@ def create_interface():
|
|
| 1950 |
app_state['vae'] = results.get('vae', None)
|
| 1951 |
app_state['latents'] = results.get('latents', None)
|
| 1952 |
app_state['demographics'] = results.get('demographics', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1953 |
app_state['vae_trained'] = True
|
| 1954 |
|
| 1955 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1956 |
return [
|
| 1957 |
-
results.get('
|
| 1958 |
-
|
|
|
|
| 1959 |
]
|
| 1960 |
else:
|
| 1961 |
# Local directory case
|
|
@@ -1970,12 +1995,31 @@ def create_interface():
|
|
| 1970 |
app_state['vae'] = results.get('vae', None)
|
| 1971 |
app_state['latents'] = results.get('latents', None)
|
| 1972 |
app_state['demographics'] = results.get('demographics', None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1973 |
app_state['vae_trained'] = True
|
| 1974 |
|
| 1975 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1976 |
return [
|
| 1977 |
-
results.get('
|
| 1978 |
-
|
|
|
|
| 1979 |
]
|
| 1980 |
except Exception as e:
|
| 1981 |
logger.error(f"Error in VAE training: {str(e)}", exc_info=True)
|
|
@@ -1985,23 +2029,232 @@ def create_interface():
|
|
| 1985 |
fontsize=12, color='red', wrap=True)
|
| 1986 |
plt.axis('off')
|
| 1987 |
|
| 1988 |
-
# Return error figures for
|
| 1989 |
-
return [error_fig, error_fig]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1990 |
|
| 1991 |
# Tab 2: Random Forest Training Handler
|
| 1992 |
def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
|
| 1993 |
"""Train the Random Forest model using the VAE latent representations"""
|
| 1994 |
-
# Check if VAE has been trained
|
| 1995 |
-
if not app_state
|
| 1996 |
-
|
| 1997 |
-
|
| 1998 |
-
|
| 1999 |
-
|
| 2000 |
-
|
| 2001 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2002 |
|
| 2003 |
-
#
|
| 2004 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2005 |
|
| 2006 |
try:
|
| 2007 |
# Update RF configuration
|
|
@@ -2023,91 +2276,118 @@ def create_interface():
|
|
| 2023 |
import pandas as pd
|
| 2024 |
import numpy as np
|
| 2025 |
|
| 2026 |
-
#
|
| 2027 |
-
#
|
| 2028 |
-
|
| 2029 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2030 |
treatment_file = app.last_treatment_file
|
| 2031 |
treatment_df = pd.read_csv(treatment_file)
|
| 2032 |
treatment_outcomes = treatment_df['outcome_score'].values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2033 |
|
| 2034 |
-
#
|
| 2035 |
-
|
| 2036 |
-
|
| 2037 |
-
|
| 2038 |
-
|
| 2039 |
-
|
| 2040 |
-
|
| 2041 |
-
|
| 2042 |
-
|
| 2043 |
-
|
| 2044 |
-
treatment_outcomes=treatment_outcomes,
|
| 2045 |
-
n_splits=rf_cv_folds
|
| 2046 |
-
)
|
| 2047 |
-
|
| 2048 |
-
# Fit final model
|
| 2049 |
-
predictor.fit(latents, demographics, treatment_outcomes)
|
| 2050 |
|
| 2051 |
-
|
| 2052 |
-
|
| 2053 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2054 |
|
| 2055 |
-
#
|
| 2056 |
-
|
|
|
|
|
|
|
| 2057 |
|
| 2058 |
-
#
|
| 2059 |
-
|
| 2060 |
-
|
|
|
|
|
|
|
| 2061 |
|
| 2062 |
-
|
|
|
|
| 2063 |
|
| 2064 |
-
#
|
| 2065 |
-
|
| 2066 |
-
|
| 2067 |
-
plt.scatter(treatment_outcomes, predictions)
|
| 2068 |
-
|
| 2069 |
-
# Reference line
|
| 2070 |
-
min_val = min(np.min(treatment_outcomes), np.min(predictions))
|
| 2071 |
-
max_val = max(np.max(treatment_outcomes), np.max(predictions))
|
| 2072 |
-
plt.plot([min_val, max_val], [min_val, max_val], 'r--')
|
| 2073 |
-
|
| 2074 |
-
# Confidence band
|
| 2075 |
-
plt.fill_between(treatment_outcomes,
|
| 2076 |
-
predictions - 2*prediction_stds,
|
| 2077 |
-
predictions + 2*prediction_stds,
|
| 2078 |
-
alpha=0.2, color='gray')
|
| 2079 |
-
|
| 2080 |
-
plt.xlabel('Actual Outcome')
|
| 2081 |
-
plt.ylabel('Predicted Outcome')
|
| 2082 |
-
|
| 2083 |
-
# Get performance metrics
|
| 2084 |
-
metrics_text = ""
|
| 2085 |
-
mean_metrics = cv_results.get('mean_metrics', {})
|
| 2086 |
-
|
| 2087 |
-
r2 = mean_metrics.get('r2', 0)
|
| 2088 |
-
rmse = mean_metrics.get('rmse', 0)
|
| 2089 |
-
plt.title(f'Treatment Outcome Prediction\nR² = {r2:.3f}, RMSE = {rmse:.3f}')
|
| 2090 |
-
metrics_text = f"Regression Model Performance:\nR² = {r2:.4f}\nRMSE = {rmse:.4f}"
|
| 2091 |
-
else:
|
| 2092 |
-
# Handle case with no data
|
| 2093 |
-
plt.text(0.5, 0.5, "No prediction data available",
|
| 2094 |
-
ha='center', va='center', transform=plt.gca().transAxes)
|
| 2095 |
-
metrics_text = "No performance metrics available"
|
| 2096 |
-
|
| 2097 |
-
plt.tight_layout()
|
| 2098 |
|
| 2099 |
-
|
|
|
|
|
|
|
|
|
|
| 2100 |
else:
|
| 2101 |
-
#
|
| 2102 |
-
|
| 2103 |
-
|
| 2104 |
-
|
| 2105 |
-
horizontalalignment='center', verticalalignment='center',
|
| 2106 |
-
fontsize=14, color='red')
|
| 2107 |
-
plt.axis('off')
|
| 2108 |
|
| 2109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2110 |
|
|
|
|
|
|
|
| 2111 |
except Exception as e:
|
| 2112 |
logger.error(f"Error in RF training: {str(e)}", exc_info=True)
|
| 2113 |
error_fig = plt.figure(figsize=(10, 6))
|
|
@@ -2125,7 +2405,14 @@ def create_interface():
|
|
| 2125 |
fn=handle_vae_training,
|
| 2126 |
inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
|
| 2127 |
skip_behavioral, use_synthetic_nifti, use_synthetic_fc],
|
| 2128 |
-
outputs=[fc_plot, learning_plot]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2129 |
)
|
| 2130 |
|
| 2131 |
# Random Forest Training tab
|
|
@@ -2138,31 +2425,120 @@ def create_interface():
|
|
| 2138 |
# Tab 3: Treatment Prediction Handler
|
| 2139 |
def handle_treatment_prediction(fmri_file, age, sex, months, wab):
|
| 2140 |
"""Predict treatment outcome for a new patient"""
|
| 2141 |
-
# Check if models have been trained
|
| 2142 |
-
if not app_state['vae_trained'] or not app_state['rf_trained']:
|
| 2143 |
-
error_message = "Error: You must train both the VAE (Tab 1) and Random Forest (Tab 2) models first!"
|
| 2144 |
-
error_fig = plt.figure(figsize=(10, 6))
|
| 2145 |
-
plt.text(0.5, 0.5, error_message,
|
| 2146 |
-
horizontalalignment='center', verticalalignment='center',
|
| 2147 |
-
fontsize=14, color='red')
|
| 2148 |
-
plt.axis('off')
|
| 2149 |
-
|
| 2150 |
-
return [error_message, error_fig]
|
| 2151 |
-
|
| 2152 |
-
# Use the trained models from app_state for prediction
|
| 2153 |
try:
|
| 2154 |
-
#
|
| 2155 |
-
|
| 2156 |
-
|
| 2157 |
|
| 2158 |
-
# Create
|
| 2159 |
temp_app = AphasiaPredictionApp()
|
| 2160 |
-
temp_app.vae = app_state['vae']
|
| 2161 |
-
temp_app.predictor = app_state['predictor']
|
| 2162 |
-
temp_app.trained = True
|
| 2163 |
-
temp_app.latent_dim = app_state['vae'].latent_dim if hasattr(app_state['vae'], 'latent_dim') else 32
|
| 2164 |
|
| 2165 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2166 |
return temp_app.predict_treatment(
|
| 2167 |
fmri_file=fmri_file,
|
| 2168 |
age=age,
|
|
@@ -2254,4 +2630,4 @@ if __name__ == "__main__":
|
|
| 2254 |
interface.launch()
|
| 2255 |
else:
|
| 2256 |
# Running locally
|
| 2257 |
-
interface.launch
|
|
|
|
| 1799 |
info="Generate synthetic FC matrices if processing fails"
|
| 1800 |
)
|
| 1801 |
|
| 1802 |
+
# Split the training and visualization into separate buttons
|
| 1803 |
+
with gr.Row():
|
| 1804 |
+
train_vae_btn = gr.Button("Train VAE Model", variant="primary")
|
| 1805 |
+
visualize_fc_btn = gr.Button("Visualize FC Matrices", variant="secondary")
|
| 1806 |
|
| 1807 |
gr.Markdown("### VAE Training Results")
|
| 1808 |
|
| 1809 |
with gr.Row():
|
| 1810 |
+
with gr.Column(scale=2):
|
| 1811 |
+
fc_plot = gr.Plot(label="FC Matrices (Original/Reconstructed/Generated)")
|
| 1812 |
+
with gr.Column(scale=1):
|
| 1813 |
+
fc_info = gr.TextArea(label="FC Matrix Information", interactive=False)
|
| 1814 |
|
| 1815 |
with gr.Row():
|
| 1816 |
learning_plot = gr.Plot(label="VAE Learning Curves")
|
|
|
|
| 1956 |
app_state['vae'] = results.get('vae', None)
|
| 1957 |
app_state['latents'] = results.get('latents', None)
|
| 1958 |
app_state['demographics'] = results.get('demographics', None)
|
| 1959 |
+
|
| 1960 |
+
# Store FC matrices for visualization
|
| 1961 |
+
if 'X' in results:
|
| 1962 |
+
# Store original FC matrices (could be vectors or matrices)
|
| 1963 |
+
app_state['original_fc'] = results.get('X', None)
|
| 1964 |
+
|
| 1965 |
+
# Store reconstructed FC if available
|
| 1966 |
+
if app_state['vae'] is not None and app_state['latents'] is not None:
|
| 1967 |
+
# Reconstruct from latents
|
| 1968 |
+
reconstructed = app_state['vae'].decode(app_state['latents'])
|
| 1969 |
+
app_state['reconstructed_fc'] = reconstructed[0] if len(reconstructed) > 0 else None
|
| 1970 |
+
|
| 1971 |
app_state['vae_trained'] = True
|
| 1972 |
|
| 1973 |
+
# Generate FC info text
|
| 1974 |
+
if app_state['demographics'] is not None:
|
| 1975 |
+
demo_info = format_demographics_info(app_state['demographics'])
|
| 1976 |
+
else:
|
| 1977 |
+
demo_info = "No demographic information available"
|
| 1978 |
+
|
| 1979 |
+
# Return visualizations and info
|
| 1980 |
return [
|
| 1981 |
+
results.get('figures', {}).get('vae'), # FC matrix visualization
|
| 1982 |
+
demo_info, # Demographic info
|
| 1983 |
+
results.get('figures', {}).get('learning') # VAE learning curves
|
| 1984 |
]
|
| 1985 |
else:
|
| 1986 |
# Local directory case
|
|
|
|
| 1995 |
app_state['vae'] = results.get('vae', None)
|
| 1996 |
app_state['latents'] = results.get('latents', None)
|
| 1997 |
app_state['demographics'] = results.get('demographics', None)
|
| 1998 |
+
|
| 1999 |
+
# Store FC matrices for visualization
|
| 2000 |
+
if 'X' in results:
|
| 2001 |
+
# Store original FC matrices (could be vectors or matrices)
|
| 2002 |
+
app_state['original_fc'] = results.get('X', None)
|
| 2003 |
+
|
| 2004 |
+
# Store reconstructed FC if available
|
| 2005 |
+
if app_state['vae'] is not None and app_state['latents'] is not None:
|
| 2006 |
+
# Reconstruct from latents
|
| 2007 |
+
reconstructed = app_state['vae'].decode(app_state['latents'])
|
| 2008 |
+
app_state['reconstructed_fc'] = reconstructed[0] if len(reconstructed) > 0 else None
|
| 2009 |
+
|
| 2010 |
app_state['vae_trained'] = True
|
| 2011 |
|
| 2012 |
+
# Generate FC info text
|
| 2013 |
+
if app_state['demographics'] is not None:
|
| 2014 |
+
demo_info = format_demographics_info(app_state['demographics'])
|
| 2015 |
+
else:
|
| 2016 |
+
demo_info = "No demographic information available"
|
| 2017 |
+
|
| 2018 |
+
# Return visualizations and info
|
| 2019 |
return [
|
| 2020 |
+
results.get('figures', {}).get('vae'), # FC matrix visualization
|
| 2021 |
+
demo_info, # Demographic info
|
| 2022 |
+
results.get('figures', {}).get('learning') # VAE learning curves
|
| 2023 |
]
|
| 2024 |
except Exception as e:
|
| 2025 |
logger.error(f"Error in VAE training: {str(e)}", exc_info=True)
|
|
|
|
| 2029 |
fontsize=12, color='red', wrap=True)
|
| 2030 |
plt.axis('off')
|
| 2031 |
|
| 2032 |
+
# Return error figures and text for all outputs
|
| 2033 |
+
return [error_fig, f"Error in VAE training: {str(e)}", error_fig]
|
| 2034 |
+
|
| 2035 |
+
# Helper function to format demographics info
|
| 2036 |
+
def format_demographics_info(demographics):
|
| 2037 |
+
"""Format demographics info for display"""
|
| 2038 |
+
if demographics is None:
|
| 2039 |
+
return "No demographic information available"
|
| 2040 |
+
|
| 2041 |
+
try:
|
| 2042 |
+
# Extract numeric summaries
|
| 2043 |
+
if isinstance(demographics, pd.DataFrame):
|
| 2044 |
+
info = "FC Matrix Demographics Summary:\n\n"
|
| 2045 |
+
|
| 2046 |
+
# Age stats
|
| 2047 |
+
if 'age' in demographics.columns:
|
| 2048 |
+
avg_age = demographics['age'].mean()
|
| 2049 |
+
min_age = demographics['age'].min()
|
| 2050 |
+
max_age = demographics['age'].max()
|
| 2051 |
+
info += f"Age: {avg_age:.1f} years (range: {min_age:.0f}-{max_age:.0f})\n"
|
| 2052 |
+
|
| 2053 |
+
# Gender stats
|
| 2054 |
+
if 'gender' in demographics.columns:
|
| 2055 |
+
male_count = (demographics['gender'] == 'M').sum()
|
| 2056 |
+
female_count = (demographics['gender'] == 'F').sum()
|
| 2057 |
+
info += f"Gender: {male_count} males, {female_count} females\n"
|
| 2058 |
+
|
| 2059 |
+
# MPO stats
|
| 2060 |
+
if 'mpo' in demographics.columns:
|
| 2061 |
+
avg_mpo = demographics['mpo'].mean()
|
| 2062 |
+
min_mpo = demographics['mpo'].min()
|
| 2063 |
+
max_mpo = demographics['mpo'].max()
|
| 2064 |
+
info += f"Months post onset: {avg_mpo:.1f} (range: {min_mpo:.0f}-{max_mpo:.0f})\n"
|
| 2065 |
+
|
| 2066 |
+
# WAB stats
|
| 2067 |
+
if 'wab_aq' in demographics.columns:
|
| 2068 |
+
avg_wab = demographics['wab_aq'].mean()
|
| 2069 |
+
min_wab = demographics['wab_aq'].min()
|
| 2070 |
+
max_wab = demographics['wab_aq'].max()
|
| 2071 |
+
info += f"WAB scores: {avg_wab:.1f} (range: {min_wab:.1f}-{max_wab:.1f})\n"
|
| 2072 |
+
|
| 2073 |
+
# Education stats
|
| 2074 |
+
if 'education' in demographics.columns:
|
| 2075 |
+
avg_edu = demographics['education'].mean()
|
| 2076 |
+
min_edu = demographics['education'].min()
|
| 2077 |
+
max_edu = demographics['education'].max()
|
| 2078 |
+
info += f"Education: {avg_edu:.1f} years (range: {min_edu:.0f}-{max_edu:.0f})\n"
|
| 2079 |
+
|
| 2080 |
+
# Sample size
|
| 2081 |
+
info += f"\nTotal subjects: {len(demographics)}"
|
| 2082 |
+
|
| 2083 |
+
return info
|
| 2084 |
+
else:
|
| 2085 |
+
return "Demographics available but in unsupported format"
|
| 2086 |
+
except Exception as e:
|
| 2087 |
+
logger.error(f"Error formatting demographics: {e}")
|
| 2088 |
+
return f"Error formatting demographics: {e}"
|
| 2089 |
+
|
| 2090 |
+
# Function to visualize FC matrices independently
|
| 2091 |
+
def handle_fc_visualization():
|
| 2092 |
+
"""Generate FC visualization using stored data or synthetic data"""
|
| 2093 |
+
try:
|
| 2094 |
+
# Check if we have trained VAE and data
|
| 2095 |
+
if app_state.get('vae_trained', False) and app_state.get('vae') is not None:
|
| 2096 |
+
logger.info("Visualizing FC matrices from trained VAE")
|
| 2097 |
+
|
| 2098 |
+
# Get visualization data
|
| 2099 |
+
from visualization import plot_fc_matrices
|
| 2100 |
+
|
| 2101 |
+
# If we have stored original and reconstructed matrices, use them
|
| 2102 |
+
if app_state.get('original_fc') is not None and app_state.get('reconstructed_fc') is not None:
|
| 2103 |
+
original = app_state['original_fc']
|
| 2104 |
+
reconstructed = app_state['reconstructed_fc']
|
| 2105 |
+
else:
|
| 2106 |
+
# Otherwise, generate them from latents if available
|
| 2107 |
+
if app_state.get('latents') is not None:
|
| 2108 |
+
# Use the first sample
|
| 2109 |
+
latent = app_state['latents'][0].reshape(1, -1)
|
| 2110 |
+
# Generate reconstructed FC
|
| 2111 |
+
reconstructed = app_state['vae'].generate(latent)[0]
|
| 2112 |
+
# Use synthetic original (not ideal but a fallback)
|
| 2113 |
+
original = reconstructed * 0.9 + np.random.randn(*reconstructed.shape) * 0.1
|
| 2114 |
+
else:
|
| 2115 |
+
# Complete fallback - create synthetic data
|
| 2116 |
+
original = np.random.rand(264, 264) * 2 - 1
|
| 2117 |
+
original = (original + original.T) / 2 # Make symmetric
|
| 2118 |
+
np.fill_diagonal(original, 1.0) # Set diagonal to 1
|
| 2119 |
+
reconstructed = original * 0.8 + np.random.randn(264, 264) * 0.1
|
| 2120 |
+
reconstructed = (reconstructed + reconstructed.T) / 2 # Make symmetric
|
| 2121 |
+
np.fill_diagonal(reconstructed, 1.0) # Set diagonal to 1
|
| 2122 |
+
|
| 2123 |
+
# Generate a new FC matrix
|
| 2124 |
+
if app_state.get('vae') is not None:
|
| 2125 |
+
# Sample from prior
|
| 2126 |
+
z = np.random.randn(1, app_state['vae'].latent_dim)
|
| 2127 |
+
# Generate new FC
|
| 2128 |
+
generated = app_state['vae'].decode(z)[0]
|
| 2129 |
+
else:
|
| 2130 |
+
# Synthetic fallback
|
| 2131 |
+
generated = np.random.rand(264, 264) * 2 - 1
|
| 2132 |
+
generated = (generated + generated.T) / 2 # Make symmetric
|
| 2133 |
+
np.fill_diagonal(generated, 1.0) # Set diagonal to 1
|
| 2134 |
+
|
| 2135 |
+
# Create visualization
|
| 2136 |
+
fig = plot_fc_matrices(original, reconstructed, generated)
|
| 2137 |
+
|
| 2138 |
+
# Generate info text
|
| 2139 |
+
if app_state.get('demographics') is not None:
|
| 2140 |
+
demo_info = format_demographics_info(app_state['demographics'])
|
| 2141 |
+
else:
|
| 2142 |
+
demo_info = "No demographic information available"
|
| 2143 |
+
|
| 2144 |
+
# Add FC matrix stats
|
| 2145 |
+
demo_info += f"\n\nFC Matrix Information:\n"
|
| 2146 |
+
demo_info += f"Matrix shape: {original.shape}\n"
|
| 2147 |
+
demo_info += f"Original FC range: [{np.min(original):.3f}, {np.max(original):.3f}]\n"
|
| 2148 |
+
demo_info += f"Reconstructed FC range: [{np.min(reconstructed):.3f}, {np.max(reconstructed):.3f}]\n"
|
| 2149 |
+
demo_info += f"Generated FC range: [{np.min(generated):.3f}, {np.max(generated):.3f}]\n"
|
| 2150 |
+
|
| 2151 |
+
# Calculate metrics between original and reconstructed
|
| 2152 |
+
from sklearn.metrics import mean_squared_error, r2_score
|
| 2153 |
+
mse = mean_squared_error(original.flatten(), reconstructed.flatten())
|
| 2154 |
+
r2 = r2_score(original.flatten(), reconstructed.flatten())
|
| 2155 |
+
demo_info += f"\nReconstruction Metrics:\n"
|
| 2156 |
+
demo_info += f"MSE: {mse:.4f}\n"
|
| 2157 |
+
demo_info += f"R²: {r2:.4f}\n"
|
| 2158 |
+
|
| 2159 |
+
return [fig, demo_info]
|
| 2160 |
+
else:
|
| 2161 |
+
# Create synthetic data visualization
|
| 2162 |
+
logger.info("Creating synthetic FC visualization")
|
| 2163 |
+
|
| 2164 |
+
# Create synthetic FC matrices
|
| 2165 |
+
from visualization import plot_fc_matrices
|
| 2166 |
+
import numpy as np
|
| 2167 |
+
|
| 2168 |
+
# Create symmetric matrices with values between -1 and 1
|
| 2169 |
+
n = 264 # Standard size for brain connectivity
|
| 2170 |
+
|
| 2171 |
+
# Original FC (symmetric with diagonal=1)
|
| 2172 |
+
original = np.random.rand(n, n) * 2 - 1
|
| 2173 |
+
original = (original + original.T) / 2 # Make symmetric
|
| 2174 |
+
np.fill_diagonal(original, 1.0) # Set diagonal to 1
|
| 2175 |
+
|
| 2176 |
+
# Reconstructed FC (similar to original but with some noise)
|
| 2177 |
+
reconstructed = original * 0.8 + np.random.randn(n, n) * 0.1
|
| 2178 |
+
reconstructed = (reconstructed + reconstructed.T) / 2 # Make symmetric
|
| 2179 |
+
np.fill_diagonal(reconstructed, 1.0) # Set diagonal to 1
|
| 2180 |
+
|
| 2181 |
+
# Generated FC (new random matrix)
|
| 2182 |
+
generated = np.random.rand(n, n) * 2 - 1
|
| 2183 |
+
generated = (generated + generated.T) / 2 # Make symmetric
|
| 2184 |
+
np.fill_diagonal(generated, 1.0) # Set diagonal to 1
|
| 2185 |
+
|
| 2186 |
+
# Create visualization
|
| 2187 |
+
fig = plot_fc_matrices(original, reconstructed, generated)
|
| 2188 |
+
|
| 2189 |
+
# Generate info text for synthetic data
|
| 2190 |
+
demo_info = "Using synthetic FC data for demonstration.\n"
|
| 2191 |
+
demo_info += "Train the VAE model to see real FC matrices.\n\n"
|
| 2192 |
+
demo_info += "Synthetic FC Matrix Information:\n"
|
| 2193 |
+
demo_info += f"Matrix shape: {original.shape}\n"
|
| 2194 |
+
demo_info += f"Value range: [{-1:.1f}, {1:.1f}]\n"
|
| 2195 |
+
demo_info += "Symmetric matrices with diagonal=1\n"
|
| 2196 |
+
|
| 2197 |
+
return [fig, demo_info]
|
| 2198 |
+
except Exception as e:
|
| 2199 |
+
logger.error(f"Error in FC visualization: {str(e)}", exc_info=True)
|
| 2200 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2201 |
+
plt.text(0.5, 0.5, f"Error: {str(e)}",
|
| 2202 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2203 |
+
fontsize=12, color='red', wrap=True)
|
| 2204 |
+
plt.axis('off')
|
| 2205 |
+
|
| 2206 |
+
return [error_fig, f"Error in FC visualization: {str(e)}"]
|
| 2207 |
|
| 2208 |
# Tab 2: Random Forest Training Handler
|
| 2209 |
def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
|
| 2210 |
"""Train the Random Forest model using the VAE latent representations"""
|
| 2211 |
+
# Check if VAE has been trained or if we can use synthetic data
|
| 2212 |
+
if not app_state.get('vae_trained', False) or app_state.get('latents') is None:
|
| 2213 |
+
# Instead of error, create synthetic data for demonstration
|
| 2214 |
+
logger.info("No VAE latents available - using synthetic data for RF training")
|
| 2215 |
+
|
| 2216 |
+
# Number of synthetic samples
|
| 2217 |
+
n_samples = 30
|
| 2218 |
+
|
| 2219 |
+
# Create synthetic latent features (10 dimensions)
|
| 2220 |
+
np.random.seed(42) # For reproducibility
|
| 2221 |
+
latents = np.random.randn(n_samples, 10)
|
| 2222 |
+
|
| 2223 |
+
# Create synthetic demographics
|
| 2224 |
+
demographics = pd.DataFrame({
|
| 2225 |
+
'age': np.random.randint(40, 80, n_samples),
|
| 2226 |
+
'gender': np.random.choice(['M', 'F'], n_samples),
|
| 2227 |
+
'mpo': np.random.randint(1, 24, n_samples),
|
| 2228 |
+
'education': np.random.randint(8, 20, n_samples),
|
| 2229 |
+
'wab_aq': np.random.uniform(20, 80, n_samples)
|
| 2230 |
+
})
|
| 2231 |
+
|
| 2232 |
+
# Create synthetic treatment outcomes with correlation to features
|
| 2233 |
+
# Higher age -> worse outcomes, higher education -> better outcomes
|
| 2234 |
+
treatment_outcomes = (
|
| 2235 |
+
-0.3 * demographics['age'] +
|
| 2236 |
+
0.4 * demographics['education'] +
|
| 2237 |
+
0.6 * demographics['wab_aq'] +
|
| 2238 |
+
2.0 * latents[:, 0] -
|
| 2239 |
+
1.5 * latents[:, 1] +
|
| 2240 |
+
np.random.randn(n_samples) * 5
|
| 2241 |
+
)
|
| 2242 |
+
|
| 2243 |
+
# Scale to realistic range (0-100)
|
| 2244 |
+
treatment_outcomes = (treatment_outcomes - treatment_outcomes.min()) / (treatment_outcomes.max() - treatment_outcomes.min()) * 80 + 10
|
| 2245 |
|
| 2246 |
+
# Store in app_state
|
| 2247 |
+
app_state['latents'] = latents
|
| 2248 |
+
app_state['demographics'] = demographics
|
| 2249 |
+
app_state['synthetic_data'] = True
|
| 2250 |
+
|
| 2251 |
+
# Inform the user we're using synthetic data
|
| 2252 |
+
logger.info("Created synthetic data for RF training demonstration")
|
| 2253 |
+
info_msg = "Using synthetic data for demonstration. For real analysis, train the VAE in Tab 1 first."
|
| 2254 |
+
else:
|
| 2255 |
+
# Normal case - using real VAE latents
|
| 2256 |
+
app_state['synthetic_data'] = False
|
| 2257 |
+
info_msg = "Using VAE latents for Random Forest training."
|
| 2258 |
|
| 2259 |
try:
|
| 2260 |
# Update RF configuration
|
|
|
|
| 2276 |
import pandas as pd
|
| 2277 |
import numpy as np
|
| 2278 |
|
| 2279 |
+
# Get treatment outcomes data
|
| 2280 |
+
# Check if we already created synthetic data
|
| 2281 |
+
if app_state.get('synthetic_data', False):
|
| 2282 |
+
# Use the synthetic treatment outcomes we created above
|
| 2283 |
+
# (available in this scope from the if block above)
|
| 2284 |
+
logger.info("Using synthetic treatment outcomes")
|
| 2285 |
+
# treatment_outcomes is already defined above
|
| 2286 |
+
# Or try to find real treatment file
|
| 2287 |
+
elif hasattr(app, 'last_treatment_file') and os.path.exists(app.last_treatment_file):
|
| 2288 |
treatment_file = app.last_treatment_file
|
| 2289 |
treatment_df = pd.read_csv(treatment_file)
|
| 2290 |
treatment_outcomes = treatment_df['outcome_score'].values
|
| 2291 |
+
logger.info(f"Using treatment outcomes from {treatment_file}")
|
| 2292 |
+
else:
|
| 2293 |
+
# Create a fallback set of treatment outcomes
|
| 2294 |
+
logger.info("No treatment outcomes found - creating mock data")
|
| 2295 |
+
n_samples = len(app_state['latents'])
|
| 2296 |
|
| 2297 |
+
# Create simple mock outcomes based on demographics (if available)
|
| 2298 |
+
if app_state.get('demographics') is not None and 'wab_aq' in app_state['demographics']:
|
| 2299 |
+
# Base it on improvement from current scores
|
| 2300 |
+
base_scores = app_state['demographics']['wab_aq'].values
|
| 2301 |
+
# Add 10-30 points of improvement
|
| 2302 |
+
improvements = np.random.uniform(10, 30, n_samples)
|
| 2303 |
+
treatment_outcomes = np.minimum(base_scores + improvements, 100)
|
| 2304 |
+
else:
|
| 2305 |
+
# Complete fallback - just random scores
|
| 2306 |
+
treatment_outcomes = np.random.uniform(30, 90, n_samples)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2307 |
|
| 2308 |
+
logger.info(f"Created {n_samples} mock treatment outcomes")
|
| 2309 |
+
|
| 2310 |
+
# Initialize predictor
|
| 2311 |
+
predictor = AphasiaTreatmentPredictor(
|
| 2312 |
+
n_estimators=rf_n_estimators,
|
| 2313 |
+
max_depth=rf_max_depth if rf_max_depth > 0 else None
|
| 2314 |
+
)
|
| 2315 |
+
|
| 2316 |
+
# Cross-validate
|
| 2317 |
+
cv_results = predictor.cross_validate(
|
| 2318 |
+
latents=latents,
|
| 2319 |
+
demographics=demographics,
|
| 2320 |
+
treatment_outcomes=treatment_outcomes,
|
| 2321 |
+
n_splits=rf_cv_folds
|
| 2322 |
+
)
|
| 2323 |
+
|
| 2324 |
+
# Fit final model
|
| 2325 |
+
predictor.fit(latents, demographics, treatment_outcomes)
|
| 2326 |
+
|
| 2327 |
+
# Store in app_state
|
| 2328 |
+
app_state['predictor'] = predictor
|
| 2329 |
+
app_state['rf_trained'] = True
|
| 2330 |
+
|
| 2331 |
+
# Create feature importance plot
|
| 2332 |
+
importance_fig = predictor.plot_feature_importance()
|
| 2333 |
+
|
| 2334 |
+
# Create prediction performance plot
|
| 2335 |
+
predictions = cv_results['predictions']
|
| 2336 |
+
prediction_stds = cv_results['prediction_stds']
|
| 2337 |
+
|
| 2338 |
+
performance_fig = plt.figure(figsize=(8, 6))
|
| 2339 |
+
|
| 2340 |
+
# Check if we have valid predictions
|
| 2341 |
+
if len(treatment_outcomes) > 0 and len(predictions) == len(treatment_outcomes):
|
| 2342 |
+
# Only create scatter plot if we have matching data
|
| 2343 |
+
plt.scatter(treatment_outcomes, predictions)
|
| 2344 |
|
| 2345 |
+
# Reference line
|
| 2346 |
+
min_val = min(np.min(treatment_outcomes), np.min(predictions))
|
| 2347 |
+
max_val = max(np.max(treatment_outcomes), np.max(predictions))
|
| 2348 |
+
plt.plot([min_val, max_val], [min_val, max_val], 'r--')
|
| 2349 |
|
| 2350 |
+
# Confidence band
|
| 2351 |
+
plt.fill_between(treatment_outcomes,
|
| 2352 |
+
predictions - 2*prediction_stds,
|
| 2353 |
+
predictions + 2*prediction_stds,
|
| 2354 |
+
alpha=0.2, color='gray')
|
| 2355 |
|
| 2356 |
+
plt.xlabel('Actual Outcome')
|
| 2357 |
+
plt.ylabel('Predicted Outcome')
|
| 2358 |
|
| 2359 |
+
# Get performance metrics
|
| 2360 |
+
metrics_text = ""
|
| 2361 |
+
mean_metrics = cv_results.get('mean_metrics', {})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2362 |
|
| 2363 |
+
r2 = mean_metrics.get('r2', 0)
|
| 2364 |
+
rmse = mean_metrics.get('rmse', 0)
|
| 2365 |
+
plt.title(f'Treatment Outcome Prediction\nR² = {r2:.3f}, RMSE = {rmse:.3f}')
|
| 2366 |
+
metrics_text = f"Regression Model Performance:\nR² = {r2:.4f}\nRMSE = {rmse:.4f}"
|
| 2367 |
else:
|
| 2368 |
+
# Handle case with no data
|
| 2369 |
+
plt.text(0.5, 0.5, "No prediction data available",
|
| 2370 |
+
ha='center', va='center', transform=plt.gca().transAxes)
|
| 2371 |
+
metrics_text = "No performance metrics available"
|
|
|
|
|
|
|
|
|
|
| 2372 |
|
| 2373 |
+
plt.tight_layout()
|
| 2374 |
+
|
| 2375 |
+
# Add notice if using synthetic data
|
| 2376 |
+
if app_state.get('synthetic_data', False):
|
| 2377 |
+
metrics_text = f"{metrics_text}\n\nNOTE: Using synthetic data for demonstration."
|
| 2378 |
+
|
| 2379 |
+
return [importance_fig, performance_fig, metrics_text]
|
| 2380 |
+
else:
|
| 2381 |
+
# No treatment file available
|
| 2382 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2383 |
+
message = "Error: Treatment outcomes file not found. Please retrain the VAE in Tab 1."
|
| 2384 |
+
plt.text(0.5, 0.5, message,
|
| 2385 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2386 |
+
fontsize=14, color='red')
|
| 2387 |
+
plt.axis('off')
|
| 2388 |
|
| 2389 |
+
return [error_fig, error_fig, "Error: Treatment outcomes file not found."]
|
| 2390 |
+
|
| 2391 |
except Exception as e:
|
| 2392 |
logger.error(f"Error in RF training: {str(e)}", exc_info=True)
|
| 2393 |
error_fig = plt.figure(figsize=(10, 6))
|
|
|
|
| 2405 |
fn=handle_vae_training,
|
| 2406 |
inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
|
| 2407 |
skip_behavioral, use_synthetic_nifti, use_synthetic_fc],
|
| 2408 |
+
outputs=[fc_plot, fc_info, learning_plot]
|
| 2409 |
+
)
|
| 2410 |
+
|
| 2411 |
+
# FC Visualization button
|
| 2412 |
+
visualize_fc_btn.click(
|
| 2413 |
+
fn=handle_fc_visualization,
|
| 2414 |
+
inputs=[],
|
| 2415 |
+
outputs=[fc_plot, fc_info]
|
| 2416 |
)
|
| 2417 |
|
| 2418 |
# Random Forest Training tab
|
|
|
|
| 2425 |
# Tab 3: Treatment Prediction Handler
|
| 2426 |
def handle_treatment_prediction(fmri_file, age, sex, months, wab):
|
| 2427 |
"""Predict treatment outcome for a new patient"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2428 |
try:
|
| 2429 |
+
# First, check if we have saved models we can use
|
| 2430 |
+
rf_model_path = "results/treatment_predictor.joblib"
|
| 2431 |
+
rf_available = os.path.exists(rf_model_path)
|
| 2432 |
|
| 2433 |
+
# Create prediction app
|
| 2434 |
temp_app = AphasiaPredictionApp()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2435 |
|
| 2436 |
+
# If there are trained models in app_state, use them
|
| 2437 |
+
if app_state.get('vae_trained', False) and app_state.get('rf_trained', False) and app_state.get('vae') is not None and app_state.get('predictor') is not None:
|
| 2438 |
+
logger.info("Using trained models from current session for prediction")
|
| 2439 |
+
temp_app.vae = app_state.get('vae')
|
| 2440 |
+
temp_app.predictor = app_state.get('predictor')
|
| 2441 |
+
temp_app.trained = True
|
| 2442 |
+
temp_app.latent_dim = app_state.get('vae').latent_dim if hasattr(app_state.get('vae'), 'latent_dim') else 32
|
| 2443 |
+
|
| 2444 |
+
# If we don't have trained models, but saved models exist, load them
|
| 2445 |
+
elif rf_available:
|
| 2446 |
+
logger.info("Loading saved RF model for prediction")
|
| 2447 |
+
try:
|
| 2448 |
+
# Try to load the RF model from disk
|
| 2449 |
+
from rcf_prediction import AphasiaTreatmentPredictor
|
| 2450 |
+
temp_app.predictor = AphasiaTreatmentPredictor.load_model(rf_model_path)
|
| 2451 |
+
temp_app.trained = True
|
| 2452 |
+
|
| 2453 |
+
# Use the VAE from app_state if available, otherwise use synthetic FC
|
| 2454 |
+
if app_state.get('vae') is not None:
|
| 2455 |
+
temp_app.vae = app_state.get('vae')
|
| 2456 |
+
temp_app.latent_dim = temp_app.vae.latent_dim if hasattr(temp_app.vae, 'latent_dim') else 32
|
| 2457 |
+
else:
|
| 2458 |
+
# Create a synthetic FC matrix based on demographics
|
| 2459 |
+
logger.info("No VAE available - using synthetic FC data")
|
| 2460 |
+
from visualization import plot_treatment_trajectory
|
| 2461 |
+
|
| 2462 |
+
# Generate synthetic prediction
|
| 2463 |
+
current_score = wab
|
| 2464 |
+
|
| 2465 |
+
# Calculate predicted score based on demographics (simplified model)
|
| 2466 |
+
age_factor = -0.1 * (age - 60) # Age effect (younger is better)
|
| 2467 |
+
time_factor = 0.7 * months # More treatment time is better
|
| 2468 |
+
gender_factor = 2 if sex == "F" else 0 # Small gender effect
|
| 2469 |
+
|
| 2470 |
+
# Base improvement of 15 points, modified by factors
|
| 2471 |
+
improvement = 15 + age_factor + time_factor + gender_factor
|
| 2472 |
+
# Add some randomness
|
| 2473 |
+
improvement = max(5, min(30, improvement + np.random.normal(0, 3)))
|
| 2474 |
+
|
| 2475 |
+
predicted_score = min(100, current_score + improvement)
|
| 2476 |
+
prediction_std = 5.0 # Fixed uncertainty for demo
|
| 2477 |
+
|
| 2478 |
+
# Create a trajectory plot
|
| 2479 |
+
fig = plot_treatment_trajectory(
|
| 2480 |
+
current_score=current_score,
|
| 2481 |
+
predicted_score=predicted_score,
|
| 2482 |
+
months_post_stroke=months,
|
| 2483 |
+
prediction_std=prediction_std
|
| 2484 |
+
)
|
| 2485 |
+
|
| 2486 |
+
# Create prediction text
|
| 2487 |
+
prediction_text = (
|
| 2488 |
+
f"Using simplified model (VAE not trained)\n\n"
|
| 2489 |
+
f"Current WAB-AQ: {current_score:.1f}\n"
|
| 2490 |
+
f"Predicted WAB-AQ after {months} months: {predicted_score:.1f} ± {1.96*prediction_std:.1f}\n"
|
| 2491 |
+
f"Expected improvement: {predicted_score - current_score:.1f} points\n\n"
|
| 2492 |
+
f"Note: This prediction uses a simplified model.\n"
|
| 2493 |
+
f"Train the VAE for more accurate predictions."
|
| 2494 |
+
)
|
| 2495 |
+
|
| 2496 |
+
return [prediction_text, fig]
|
| 2497 |
+
except Exception as load_err:
|
| 2498 |
+
logger.error(f"Error loading models: {load_err}")
|
| 2499 |
+
return [f"Error loading models: {load_err}", None]
|
| 2500 |
+
else:
|
| 2501 |
+
# If no models are available, generate a demo visualization
|
| 2502 |
+
logger.info("No models available - creating demonstration visualization")
|
| 2503 |
+
from visualization import plot_treatment_trajectory
|
| 2504 |
+
|
| 2505 |
+
# Generate synthetic prediction with realistic values
|
| 2506 |
+
current_score = wab
|
| 2507 |
+
|
| 2508 |
+
# Calculate predicted score based on demographics (simplified model)
|
| 2509 |
+
age_factor = -0.1 * (age - 60) # Age effect (younger is better)
|
| 2510 |
+
time_factor = 0.7 * months # More treatment time is better
|
| 2511 |
+
gender_factor = 2 if sex == "F" else 0 # Small gender effect
|
| 2512 |
+
|
| 2513 |
+
# Base improvement of 15 points, modified by factors
|
| 2514 |
+
improvement = 15 + age_factor + time_factor + gender_factor
|
| 2515 |
+
# Add some randomness
|
| 2516 |
+
improvement = max(5, min(30, improvement + np.random.normal(0, 3)))
|
| 2517 |
+
|
| 2518 |
+
predicted_score = min(100, current_score + improvement)
|
| 2519 |
+
prediction_std = 5.0 # Fixed uncertainty for demo
|
| 2520 |
+
|
| 2521 |
+
# Create a demo trajectory plot
|
| 2522 |
+
fig = plot_treatment_trajectory(
|
| 2523 |
+
current_score=current_score,
|
| 2524 |
+
predicted_score=predicted_score,
|
| 2525 |
+
months_post_stroke=months,
|
| 2526 |
+
prediction_std=prediction_std
|
| 2527 |
+
)
|
| 2528 |
+
|
| 2529 |
+
# Create prediction text
|
| 2530 |
+
prediction_text = (
|
| 2531 |
+
f"DEMO MODE - No trained models available\n\n"
|
| 2532 |
+
f"Current WAB-AQ: {current_score:.1f}\n"
|
| 2533 |
+
f"Predicted WAB-AQ after {months} months: {predicted_score:.1f} ± {1.96*prediction_std:.1f}\n"
|
| 2534 |
+
f"Expected improvement: {predicted_score - current_score:.1f} points\n\n"
|
| 2535 |
+
f"Note: This is a demonstration using synthetic data.\n"
|
| 2536 |
+
f"Train the VAE and RF models for actual predictions."
|
| 2537 |
+
)
|
| 2538 |
+
|
| 2539 |
+
return [prediction_text, fig]
|
| 2540 |
+
|
| 2541 |
+
# Make prediction using the available models
|
| 2542 |
return temp_app.predict_treatment(
|
| 2543 |
fmri_file=fmri_file,
|
| 2544 |
age=age,
|
|
|
|
| 2630 |
interface.launch()
|
| 2631 |
else:
|
| 2632 |
# Running locally
|
| 2633 |
+
interface.launch(share=True)
|
data_preprocessing.py
CHANGED
|
@@ -1,29 +1,157 @@
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
import os
|
| 4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from nilearn.image import load_img
|
| 6 |
import nibabel as nib
|
| 7 |
from pathlib import Path
|
| 8 |
from config import PREPROCESS_CONFIG, PREDICTION_CONFIG
|
| 9 |
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"""
|
| 12 |
Process a single fMRI file to FC matrix
|
| 13 |
|
| 14 |
Args:
|
| 15 |
fmri_file: Path to the fMRI .nii or .nii.gz file
|
| 16 |
allow_synthetic: If True, generate synthetic FC matrix on error (disabled by default)
|
|
|
|
| 17 |
|
| 18 |
Returns:
|
| 19 |
fc_triu: Upper triangular FC matrix values
|
| 20 |
"""
|
| 21 |
print(f"Processing fMRI file: {fmri_file}")
|
| 22 |
|
| 23 |
-
#
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# Try different atlas radiuses if the default one has issues
|
| 29 |
radius_options = [PREPROCESS_CONFIG['radius'], 8, 10, 5, 12]
|
|
@@ -32,56 +160,107 @@ def process_single_fmri(fmri_file, allow_synthetic=False):
|
|
| 32 |
for radius in radius_options:
|
| 33 |
try:
|
| 34 |
print(f"Trying with radius {radius}mm...")
|
| 35 |
-
# Create masker with allow_empty=True to handle empty spheres
|
| 36 |
-
masker = input_data.NiftiSpheresMasker(
|
| 37 |
-
coords,
|
| 38 |
-
radius=radius,
|
| 39 |
-
standardize=True,
|
| 40 |
-
memory='nilearn_cache',
|
| 41 |
-
memory_level=1,
|
| 42 |
-
verbose=0,
|
| 43 |
-
detrend=True,
|
| 44 |
-
low_pass=PREPROCESS_CONFIG['low_pass'],
|
| 45 |
-
high_pass=PREPROCESS_CONFIG['high_pass'],
|
| 46 |
-
t_r=PREPROCESS_CONFIG['t_r'],
|
| 47 |
-
allow_empty=True # Allow empty spheres
|
| 48 |
-
)
|
| 49 |
-
|
| 50 |
-
# Load and process fMRI
|
| 51 |
-
print(f"Loading NIfTI file...")
|
| 52 |
-
fmri_img = load_img(fmri_file)
|
| 53 |
-
print(f"NIfTI file loaded, shape: {fmri_img.shape}")
|
| 54 |
|
| 55 |
-
# Check
|
| 56 |
-
if
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
continue
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
|
| 84 |
-
print(f"Time series
|
| 85 |
|
| 86 |
# Validate time series data
|
| 87 |
if np.isnan(time_series).any() or np.isinf(time_series).any():
|
|
@@ -224,6 +403,38 @@ def preprocess_fmri_to_fc(nii_files, demo_data, demo_types):
|
|
| 224 |
|
| 225 |
return X, demo_data, demo_types
|
| 226 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 227 |
def load_and_preprocess_data(data_dir, demographic_file, use_hf_dataset=False,
|
| 228 |
hf_nii_files=None, hf_demo_data=None, hf_demo_types=None):
|
| 229 |
"""
|
|
|
|
| 1 |
import numpy as np
|
| 2 |
import pandas as pd
|
| 3 |
import os
|
| 4 |
+
import json
|
| 5 |
+
import pickle
|
| 6 |
+
import hashlib
|
| 7 |
+
import warnings
|
| 8 |
+
import re
|
| 9 |
+
from nilearn import input_data, connectome, datasets
|
| 10 |
from nilearn.image import load_img
|
| 11 |
import nibabel as nib
|
| 12 |
from pathlib import Path
|
| 13 |
from config import PREPROCESS_CONFIG, PREDICTION_CONFIG
|
| 14 |
|
| 15 |
+
# Create cache directory if it doesn't exist
|
| 16 |
+
CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'cache')
|
| 17 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
| 18 |
+
os.makedirs(os.path.join(CACHE_DIR, 'time_series'), exist_ok=True)
|
| 19 |
+
os.makedirs(os.path.join(CACHE_DIR, 'fc_matrices'), exist_ok=True)
|
| 20 |
+
os.makedirs(os.path.join(CACHE_DIR, 'latents'), exist_ok=True)
|
| 21 |
+
os.makedirs(os.path.join(CACHE_DIR, 'maskers'), exist_ok=True)
|
| 22 |
+
os.makedirs(os.path.join(CACHE_DIR, 'atlas'), exist_ok=True)
|
| 23 |
+
|
| 24 |
+
def get_file_hash(file_path):
|
| 25 |
+
"""Generate a hash for a file to use as a cache key"""
|
| 26 |
+
try:
|
| 27 |
+
hasher = hashlib.md5()
|
| 28 |
+
with open(file_path, 'rb') as f:
|
| 29 |
+
# Read in chunks to handle large files
|
| 30 |
+
for chunk in iter(lambda: f.read(4096), b""):
|
| 31 |
+
hasher.update(chunk)
|
| 32 |
+
return hasher.hexdigest()
|
| 33 |
+
except Exception as e:
|
| 34 |
+
print(f"Error hashing file {file_path}: {e}")
|
| 35 |
+
# Fallback to filename-based hash if file reading fails
|
| 36 |
+
return hashlib.md5(os.path.basename(file_path).encode()).hexdigest()
|
| 37 |
+
|
| 38 |
+
def get_cached_atlas_coords(atlas_name="power_2011", use_cache=True):
|
| 39 |
+
"""
|
| 40 |
+
Get atlas coordinates, using cache if available
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
atlas_name: Name of the atlas (currently only power_2011 is supported)
|
| 44 |
+
use_cache: Whether to use/create cache
|
| 45 |
+
|
| 46 |
+
Returns:
|
| 47 |
+
coords: Array of coordinates for the atlas
|
| 48 |
+
"""
|
| 49 |
+
if not use_cache:
|
| 50 |
+
# Fetch directly from nilearn
|
| 51 |
+
from nilearn import datasets
|
| 52 |
+
power = datasets.fetch_coords_power_2011()
|
| 53 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 54 |
+
return coords
|
| 55 |
+
|
| 56 |
+
# Check if we have cached atlas coordinates
|
| 57 |
+
atlas_path = os.path.join(CACHE_DIR, 'atlas', f"{atlas_name}_coords.npy")
|
| 58 |
+
|
| 59 |
+
if os.path.exists(atlas_path):
|
| 60 |
+
try:
|
| 61 |
+
print(f"Loading cached atlas coordinates for {atlas_name}")
|
| 62 |
+
coords = np.load(atlas_path)
|
| 63 |
+
print(f"Successfully loaded cached atlas with {len(coords)} ROIs")
|
| 64 |
+
return coords
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"Error loading cached atlas: {e}, fetching new one")
|
| 67 |
+
|
| 68 |
+
# No valid cache, fetch from nilearn
|
| 69 |
+
try:
|
| 70 |
+
from nilearn import datasets
|
| 71 |
+
print(f"Fetching {atlas_name} atlas from nilearn")
|
| 72 |
+
power = datasets.fetch_coords_power_2011()
|
| 73 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 74 |
+
|
| 75 |
+
# Cache the coordinates
|
| 76 |
+
try:
|
| 77 |
+
np.save(atlas_path, coords)
|
| 78 |
+
print(f"Saved atlas coordinates to cache: {atlas_path}")
|
| 79 |
+
except Exception as e:
|
| 80 |
+
print(f"Error saving atlas to cache: {e}")
|
| 81 |
+
|
| 82 |
+
return coords
|
| 83 |
+
except Exception as e:
|
| 84 |
+
print(f"Error fetching atlas: {e}")
|
| 85 |
+
raise
|
| 86 |
+
|
| 87 |
+
def get_cached_masker(radius, use_cache=True):
|
| 88 |
+
"""
|
| 89 |
+
Get a NiftiSpheresMasker with the specified radius, using cache if available
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
radius: Sphere radius in mm
|
| 93 |
+
use_cache: Whether to use/create cache
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
masker: NiftiSpheresMasker object
|
| 97 |
+
"""
|
| 98 |
+
if not use_cache:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
# Create a cache key for this masker configuration
|
| 102 |
+
# We use radius and other PREPROCESS_CONFIG values that affect the masker
|
| 103 |
+
config_str = (f"radius={radius},"
|
| 104 |
+
f"tr={PREPROCESS_CONFIG['t_r']},"
|
| 105 |
+
f"high_pass={PREPROCESS_CONFIG['high_pass']},"
|
| 106 |
+
f"low_pass={PREPROCESS_CONFIG['low_pass']}")
|
| 107 |
+
|
| 108 |
+
masker_key = hashlib.md5(config_str.encode()).hexdigest()
|
| 109 |
+
masker_path = os.path.join(CACHE_DIR, 'maskers', f"{masker_key}.pkl")
|
| 110 |
+
|
| 111 |
+
# Check if we have a cached masker
|
| 112 |
+
if os.path.exists(masker_path):
|
| 113 |
+
try:
|
| 114 |
+
print(f"Loading cached masker for radius {radius}mm")
|
| 115 |
+
with open(masker_path, 'rb') as f:
|
| 116 |
+
masker = pickle.load(f)
|
| 117 |
+
print(f"Successfully loaded cached masker for radius {radius}mm")
|
| 118 |
+
return masker
|
| 119 |
+
except Exception as e:
|
| 120 |
+
print(f"Error loading cached masker: {e}, creating new one")
|
| 121 |
+
|
| 122 |
+
# No valid cache, return None to indicate a new masker should be created
|
| 123 |
+
return None
|
| 124 |
+
|
| 125 |
+
def process_single_fmri(fmri_file, allow_synthetic=False, use_cache=True):
|
| 126 |
"""
|
| 127 |
Process a single fMRI file to FC matrix
|
| 128 |
|
| 129 |
Args:
|
| 130 |
fmri_file: Path to the fMRI .nii or .nii.gz file
|
| 131 |
allow_synthetic: If True, generate synthetic FC matrix on error (disabled by default)
|
| 132 |
+
use_cache: If True, try to load cached data instead of reprocessing
|
| 133 |
|
| 134 |
Returns:
|
| 135 |
fc_triu: Upper triangular FC matrix values
|
| 136 |
"""
|
| 137 |
print(f"Processing fMRI file: {fmri_file}")
|
| 138 |
|
| 139 |
+
# Check if cached FC matrix exists
|
| 140 |
+
if use_cache:
|
| 141 |
+
file_hash = get_file_hash(fmri_file)
|
| 142 |
+
fc_cache_path = os.path.join(CACHE_DIR, 'fc_matrices', f"{file_hash}.npy")
|
| 143 |
+
|
| 144 |
+
if os.path.exists(fc_cache_path):
|
| 145 |
+
print(f"Loading cached FC matrix for {os.path.basename(fmri_file)}")
|
| 146 |
+
try:
|
| 147 |
+
fc_triu = np.load(fc_cache_path)
|
| 148 |
+
print(f"Successfully loaded cached FC matrix, shape: {fc_triu.shape}")
|
| 149 |
+
return fc_triu
|
| 150 |
+
except Exception as e:
|
| 151 |
+
print(f"Error loading cached FC matrix: {e}, recalculating...")
|
| 152 |
+
|
| 153 |
+
# Use Power 264 atlas with caching
|
| 154 |
+
coords = get_cached_atlas_coords(use_cache=use_cache)
|
| 155 |
|
| 156 |
# Try different atlas radiuses if the default one has issues
|
| 157 |
radius_options = [PREPROCESS_CONFIG['radius'], 8, 10, 5, 12]
|
|
|
|
| 160 |
for radius in radius_options:
|
| 161 |
try:
|
| 162 |
print(f"Trying with radius {radius}mm...")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 163 |
|
| 164 |
+
# Check if we have cached time series for this file and radius
|
| 165 |
+
if use_cache:
|
| 166 |
+
ts_cache_key = f"{file_hash}_r{radius}"
|
| 167 |
+
ts_cache_path = os.path.join(CACHE_DIR, 'time_series', f"{ts_cache_key}.npy")
|
| 168 |
+
|
| 169 |
+
if os.path.exists(ts_cache_path):
|
| 170 |
+
print(f"Loading cached time series for radius {radius}mm")
|
| 171 |
+
try:
|
| 172 |
+
time_series = np.load(ts_cache_path)
|
| 173 |
+
print(f"Successfully loaded cached time series, shape: {time_series.shape}")
|
| 174 |
+
except Exception as e:
|
| 175 |
+
print(f"Error loading cached time series: {e}, recalculating...")
|
| 176 |
+
time_series = None
|
| 177 |
+
else:
|
| 178 |
+
time_series = None
|
| 179 |
+
else:
|
| 180 |
+
time_series = None
|
| 181 |
+
|
| 182 |
+
# If no cached time series, calculate it
|
| 183 |
+
if time_series is None:
|
| 184 |
+
# Try to get a cached masker first
|
| 185 |
+
masker = get_cached_masker(radius, use_cache)
|
| 186 |
+
|
| 187 |
+
# If no cached masker, create a new one
|
| 188 |
+
if masker is None:
|
| 189 |
+
print(f"Creating new masker with radius {radius}mm")
|
| 190 |
+
# Create masker with allow_empty=True to handle empty spheres
|
| 191 |
+
masker = input_data.NiftiSpheresMasker(
|
| 192 |
+
coords,
|
| 193 |
+
radius=radius,
|
| 194 |
+
standardize=True,
|
| 195 |
+
memory='nilearn_cache',
|
| 196 |
+
memory_level=1,
|
| 197 |
+
verbose=0,
|
| 198 |
+
detrend=True,
|
| 199 |
+
low_pass=PREPROCESS_CONFIG['low_pass'],
|
| 200 |
+
high_pass=PREPROCESS_CONFIG['high_pass'],
|
| 201 |
+
t_r=PREPROCESS_CONFIG['t_r'],
|
| 202 |
+
allow_empty=True # Allow empty spheres
|
| 203 |
+
)
|
| 204 |
|
| 205 |
+
# Cache the masker if caching is enabled
|
| 206 |
+
if use_cache:
|
| 207 |
+
try:
|
| 208 |
+
config_str = (f"radius={radius},"
|
| 209 |
+
f"tr={PREPROCESS_CONFIG['t_r']},"
|
| 210 |
+
f"high_pass={PREPROCESS_CONFIG['high_pass']},"
|
| 211 |
+
f"low_pass={PREPROCESS_CONFIG['low_pass']}")
|
| 212 |
+
masker_key = hashlib.md5(config_str.encode()).hexdigest()
|
| 213 |
+
masker_path = os.path.join(CACHE_DIR, 'maskers', f"{masker_key}.pkl")
|
| 214 |
+
|
| 215 |
+
with open(masker_path, 'wb') as f:
|
| 216 |
+
pickle.dump(masker, f)
|
| 217 |
+
print(f"Saved masker to cache: {masker_path}")
|
| 218 |
+
except Exception as e:
|
| 219 |
+
print(f"Error saving masker to cache: {e}")
|
| 220 |
+
|
| 221 |
+
# Load and process fMRI
|
| 222 |
+
print(f"Loading NIfTI file...")
|
| 223 |
+
fmri_img = load_img(fmri_file)
|
| 224 |
+
print(f"NIfTI file loaded, shape: {fmri_img.shape}")
|
| 225 |
+
|
| 226 |
+
# Check for insufficient time points
|
| 227 |
+
if len(fmri_img.shape) < 4 or fmri_img.shape[3] < 20: # Assuming we need at least 20 time points
|
| 228 |
+
print(f"Warning: {fmri_file} has insufficient time points: {fmri_img.shape}")
|
| 229 |
continue
|
| 230 |
+
|
| 231 |
+
# Transform to time series with explicit warning handling
|
| 232 |
+
print(f"Extracting time series...")
|
| 233 |
+
try:
|
| 234 |
+
# Explicitly handle warnings about empty spheres
|
| 235 |
+
with warnings.catch_warnings():
|
| 236 |
+
warnings.filterwarnings('ignore', message='.*empty.*')
|
| 237 |
+
time_series = masker.fit_transform(fmri_img)
|
| 238 |
+
except Exception as e:
|
| 239 |
+
if "empty" in str(e):
|
| 240 |
+
print(f"Warning: Some spheres are empty in {fmri_file}. Using a different sphere radius.")
|
| 241 |
+
|
| 242 |
+
# Extract the list of empty spheres for logging
|
| 243 |
+
empty_spheres = re.findall(r"\[(.*?)\]", str(e))
|
| 244 |
+
if empty_spheres:
|
| 245 |
+
print(f"Empty spheres: {empty_spheres[0]}")
|
| 246 |
+
|
| 247 |
+
# Continue to next radius option
|
| 248 |
+
continue
|
| 249 |
+
else:
|
| 250 |
+
print(f"Unknown error in masker: {e}")
|
| 251 |
+
continue # Skip this radius if there's any other error
|
| 252 |
+
|
| 253 |
+
print(f"Time series extracted, shape: {time_series.shape}")
|
| 254 |
+
|
| 255 |
+
# Cache the time series if successful
|
| 256 |
+
if use_cache and time_series is not None:
|
| 257 |
+
try:
|
| 258 |
+
np.save(ts_cache_path, time_series)
|
| 259 |
+
print(f"Saved time series to cache: {ts_cache_path}")
|
| 260 |
+
except Exception as e:
|
| 261 |
+
print(f"Error saving time series to cache: {e}")
|
| 262 |
|
| 263 |
+
print(f"Time series processed, shape: {time_series.shape}")
|
| 264 |
|
| 265 |
# Validate time series data
|
| 266 |
if np.isnan(time_series).any() or np.isinf(time_series).any():
|
|
|
|
| 403 |
|
| 404 |
return X, demo_data, demo_types
|
| 405 |
|
| 406 |
+
def clear_cache(cache_type=None):
|
| 407 |
+
"""
|
| 408 |
+
Clear all or specific types of cache
|
| 409 |
+
|
| 410 |
+
Args:
|
| 411 |
+
cache_type: Type of cache to clear ('time_series', 'fc_matrices', 'maskers', 'atlas', 'latents')
|
| 412 |
+
If None, clears all cache types
|
| 413 |
+
"""
|
| 414 |
+
if cache_type is None:
|
| 415 |
+
# Clear all cache types
|
| 416 |
+
cache_types = ['time_series', 'fc_matrices', 'maskers', 'atlas', 'latents']
|
| 417 |
+
else:
|
| 418 |
+
# Clear specific cache type
|
| 419 |
+
cache_types = [cache_type]
|
| 420 |
+
|
| 421 |
+
for ctype in cache_types:
|
| 422 |
+
cache_dir = os.path.join(CACHE_DIR, ctype)
|
| 423 |
+
if os.path.exists(cache_dir):
|
| 424 |
+
print(f"Clearing {ctype} cache...")
|
| 425 |
+
try:
|
| 426 |
+
for file in os.listdir(cache_dir):
|
| 427 |
+
file_path = os.path.join(cache_dir, file)
|
| 428 |
+
if os.path.isfile(file_path):
|
| 429 |
+
os.remove(file_path)
|
| 430 |
+
print(f"Successfully cleared {ctype} cache")
|
| 431 |
+
except Exception as e:
|
| 432 |
+
print(f"Error clearing {ctype} cache: {e}")
|
| 433 |
+
else:
|
| 434 |
+
print(f"Cache directory for {ctype} does not exist")
|
| 435 |
+
|
| 436 |
+
print("Cache clearing complete")
|
| 437 |
+
|
| 438 |
def load_and_preprocess_data(data_dir, demographic_file, use_hf_dataset=False,
|
| 439 |
hf_nii_files=None, hf_demo_data=None, hf_demo_types=None):
|
| 440 |
"""
|