Spaces:
Sleeping
Sleeping
Upload 6 files
Browse files- README.md +3 -3
- app.py +170 -35
- config.py +1 -1
- huggingface_fc_visualization.py +1 -1
- test_hf_download.py +2 -2
- visualize_fc.py +1 -1
README.md
CHANGED
|
@@ -4,7 +4,7 @@ emoji: 🧠
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
@@ -23,7 +23,7 @@ This application implements a VAE model that:
|
|
| 23 |
|
| 24 |
## Dataset
|
| 25 |
|
| 26 |
-
This demo uses the [SreekarB/
|
| 27 |
|
| 28 |
- NIfTI files in P01_rs.nii format containing fMRI data
|
| 29 |
- Demographic information directly in the dataset:
|
|
@@ -40,7 +40,7 @@ The application processes the NIfTI files using the Power 264 atlas to create fu
|
|
| 40 |
## How to Use
|
| 41 |
|
| 42 |
1. **Configure Parameters**:
|
| 43 |
-
- **Data Source**: By default, it uses the SreekarB/
|
| 44 |
- **Latent Dimensions**: Controls the size of the latent space (default: 32)
|
| 45 |
- **Number of Epochs**: Training iterations (default: 200 for demo)
|
| 46 |
- **Batch Size**: Training batch size (default: 16)
|
|
|
|
| 4 |
colorFrom: blue
|
| 5 |
colorTo: pink
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 3.36.1
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
---
|
|
|
|
| 23 |
|
| 24 |
## Dataset
|
| 25 |
|
| 26 |
+
This demo uses the [SreekarB/OSFData1](https://huggingface.co/datasets/SreekarB/OSFData1) dataset from HuggingFace, which contains:
|
| 27 |
|
| 28 |
- NIfTI files in P01_rs.nii format containing fMRI data
|
| 29 |
- Demographic information directly in the dataset:
|
|
|
|
| 40 |
## How to Use
|
| 41 |
|
| 42 |
1. **Configure Parameters**:
|
| 43 |
+
- **Data Source**: By default, it uses the SreekarB/OSFData1 HuggingFace dataset
|
| 44 |
- **Latent Dimensions**: Controls the size of the latent space (default: 32)
|
| 45 |
- **Number of Epochs**: Training iterations (default: 200 for demo)
|
| 46 |
- **Batch Size**: Training batch size (default: 16)
|
app.py
CHANGED
|
@@ -49,8 +49,8 @@ class AphasiaPredictionApp:
|
|
| 49 |
# Run the full analysis pipeline
|
| 50 |
# For HuggingFace dataset, we don't need the demographic file physically
|
| 51 |
# as we'll extract demographics directly from the dataset
|
| 52 |
-
if data_dir == "SreekarB/
|
| 53 |
-
logger.info("Using SreekarB/
|
| 54 |
|
| 55 |
try:
|
| 56 |
# Import HF dataset libraries
|
|
@@ -298,8 +298,8 @@ class AphasiaPredictionApp:
|
|
| 298 |
# Set default treatment file path to our fallback file
|
| 299 |
treatment_file = fallback_file
|
| 300 |
|
| 301 |
-
# For SreekarB/
|
| 302 |
-
if data_dir == "SreekarB/
|
| 303 |
# Check if the user wants to skip behavioral data processing
|
| 304 |
skip_behavioral = PREDICTION_CONFIG.get('skip_behavioral_data', False)
|
| 305 |
|
|
@@ -354,7 +354,7 @@ class AphasiaPredictionApp:
|
|
| 354 |
except Exception as e:
|
| 355 |
logger.warning(f"Error during treatment data lookup: {e}, using standard outcomes")
|
| 356 |
# Keep using the fallback file
|
| 357 |
-
# Only check for treatment_file if we're not using the SreekarB/
|
| 358 |
elif not os.path.exists(treatment_file):
|
| 359 |
# Try app directory as fallback
|
| 360 |
app_dir_treatment = os.path.join(os.path.dirname(os.path.abspath(__file__)), "treatment_outcomes.csv")
|
|
@@ -369,7 +369,7 @@ class AphasiaPredictionApp:
|
|
| 369 |
logger.info(f"Using treatment file: {treatment_file}")
|
| 370 |
|
| 371 |
# Special handling for HuggingFace dataset
|
| 372 |
-
if data_dir == "SreekarB/
|
| 373 |
# For NIfTI files, we need to search the API or download regardless of demographic source
|
| 374 |
logger.info("Searching for NIfTI files in the dataset...")
|
| 375 |
|
|
@@ -1883,7 +1883,7 @@ def create_interface():
|
|
| 1883 |
with gr.Column(scale=1):
|
| 1884 |
data_dir = gr.Textbox(
|
| 1885 |
label="Data Directory or HuggingFace Dataset ID",
|
| 1886 |
-
value="SreekarB/
|
| 1887 |
)
|
| 1888 |
local_nii_dir = gr.Textbox(
|
| 1889 |
label="Local NIfTI Files Directory (Optional)",
|
|
@@ -1972,7 +1972,7 @@ def create_interface():
|
|
| 1972 |
with gr.Column(scale=1):
|
| 1973 |
prediction_type = gr.Radio(
|
| 1974 |
label="Prediction Type",
|
| 1975 |
-
choices=["regression"
|
| 1976 |
value="regression"
|
| 1977 |
)
|
| 1978 |
outcome_variable = gr.Dropdown(
|
|
@@ -2230,7 +2230,7 @@ def create_interface():
|
|
| 2230 |
logger.info("Looking for data in directory and preprocessing...")
|
| 2231 |
|
| 2232 |
# This part is similar to app.train_models but only focuses on VAE
|
| 2233 |
-
if data_dir == "SreekarB/
|
| 2234 |
# Use dataset, similar to existing code in app.train_models
|
| 2235 |
# For brevity, we'll call the full train_models function but only
|
| 2236 |
# extract the VAE-related results
|
|
@@ -2298,17 +2298,68 @@ def create_interface():
|
|
| 2298 |
# Tab 2: Random Forest Training Handler
|
| 2299 |
def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
|
| 2300 |
"""Train the Random Forest model using the VAE latent representations"""
|
| 2301 |
-
#
|
| 2302 |
-
if not app_state
|
| 2303 |
-
|
| 2304 |
-
|
| 2305 |
-
|
| 2306 |
-
|
| 2307 |
-
|
| 2308 |
-
|
| 2309 |
-
|
| 2310 |
-
|
| 2311 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2312 |
|
| 2313 |
try:
|
| 2314 |
# Update RF configuration
|
|
@@ -2317,7 +2368,7 @@ def create_interface():
|
|
| 2317 |
PREDICTION_CONFIG['max_depth'] = rf_max_depth if rf_max_depth > 0 else None
|
| 2318 |
PREDICTION_CONFIG['cv_folds'] = rf_cv_folds
|
| 2319 |
|
| 2320 |
-
#
|
| 2321 |
logger.info(f"Training Random Forest Regression model: outcome={outcome_variable}")
|
| 2322 |
logger.info(f"RF parameters: n_estimators={rf_n_estimators}, max_depth={rf_max_depth}, cv_folds={rf_cv_folds}")
|
| 2323 |
|
|
@@ -2455,22 +2506,106 @@ def create_interface():
|
|
| 2455 |
# Tab 3: Treatment Prediction Handler
|
| 2456 |
def handle_treatment_prediction(fmri_file, age, sex, months, wab):
|
| 2457 |
"""Predict treatment outcome for a new patient"""
|
| 2458 |
-
#
|
| 2459 |
-
if not app_state
|
| 2460 |
-
|
| 2461 |
-
|
| 2462 |
-
|
| 2463 |
-
|
| 2464 |
-
|
| 2465 |
-
|
| 2466 |
-
|
| 2467 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2468 |
|
| 2469 |
# Use the trained models from app_state for prediction
|
| 2470 |
try:
|
| 2471 |
# Set up prediction
|
| 2472 |
-
if app_state
|
| 2473 |
-
|
|
|
|
| 2474 |
|
| 2475 |
# Create a temporary prediction app with our trained models
|
| 2476 |
temp_app = AphasiaPredictionApp()
|
|
@@ -2501,8 +2636,8 @@ def create_interface():
|
|
| 2501 |
# Add examples
|
| 2502 |
gr.Examples(
|
| 2503 |
examples=[
|
| 2504 |
-
["SreekarB/
|
| 2505 |
-
["SreekarB/
|
| 2506 |
],
|
| 2507 |
inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
|
| 2508 |
prediction_type, outcome_variable, skip_behavioral,
|
|
|
|
| 49 |
# Run the full analysis pipeline
|
| 50 |
# For HuggingFace dataset, we don't need the demographic file physically
|
| 51 |
# as we'll extract demographics directly from the dataset
|
| 52 |
+
if data_dir == "SreekarB/OSFData1":
|
| 53 |
+
logger.info("Using SreekarB/OSFData1 dataset, loading demographic data directly from the dataset API")
|
| 54 |
|
| 55 |
try:
|
| 56 |
# Import HF dataset libraries
|
|
|
|
| 298 |
# Set default treatment file path to our fallback file
|
| 299 |
treatment_file = fallback_file
|
| 300 |
|
| 301 |
+
# For SreekarB/OSFData1 dataset, optionally look for real treatment data
|
| 302 |
+
if data_dir == "SreekarB/OSFData1":
|
| 303 |
# Check if the user wants to skip behavioral data processing
|
| 304 |
skip_behavioral = PREDICTION_CONFIG.get('skip_behavioral_data', False)
|
| 305 |
|
|
|
|
| 354 |
except Exception as e:
|
| 355 |
logger.warning(f"Error during treatment data lookup: {e}, using standard outcomes")
|
| 356 |
# Keep using the fallback file
|
| 357 |
+
# Only check for treatment_file if we're not using the SreekarB/OSFData1 dataset
|
| 358 |
elif not os.path.exists(treatment_file):
|
| 359 |
# Try app directory as fallback
|
| 360 |
app_dir_treatment = os.path.join(os.path.dirname(os.path.abspath(__file__)), "treatment_outcomes.csv")
|
|
|
|
| 369 |
logger.info(f"Using treatment file: {treatment_file}")
|
| 370 |
|
| 371 |
# Special handling for HuggingFace dataset
|
| 372 |
+
if data_dir == "SreekarB/OSFData1":
|
| 373 |
# For NIfTI files, we need to search the API or download regardless of demographic source
|
| 374 |
logger.info("Searching for NIfTI files in the dataset...")
|
| 375 |
|
|
|
|
| 1883 |
with gr.Column(scale=1):
|
| 1884 |
data_dir = gr.Textbox(
|
| 1885 |
label="Data Directory or HuggingFace Dataset ID",
|
| 1886 |
+
value="SreekarB/OSFData1"
|
| 1887 |
)
|
| 1888 |
local_nii_dir = gr.Textbox(
|
| 1889 |
label="Local NIfTI Files Directory (Optional)",
|
|
|
|
| 1972 |
with gr.Column(scale=1):
|
| 1973 |
prediction_type = gr.Radio(
|
| 1974 |
label="Prediction Type",
|
| 1975 |
+
choices=["regression"],
|
| 1976 |
value="regression"
|
| 1977 |
)
|
| 1978 |
outcome_variable = gr.Dropdown(
|
|
|
|
| 2230 |
logger.info("Looking for data in directory and preprocessing...")
|
| 2231 |
|
| 2232 |
# This part is similar to app.train_models but only focuses on VAE
|
| 2233 |
+
if data_dir == "SreekarB/OSFData1":
|
| 2234 |
# Use dataset, similar to existing code in app.train_models
|
| 2235 |
# For brevity, we'll call the full train_models function but only
|
| 2236 |
# extract the VAE-related results
|
|
|
|
| 2298 |
# Tab 2: Random Forest Training Handler
|
| 2299 |
def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
|
| 2300 |
"""Train the Random Forest model using the VAE latent representations"""
|
| 2301 |
+
# Try to load the VAE model if it's not already trained
|
| 2302 |
+
if not app_state.get('vae_trained', False) or app_state.get('latents') is None:
|
| 2303 |
+
try:
|
| 2304 |
+
# Try to load the VAE model from disk
|
| 2305 |
+
from vae_model import DemoVAE
|
| 2306 |
+
vae_path = os.path.join('models', 'vae_model.pt')
|
| 2307 |
+
if os.path.exists(vae_path):
|
| 2308 |
+
logger.info("Loading saved VAE model...")
|
| 2309 |
+
vae = DemoVAE()
|
| 2310 |
+
vae.load(vae_path)
|
| 2311 |
+
app_state['vae'] = vae
|
| 2312 |
+
|
| 2313 |
+
# We also need latent representations for RF training
|
| 2314 |
+
# Use synthetic data if no real data is available
|
| 2315 |
+
from data_preprocessing import generate_synthetic_fc_matrices
|
| 2316 |
+
synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
|
| 2317 |
+
logger.info("Generating latent representations from synthetic data...")
|
| 2318 |
+
|
| 2319 |
+
latents = vae.encode(synthetic_fc, synthetic_demo)
|
| 2320 |
+
app_state['latents'] = latents
|
| 2321 |
+
app_state['demographics'] = synthetic_demo
|
| 2322 |
+
app_state['vae_trained'] = True
|
| 2323 |
+
logger.info("Loaded VAE model and generated synthetic latents")
|
| 2324 |
+
else:
|
| 2325 |
+
# Train a simple VAE with synthetic data
|
| 2326 |
+
from vae_model import DemoVAE
|
| 2327 |
+
from data_preprocessing import generate_synthetic_fc_matrices
|
| 2328 |
+
|
| 2329 |
+
logger.info("VAE model not found. Training a simple model with synthetic data...")
|
| 2330 |
+
|
| 2331 |
+
# Generate synthetic data
|
| 2332 |
+
synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
|
| 2333 |
+
|
| 2334 |
+
# Train a simple VAE
|
| 2335 |
+
vae = DemoVAE(latent_dim=10)
|
| 2336 |
+
vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
|
| 2337 |
+
|
| 2338 |
+
# Get latent representations
|
| 2339 |
+
latents = vae.encode(synthetic_fc, synthetic_demo)
|
| 2340 |
+
|
| 2341 |
+
# Save in app_state
|
| 2342 |
+
app_state['vae'] = vae
|
| 2343 |
+
app_state['latents'] = latents
|
| 2344 |
+
app_state['demographics'] = synthetic_demo
|
| 2345 |
+
app_state['vae_trained'] = True
|
| 2346 |
+
|
| 2347 |
+
# Save the model for future use
|
| 2348 |
+
if not os.path.exists('models'):
|
| 2349 |
+
os.makedirs('models')
|
| 2350 |
+
vae.save('models/vae_model.pt')
|
| 2351 |
+
|
| 2352 |
+
logger.info("Trained and saved a simple VAE model with synthetic data")
|
| 2353 |
+
except Exception as e:
|
| 2354 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2355 |
+
message = f"Error: Unable to load or train VAE model: {str(e)}"
|
| 2356 |
+
plt.text(0.5, 0.5, message,
|
| 2357 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2358 |
+
fontsize=14, color='red')
|
| 2359 |
+
plt.axis('off')
|
| 2360 |
+
|
| 2361 |
+
# Return error for both outputs
|
| 2362 |
+
return [error_fig, error_fig, f"Error: Unable to load or train VAE model: {str(e)}"]
|
| 2363 |
|
| 2364 |
try:
|
| 2365 |
# Update RF configuration
|
|
|
|
| 2368 |
PREDICTION_CONFIG['max_depth'] = rf_max_depth if rf_max_depth > 0 else None
|
| 2369 |
PREDICTION_CONFIG['cv_folds'] = rf_cv_folds
|
| 2370 |
|
| 2371 |
+
# We only use regression for prediction
|
| 2372 |
logger.info(f"Training Random Forest Regression model: outcome={outcome_variable}")
|
| 2373 |
logger.info(f"RF parameters: n_estimators={rf_n_estimators}, max_depth={rf_max_depth}, cv_folds={rf_cv_folds}")
|
| 2374 |
|
|
|
|
| 2506 |
# Tab 3: Treatment Prediction Handler
|
| 2507 |
def handle_treatment_prediction(fmri_file, age, sex, months, wab):
|
| 2508 |
"""Predict treatment outcome for a new patient"""
|
| 2509 |
+
# Try to load models if they are not already trained
|
| 2510 |
+
if not app_state.get('vae_trained', False) or not app_state.get('rf_trained', False):
|
| 2511 |
+
try:
|
| 2512 |
+
# First check for VAE model
|
| 2513 |
+
from vae_model import DemoVAE
|
| 2514 |
+
vae_path = os.path.join('models', 'vae_model.pt')
|
| 2515 |
+
rf_path = os.path.join('models', 'predictor_model.pt')
|
| 2516 |
+
|
| 2517 |
+
vae_loaded = False
|
| 2518 |
+
rf_loaded = False
|
| 2519 |
+
|
| 2520 |
+
# Try to load the VAE model
|
| 2521 |
+
if not app_state.get('vae_trained', False) and os.path.exists(vae_path):
|
| 2522 |
+
logger.info("Loading saved VAE model...")
|
| 2523 |
+
vae = DemoVAE()
|
| 2524 |
+
vae.load(vae_path)
|
| 2525 |
+
app_state['vae'] = vae
|
| 2526 |
+
app_state['vae_trained'] = True
|
| 2527 |
+
vae_loaded = True
|
| 2528 |
+
|
| 2529 |
+
# Try to load the RF model
|
| 2530 |
+
if not app_state.get('rf_trained', False) and os.path.exists(rf_path):
|
| 2531 |
+
logger.info("Loading saved RF predictor model...")
|
| 2532 |
+
from main import RandomForestPredictor
|
| 2533 |
+
|
| 2534 |
+
# Load the model
|
| 2535 |
+
loaded_data = torch.load(rf_path)
|
| 2536 |
+
predictor = RandomForestPredictor()
|
| 2537 |
+
predictor.model = loaded_data['predictor_state']
|
| 2538 |
+
predictor.feature_importance = loaded_data.get('feature_importance', {})
|
| 2539 |
+
|
| 2540 |
+
app_state['predictor'] = predictor
|
| 2541 |
+
app_state['rf_trained'] = True
|
| 2542 |
+
rf_loaded = True
|
| 2543 |
+
|
| 2544 |
+
# If we couldn't load both models, train quick synthetic models
|
| 2545 |
+
if not (vae_loaded and rf_loaded):
|
| 2546 |
+
logger.info("Training synthetic models for demo purposes...")
|
| 2547 |
+
|
| 2548 |
+
# Generate synthetic data
|
| 2549 |
+
from data_preprocessing import generate_synthetic_fc_matrices
|
| 2550 |
+
synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
|
| 2551 |
+
|
| 2552 |
+
# Train VAE if needed
|
| 2553 |
+
if not vae_loaded:
|
| 2554 |
+
vae = DemoVAE(latent_dim=10)
|
| 2555 |
+
vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
|
| 2556 |
+
app_state['vae'] = vae
|
| 2557 |
+
app_state['vae_trained'] = True
|
| 2558 |
+
|
| 2559 |
+
# Save for future use
|
| 2560 |
+
if not os.path.exists('models'):
|
| 2561 |
+
os.makedirs('models')
|
| 2562 |
+
vae.save('models/vae_model.pt')
|
| 2563 |
+
else:
|
| 2564 |
+
vae = app_state['vae']
|
| 2565 |
+
|
| 2566 |
+
# Get latent representations for RF training
|
| 2567 |
+
latents = vae.encode(synthetic_fc, synthetic_demo)
|
| 2568 |
+
|
| 2569 |
+
# Train RF if needed
|
| 2570 |
+
if not rf_loaded:
|
| 2571 |
+
from main import RandomForestPredictor
|
| 2572 |
+
|
| 2573 |
+
# Create synthetic outcome data
|
| 2574 |
+
import numpy as np
|
| 2575 |
+
outcomes = np.random.normal(50, 10, size=len(synthetic_demo))
|
| 2576 |
+
|
| 2577 |
+
# Train the RF model
|
| 2578 |
+
predictor = RandomForestPredictor()
|
| 2579 |
+
predictor.train(latents, outcomes)
|
| 2580 |
+
|
| 2581 |
+
app_state['predictor'] = predictor
|
| 2582 |
+
app_state['rf_trained'] = True
|
| 2583 |
+
|
| 2584 |
+
# Save for future use
|
| 2585 |
+
if not os.path.exists('models'):
|
| 2586 |
+
os.makedirs('models')
|
| 2587 |
+
torch.save({
|
| 2588 |
+
'predictor_state': predictor.model,
|
| 2589 |
+
'feature_importance': predictor.feature_importance
|
| 2590 |
+
}, 'models/predictor_model.pt')
|
| 2591 |
+
|
| 2592 |
+
logger.info("Successfully trained synthetic models for demo")
|
| 2593 |
+
except Exception as e:
|
| 2594 |
+
error_message = f"Error: Unable to load or train required models: {str(e)}"
|
| 2595 |
+
error_fig = plt.figure(figsize=(10, 6))
|
| 2596 |
+
plt.text(0.5, 0.5, error_message,
|
| 2597 |
+
horizontalalignment='center', verticalalignment='center',
|
| 2598 |
+
fontsize=14, color='red')
|
| 2599 |
+
plt.axis('off')
|
| 2600 |
+
|
| 2601 |
+
return [error_message, error_fig]
|
| 2602 |
|
| 2603 |
# Use the trained models from app_state for prediction
|
| 2604 |
try:
|
| 2605 |
# Set up prediction
|
| 2606 |
+
if app_state.get('vae') is None or app_state.get('predictor') is None:
|
| 2607 |
+
error_message = "Error: Models not properly available"
|
| 2608 |
+
return [error_message, None]
|
| 2609 |
|
| 2610 |
# Create a temporary prediction app with our trained models
|
| 2611 |
temp_app = AphasiaPredictionApp()
|
|
|
|
| 2636 |
# Add examples
|
| 2637 |
gr.Examples(
|
| 2638 |
examples=[
|
| 2639 |
+
["SreekarB/OSFData1", "", 32, 200, 16, True, "regression", "wab_aq", True, False, False], # Standard training without synthetic data
|
| 2640 |
+
["SreekarB/OSFData1", "", 16, 100, 8, True, "regression", "wab_aq", True, False, False] # Faster training with smaller parameters
|
| 2641 |
],
|
| 2642 |
inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
|
| 2643 |
prediction_type, outcome_variable, skip_behavioral,
|
config.py
CHANGED
|
@@ -18,7 +18,7 @@ PREPROCESS_CONFIG = {
|
|
| 18 |
|
| 19 |
# Dataset configuration
|
| 20 |
DATASET_CONFIG = {
|
| 21 |
-
'name': 'SreekarB/
|
| 22 |
'split': 'train'
|
| 23 |
}
|
| 24 |
|
|
|
|
| 18 |
|
| 19 |
# Dataset configuration
|
| 20 |
DATASET_CONFIG = {
|
| 21 |
+
'name': 'SreekarB/OSFData1',
|
| 22 |
'split': 'train'
|
| 23 |
}
|
| 24 |
|
huggingface_fc_visualization.py
CHANGED
|
@@ -373,7 +373,7 @@ def generate_comparison():
|
|
| 373 |
print("Loading dataset from HuggingFace...")
|
| 374 |
|
| 375 |
# Load the HuggingFace dataset using config
|
| 376 |
-
dataset_name = DATASET_CONFIG.get('name', 'SreekarB/
|
| 377 |
dataset_split = DATASET_CONFIG.get('split', 'train')
|
| 378 |
|
| 379 |
dataset = load_dataset(dataset_name, split=dataset_split)
|
|
|
|
| 373 |
print("Loading dataset from HuggingFace...")
|
| 374 |
|
| 375 |
# Load the HuggingFace dataset using config
|
| 376 |
+
dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData1')
|
| 377 |
dataset_split = DATASET_CONFIG.get('split', 'train')
|
| 378 |
|
| 379 |
dataset = load_dataset(dataset_name, split=dataset_split)
|
test_hf_download.py
CHANGED
|
@@ -6,7 +6,7 @@ from datasets import load_dataset
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
-
def test_huggingface_download(dataset_name="SreekarB/
|
| 10 |
"""
|
| 11 |
Test script to verify downloading NIfTI files from HuggingFace Datasets
|
| 12 |
"""
|
|
@@ -227,7 +227,7 @@ if __name__ == "__main__":
|
|
| 227 |
# Process command line arguments
|
| 228 |
import argparse
|
| 229 |
parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
|
| 230 |
-
parser.add_argument('--dataset', type=str, default="SreekarB/
|
| 231 |
parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
|
| 232 |
parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
|
| 233 |
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
import pandas as pd
|
| 8 |
|
| 9 |
+
def test_huggingface_download(dataset_name="SreekarB/OSFData1", revision=None, auth_token=None):
|
| 10 |
"""
|
| 11 |
Test script to verify downloading NIfTI files from HuggingFace Datasets
|
| 12 |
"""
|
|
|
|
| 227 |
# Process command line arguments
|
| 228 |
import argparse
|
| 229 |
parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
|
| 230 |
+
parser.add_argument('--dataset', type=str, default="SreekarB/OSFData1", help='HuggingFace dataset name')
|
| 231 |
parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
|
| 232 |
parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
|
| 233 |
|
visualize_fc.py
CHANGED
|
@@ -15,7 +15,7 @@ from config import PREDICTION_CONFIG
|
|
| 15 |
|
| 16 |
def main():
|
| 17 |
# Configuration
|
| 18 |
-
data_dir = "SreekarB/
|
| 19 |
latent_dim = 16
|
| 20 |
nepochs = 50
|
| 21 |
batch_size = 4
|
|
|
|
| 15 |
|
| 16 |
def main():
|
| 17 |
# Configuration
|
| 18 |
+
data_dir = "SreekarB/OSFData1" # HuggingFace dataset
|
| 19 |
latent_dim = 16
|
| 20 |
nepochs = 50
|
| 21 |
batch_size = 4
|