SreekarB commited on
Commit
e81f968
·
verified ·
1 Parent(s): 0d38954

Upload 6 files

Browse files
Files changed (6) hide show
  1. README.md +3 -3
  2. app.py +170 -35
  3. config.py +1 -1
  4. huggingface_fc_visualization.py +1 -1
  5. test_hf_download.py +2 -2
  6. visualize_fc.py +1 -1
README.md CHANGED
@@ -4,7 +4,7 @@ emoji: 🧠
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  ---
@@ -23,7 +23,7 @@ This application implements a VAE model that:
23
 
24
  ## Dataset
25
 
26
- This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
27
 
28
  - NIfTI files in P01_rs.nii format containing fMRI data
29
  - Demographic information directly in the dataset:
@@ -40,7 +40,7 @@ The application processes the NIfTI files using the Power 264 atlas to create fu
40
  ## How to Use
41
 
42
  1. **Configure Parameters**:
43
- - **Data Source**: By default, it uses the SreekarB/OSFData HuggingFace dataset
44
  - **Latent Dimensions**: Controls the size of the latent space (default: 32)
45
  - **Number of Epochs**: Training iterations (default: 200 for demo)
46
  - **Batch Size**: Training batch size (default: 16)
 
4
  colorFrom: blue
5
  colorTo: pink
6
  sdk: gradio
7
+ sdk_version: 3.36.1
8
  app_file: app.py
9
  pinned: false
10
  ---
 
23
 
24
  ## Dataset
25
 
26
+ This demo uses the [SreekarB/OSFData1](https://huggingface.co/datasets/SreekarB/OSFData1) dataset from HuggingFace, which contains:
27
 
28
  - NIfTI files in P01_rs.nii format containing fMRI data
29
  - Demographic information directly in the dataset:
 
40
  ## How to Use
41
 
42
  1. **Configure Parameters**:
43
+ - **Data Source**: By default, it uses the SreekarB/OSFData1 HuggingFace dataset
44
  - **Latent Dimensions**: Controls the size of the latent space (default: 32)
45
  - **Number of Epochs**: Training iterations (default: 200 for demo)
46
  - **Batch Size**: Training batch size (default: 16)
app.py CHANGED
@@ -49,8 +49,8 @@ class AphasiaPredictionApp:
49
  # Run the full analysis pipeline
50
  # For HuggingFace dataset, we don't need the demographic file physically
51
  # as we'll extract demographics directly from the dataset
52
- if data_dir == "SreekarB/OSFData":
53
- logger.info("Using SreekarB/OSFData dataset, loading demographic data directly from the dataset API")
54
 
55
  try:
56
  # Import HF dataset libraries
@@ -298,8 +298,8 @@ class AphasiaPredictionApp:
298
  # Set default treatment file path to our fallback file
299
  treatment_file = fallback_file
300
 
301
- # For SreekarB/OSFData dataset, optionally look for real treatment data
302
- if data_dir == "SreekarB/OSFData":
303
  # Check if the user wants to skip behavioral data processing
304
  skip_behavioral = PREDICTION_CONFIG.get('skip_behavioral_data', False)
305
 
@@ -354,7 +354,7 @@ class AphasiaPredictionApp:
354
  except Exception as e:
355
  logger.warning(f"Error during treatment data lookup: {e}, using standard outcomes")
356
  # Keep using the fallback file
357
- # Only check for treatment_file if we're not using the SreekarB/OSFData dataset
358
  elif not os.path.exists(treatment_file):
359
  # Try app directory as fallback
360
  app_dir_treatment = os.path.join(os.path.dirname(os.path.abspath(__file__)), "treatment_outcomes.csv")
@@ -369,7 +369,7 @@ class AphasiaPredictionApp:
369
  logger.info(f"Using treatment file: {treatment_file}")
370
 
371
  # Special handling for HuggingFace dataset
372
- if data_dir == "SreekarB/OSFData":
373
  # For NIfTI files, we need to search the API or download regardless of demographic source
374
  logger.info("Searching for NIfTI files in the dataset...")
375
 
@@ -1883,7 +1883,7 @@ def create_interface():
1883
  with gr.Column(scale=1):
1884
  data_dir = gr.Textbox(
1885
  label="Data Directory or HuggingFace Dataset ID",
1886
- value="SreekarB/OSFData"
1887
  )
1888
  local_nii_dir = gr.Textbox(
1889
  label="Local NIfTI Files Directory (Optional)",
@@ -1972,7 +1972,7 @@ def create_interface():
1972
  with gr.Column(scale=1):
1973
  prediction_type = gr.Radio(
1974
  label="Prediction Type",
1975
- choices=["regression", "classification"],
1976
  value="regression"
1977
  )
1978
  outcome_variable = gr.Dropdown(
@@ -2230,7 +2230,7 @@ def create_interface():
2230
  logger.info("Looking for data in directory and preprocessing...")
2231
 
2232
  # This part is similar to app.train_models but only focuses on VAE
2233
- if data_dir == "SreekarB/OSFData":
2234
  # Use dataset, similar to existing code in app.train_models
2235
  # For brevity, we'll call the full train_models function but only
2236
  # extract the VAE-related results
@@ -2298,17 +2298,68 @@ def create_interface():
2298
  # Tab 2: Random Forest Training Handler
2299
  def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
2300
  """Train the Random Forest model using the VAE latent representations"""
2301
- # Check if VAE has been trained
2302
- if not app_state['vae_trained'] or app_state['latents'] is None:
2303
- error_fig = plt.figure(figsize=(10, 6))
2304
- message = "Error: You must train the VAE model in Tab 1 first!"
2305
- plt.text(0.5, 0.5, message,
2306
- horizontalalignment='center', verticalalignment='center',
2307
- fontsize=14, color='red')
2308
- plt.axis('off')
2309
-
2310
- # Return error for both outputs
2311
- return [error_fig, error_fig, "Error: VAE not trained. Go to Tab 1 and train the VAE first."]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2312
 
2313
  try:
2314
  # Update RF configuration
@@ -2317,7 +2368,7 @@ def create_interface():
2317
  PREDICTION_CONFIG['max_depth'] = rf_max_depth if rf_max_depth > 0 else None
2318
  PREDICTION_CONFIG['cv_folds'] = rf_cv_folds
2319
 
2320
- # Note: prediction_type parameter is ignored as we only support regression
2321
  logger.info(f"Training Random Forest Regression model: outcome={outcome_variable}")
2322
  logger.info(f"RF parameters: n_estimators={rf_n_estimators}, max_depth={rf_max_depth}, cv_folds={rf_cv_folds}")
2323
 
@@ -2455,22 +2506,106 @@ def create_interface():
2455
  # Tab 3: Treatment Prediction Handler
2456
  def handle_treatment_prediction(fmri_file, age, sex, months, wab):
2457
  """Predict treatment outcome for a new patient"""
2458
- # Check if models have been trained
2459
- if not app_state['vae_trained'] or not app_state['rf_trained']:
2460
- error_message = "Error: You must train both the VAE (Tab 1) and Random Forest (Tab 2) models first!"
2461
- error_fig = plt.figure(figsize=(10, 6))
2462
- plt.text(0.5, 0.5, error_message,
2463
- horizontalalignment='center', verticalalignment='center',
2464
- fontsize=14, color='red')
2465
- plt.axis('off')
2466
-
2467
- return [error_message, error_fig]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2468
 
2469
  # Use the trained models from app_state for prediction
2470
  try:
2471
  # Set up prediction
2472
- if app_state['vae'] is None or app_state['predictor'] is None:
2473
- return ["Error: Models not properly trained", None]
 
2474
 
2475
  # Create a temporary prediction app with our trained models
2476
  temp_app = AphasiaPredictionApp()
@@ -2501,8 +2636,8 @@ def create_interface():
2501
  # Add examples
2502
  gr.Examples(
2503
  examples=[
2504
- ["SreekarB/OSFData", "", 32, 200, 16, True, "regression", "wab_aq", True, False, False], # Standard training without synthetic data
2505
- ["SreekarB/OSFData", "", 16, 100, 8, True, "classification", "wab_aq", True, False, False] # Faster training with classification
2506
  ],
2507
  inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
2508
  prediction_type, outcome_variable, skip_behavioral,
 
49
  # Run the full analysis pipeline
50
  # For HuggingFace dataset, we don't need the demographic file physically
51
  # as we'll extract demographics directly from the dataset
52
+ if data_dir == "SreekarB/OSFData1":
53
+ logger.info("Using SreekarB/OSFData1 dataset, loading demographic data directly from the dataset API")
54
 
55
  try:
56
  # Import HF dataset libraries
 
298
  # Set default treatment file path to our fallback file
299
  treatment_file = fallback_file
300
 
301
+ # For SreekarB/OSFData1 dataset, optionally look for real treatment data
302
+ if data_dir == "SreekarB/OSFData1":
303
  # Check if the user wants to skip behavioral data processing
304
  skip_behavioral = PREDICTION_CONFIG.get('skip_behavioral_data', False)
305
 
 
354
  except Exception as e:
355
  logger.warning(f"Error during treatment data lookup: {e}, using standard outcomes")
356
  # Keep using the fallback file
357
+ # Only check for treatment_file if we're not using the SreekarB/OSFData1 dataset
358
  elif not os.path.exists(treatment_file):
359
  # Try app directory as fallback
360
  app_dir_treatment = os.path.join(os.path.dirname(os.path.abspath(__file__)), "treatment_outcomes.csv")
 
369
  logger.info(f"Using treatment file: {treatment_file}")
370
 
371
  # Special handling for HuggingFace dataset
372
+ if data_dir == "SreekarB/OSFData1":
373
  # For NIfTI files, we need to search the API or download regardless of demographic source
374
  logger.info("Searching for NIfTI files in the dataset...")
375
 
 
1883
  with gr.Column(scale=1):
1884
  data_dir = gr.Textbox(
1885
  label="Data Directory or HuggingFace Dataset ID",
1886
+ value="SreekarB/OSFData1"
1887
  )
1888
  local_nii_dir = gr.Textbox(
1889
  label="Local NIfTI Files Directory (Optional)",
 
1972
  with gr.Column(scale=1):
1973
  prediction_type = gr.Radio(
1974
  label="Prediction Type",
1975
+ choices=["regression"],
1976
  value="regression"
1977
  )
1978
  outcome_variable = gr.Dropdown(
 
2230
  logger.info("Looking for data in directory and preprocessing...")
2231
 
2232
  # This part is similar to app.train_models but only focuses on VAE
2233
+ if data_dir == "SreekarB/OSFData1":
2234
  # Use dataset, similar to existing code in app.train_models
2235
  # For brevity, we'll call the full train_models function but only
2236
  # extract the VAE-related results
 
2298
  # Tab 2: Random Forest Training Handler
2299
  def handle_rf_training(prediction_type, outcome_variable, rf_n_estimators, rf_max_depth, rf_cv_folds):
2300
  """Train the Random Forest model using the VAE latent representations"""
2301
+ # Try to load the VAE model if it's not already trained
2302
+ if not app_state.get('vae_trained', False) or app_state.get('latents') is None:
2303
+ try:
2304
+ # Try to load the VAE model from disk
2305
+ from vae_model import DemoVAE
2306
+ vae_path = os.path.join('models', 'vae_model.pt')
2307
+ if os.path.exists(vae_path):
2308
+ logger.info("Loading saved VAE model...")
2309
+ vae = DemoVAE()
2310
+ vae.load(vae_path)
2311
+ app_state['vae'] = vae
2312
+
2313
+ # We also need latent representations for RF training
2314
+ # Use synthetic data if no real data is available
2315
+ from data_preprocessing import generate_synthetic_fc_matrices
2316
+ synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
2317
+ logger.info("Generating latent representations from synthetic data...")
2318
+
2319
+ latents = vae.encode(synthetic_fc, synthetic_demo)
2320
+ app_state['latents'] = latents
2321
+ app_state['demographics'] = synthetic_demo
2322
+ app_state['vae_trained'] = True
2323
+ logger.info("Loaded VAE model and generated synthetic latents")
2324
+ else:
2325
+ # Train a simple VAE with synthetic data
2326
+ from vae_model import DemoVAE
2327
+ from data_preprocessing import generate_synthetic_fc_matrices
2328
+
2329
+ logger.info("VAE model not found. Training a simple model with synthetic data...")
2330
+
2331
+ # Generate synthetic data
2332
+ synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
2333
+
2334
+ # Train a simple VAE
2335
+ vae = DemoVAE(latent_dim=10)
2336
+ vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
2337
+
2338
+ # Get latent representations
2339
+ latents = vae.encode(synthetic_fc, synthetic_demo)
2340
+
2341
+ # Save in app_state
2342
+ app_state['vae'] = vae
2343
+ app_state['latents'] = latents
2344
+ app_state['demographics'] = synthetic_demo
2345
+ app_state['vae_trained'] = True
2346
+
2347
+ # Save the model for future use
2348
+ if not os.path.exists('models'):
2349
+ os.makedirs('models')
2350
+ vae.save('models/vae_model.pt')
2351
+
2352
+ logger.info("Trained and saved a simple VAE model with synthetic data")
2353
+ except Exception as e:
2354
+ error_fig = plt.figure(figsize=(10, 6))
2355
+ message = f"Error: Unable to load or train VAE model: {str(e)}"
2356
+ plt.text(0.5, 0.5, message,
2357
+ horizontalalignment='center', verticalalignment='center',
2358
+ fontsize=14, color='red')
2359
+ plt.axis('off')
2360
+
2361
+ # Return error for both outputs
2362
+ return [error_fig, error_fig, f"Error: Unable to load or train VAE model: {str(e)}"]
2363
 
2364
  try:
2365
  # Update RF configuration
 
2368
  PREDICTION_CONFIG['max_depth'] = rf_max_depth if rf_max_depth > 0 else None
2369
  PREDICTION_CONFIG['cv_folds'] = rf_cv_folds
2370
 
2371
+ # We only use regression for prediction
2372
  logger.info(f"Training Random Forest Regression model: outcome={outcome_variable}")
2373
  logger.info(f"RF parameters: n_estimators={rf_n_estimators}, max_depth={rf_max_depth}, cv_folds={rf_cv_folds}")
2374
 
 
2506
  # Tab 3: Treatment Prediction Handler
2507
  def handle_treatment_prediction(fmri_file, age, sex, months, wab):
2508
  """Predict treatment outcome for a new patient"""
2509
+ # Try to load models if they are not already trained
2510
+ if not app_state.get('vae_trained', False) or not app_state.get('rf_trained', False):
2511
+ try:
2512
+ # First check for VAE model
2513
+ from vae_model import DemoVAE
2514
+ vae_path = os.path.join('models', 'vae_model.pt')
2515
+ rf_path = os.path.join('models', 'predictor_model.pt')
2516
+
2517
+ vae_loaded = False
2518
+ rf_loaded = False
2519
+
2520
+ # Try to load the VAE model
2521
+ if not app_state.get('vae_trained', False) and os.path.exists(vae_path):
2522
+ logger.info("Loading saved VAE model...")
2523
+ vae = DemoVAE()
2524
+ vae.load(vae_path)
2525
+ app_state['vae'] = vae
2526
+ app_state['vae_trained'] = True
2527
+ vae_loaded = True
2528
+
2529
+ # Try to load the RF model
2530
+ if not app_state.get('rf_trained', False) and os.path.exists(rf_path):
2531
+ logger.info("Loading saved RF predictor model...")
2532
+ from main import RandomForestPredictor
2533
+
2534
+ # Load the model
2535
+ loaded_data = torch.load(rf_path)
2536
+ predictor = RandomForestPredictor()
2537
+ predictor.model = loaded_data['predictor_state']
2538
+ predictor.feature_importance = loaded_data.get('feature_importance', {})
2539
+
2540
+ app_state['predictor'] = predictor
2541
+ app_state['rf_trained'] = True
2542
+ rf_loaded = True
2543
+
2544
+ # If we couldn't load both models, train quick synthetic models
2545
+ if not (vae_loaded and rf_loaded):
2546
+ logger.info("Training synthetic models for demo purposes...")
2547
+
2548
+ # Generate synthetic data
2549
+ from data_preprocessing import generate_synthetic_fc_matrices
2550
+ synthetic_fc, synthetic_demo = generate_synthetic_fc_matrices(30)
2551
+
2552
+ # Train VAE if needed
2553
+ if not vae_loaded:
2554
+ vae = DemoVAE(latent_dim=10)
2555
+ vae.train(synthetic_fc, synthetic_demo, nepochs=10, bsize=8)
2556
+ app_state['vae'] = vae
2557
+ app_state['vae_trained'] = True
2558
+
2559
+ # Save for future use
2560
+ if not os.path.exists('models'):
2561
+ os.makedirs('models')
2562
+ vae.save('models/vae_model.pt')
2563
+ else:
2564
+ vae = app_state['vae']
2565
+
2566
+ # Get latent representations for RF training
2567
+ latents = vae.encode(synthetic_fc, synthetic_demo)
2568
+
2569
+ # Train RF if needed
2570
+ if not rf_loaded:
2571
+ from main import RandomForestPredictor
2572
+
2573
+ # Create synthetic outcome data
2574
+ import numpy as np
2575
+ outcomes = np.random.normal(50, 10, size=len(synthetic_demo))
2576
+
2577
+ # Train the RF model
2578
+ predictor = RandomForestPredictor()
2579
+ predictor.train(latents, outcomes)
2580
+
2581
+ app_state['predictor'] = predictor
2582
+ app_state['rf_trained'] = True
2583
+
2584
+ # Save for future use
2585
+ if not os.path.exists('models'):
2586
+ os.makedirs('models')
2587
+ torch.save({
2588
+ 'predictor_state': predictor.model,
2589
+ 'feature_importance': predictor.feature_importance
2590
+ }, 'models/predictor_model.pt')
2591
+
2592
+ logger.info("Successfully trained synthetic models for demo")
2593
+ except Exception as e:
2594
+ error_message = f"Error: Unable to load or train required models: {str(e)}"
2595
+ error_fig = plt.figure(figsize=(10, 6))
2596
+ plt.text(0.5, 0.5, error_message,
2597
+ horizontalalignment='center', verticalalignment='center',
2598
+ fontsize=14, color='red')
2599
+ plt.axis('off')
2600
+
2601
+ return [error_message, error_fig]
2602
 
2603
  # Use the trained models from app_state for prediction
2604
  try:
2605
  # Set up prediction
2606
+ if app_state.get('vae') is None or app_state.get('predictor') is None:
2607
+ error_message = "Error: Models not properly available"
2608
+ return [error_message, None]
2609
 
2610
  # Create a temporary prediction app with our trained models
2611
  temp_app = AphasiaPredictionApp()
 
2636
  # Add examples
2637
  gr.Examples(
2638
  examples=[
2639
+ ["SreekarB/OSFData1", "", 32, 200, 16, True, "regression", "wab_aq", True, False, False], # Standard training without synthetic data
2640
+ ["SreekarB/OSFData1", "", 16, 100, 8, True, "regression", "wab_aq", True, False, False] # Faster training with smaller parameters
2641
  ],
2642
  inputs=[data_dir, local_nii_dir, latent_dim, nepochs, bsize, use_hf_dataset,
2643
  prediction_type, outcome_variable, skip_behavioral,
config.py CHANGED
@@ -18,7 +18,7 @@ PREPROCESS_CONFIG = {
18
 
19
  # Dataset configuration
20
  DATASET_CONFIG = {
21
- 'name': 'SreekarB/OSFData',
22
  'split': 'train'
23
  }
24
 
 
18
 
19
  # Dataset configuration
20
  DATASET_CONFIG = {
21
+ 'name': 'SreekarB/OSFData1',
22
  'split': 'train'
23
  }
24
 
huggingface_fc_visualization.py CHANGED
@@ -373,7 +373,7 @@ def generate_comparison():
373
  print("Loading dataset from HuggingFace...")
374
 
375
  # Load the HuggingFace dataset using config
376
- dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData')
377
  dataset_split = DATASET_CONFIG.get('split', 'train')
378
 
379
  dataset = load_dataset(dataset_name, split=dataset_split)
 
373
  print("Loading dataset from HuggingFace...")
374
 
375
  # Load the HuggingFace dataset using config
376
+ dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData1')
377
  dataset_split = DATASET_CONFIG.get('split', 'train')
378
 
379
  dataset = load_dataset(dataset_name, split=dataset_split)
test_hf_download.py CHANGED
@@ -6,7 +6,7 @@ from datasets import load_dataset
6
  import numpy as np
7
  import pandas as pd
8
 
9
- def test_huggingface_download(dataset_name="SreekarB/OSFData", revision=None, auth_token=None):
10
  """
11
  Test script to verify downloading NIfTI files from HuggingFace Datasets
12
  """
@@ -227,7 +227,7 @@ if __name__ == "__main__":
227
  # Process command line arguments
228
  import argparse
229
  parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
230
- parser.add_argument('--dataset', type=str, default="SreekarB/OSFData", help='HuggingFace dataset name')
231
  parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
232
  parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
233
 
 
6
  import numpy as np
7
  import pandas as pd
8
 
9
+ def test_huggingface_download(dataset_name="SreekarB/OSFData1", revision=None, auth_token=None):
10
  """
11
  Test script to verify downloading NIfTI files from HuggingFace Datasets
12
  """
 
227
  # Process command line arguments
228
  import argparse
229
  parser = argparse.ArgumentParser(description='Test HuggingFace dataset downloading')
230
+ parser.add_argument('--dataset', type=str, default="SreekarB/OSFData1", help='HuggingFace dataset name')
231
  parser.add_argument('--revision', type=str, default=None, help='Dataset revision/branch')
232
  parser.add_argument('--token', type=str, default=None, help='HuggingFace authentication token')
233
 
visualize_fc.py CHANGED
@@ -15,7 +15,7 @@ from config import PREDICTION_CONFIG
15
 
16
  def main():
17
  # Configuration
18
- data_dir = "SreekarB/OSFData" # HuggingFace dataset
19
  latent_dim = 16
20
  nepochs = 50
21
  batch_size = 4
 
15
 
16
  def main():
17
  # Configuration
18
+ data_dir = "SreekarB/OSFData1" # HuggingFace dataset
19
  latent_dim = 16
20
  nepochs = 50
21
  batch_size = 4