SreekarB commited on
Commit
3a90fb6
·
verified ·
1 Parent(s): 9939094

Upload 18 files

Browse files
README.md CHANGED
@@ -1,12 +1,62 @@
1
- ---
2
- title: VAE
3
- emoji: 🏢
4
- colorFrom: red
5
- colorTo: indigo
6
- sdk: gradio
7
- sdk_version: 5.20.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Aphasia Prediction with FC Visualization
3
+ emoji: 🧠
4
+ colorFrom: indigo
5
+ colorTo: red
6
+ sdk: gradio
7
+ sdk_version: 3.50.0
8
+ app_file: app.py
9
+ pinned: false
10
+ ---
11
+
12
+ # Aphasia Prediction with VAE and FC Visualization
13
+
14
+ This application predicts aphasia scores based on patient demographics and visualizes functional connectivity (FC) patterns in the brain.
15
+
16
+ ## Features
17
+
18
+ - Predict aphasia severity (WAB AQ score) based on patient demographic data
19
+ - Option to manually set aphasia scores
20
+ - Display functional connectivity heatmap and matrix values
21
+ - Interactive visualization of brain region connectivity
22
+ - Customizable demographic parameters
23
+
24
+ ## Usage
25
+
26
+ 1. First, if you haven't trained the model, go to the "Train Model" tab and click "Train Model"
27
+ 2. When the model is ready, go to the "Predict & Visualize" tab
28
+ 3. Adjust the demographic sliders for age, months post onset, education, gender, and handedness
29
+ 4. Select an aphasia type and set initial severity and lesion size
30
+ 5. Click "Generate Functional Connectivity" to see the predictions and visualization
31
+ 6. Optionally override the model's prediction with a custom score
32
+ 7. Explore the functional connectivity matrix visualization and detailed values
33
+
34
+ ## Technical Details
35
+
36
+ The application uses:
37
+ - A Variational Autoencoder (VAE) from the DemoVAE package for learning latent representations of brain connectivity
38
+ - Random Forest regression to predict aphasia scores from latent features and demographics
39
+ - Gradio web interface for interactive visualization
40
+ - Analysis of key brain connectivity patterns and their relationship to aphasia
41
+
42
+ ## Deployment
43
+
44
+ The application can be deployed using:
45
+
46
+ ```bash
47
+ # Install requirements
48
+ pip install -r requirements.txt
49
+
50
+ # Run the Gradio app
51
+ python app_gradio.py
52
+ ```
53
+
54
+ ## Hugging Face Spaces Deployment
55
+
56
+ This app is designed to be deployed on Hugging Face Spaces:
57
+
58
+ 1. Create a new Space and select Gradio as the SDK
59
+ 2. Upload the files or connect to your GitHub repository
60
+ 3. The app will automatically deploy and be available online
61
+
62
+ Note: The initial model training may take some time when you first run the application.
app.py ADDED
@@ -0,0 +1,451 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import gradio as gr
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import torch
7
+ import pickle
8
+ import pandas as pd
9
+ import time
10
+ import warnings
11
+ warnings.filterwarnings('ignore') # Suppress warnings
12
+
13
+ # Add the current directory to Python path
14
+ sys.path.append(os.path.dirname(os.path.abspath(__file__)))
15
+
16
+ # Add PIP package to path
17
+ sys.path.append(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'pip', 'src'))
18
+
19
+ # Check if running in Hugging Face Spaces
20
+ IS_SPACE = os.environ.get('SPACE_ID') is not None
21
+ print(f"Running in {'Hugging Face Spaces' if IS_SPACE else 'local environment'}")
22
+
23
+ # Import VAE model and functions
24
+ try:
25
+ from osf_demovae_adapter import (
26
+ VAE, load_and_process_data, train_demovae_model, predict_aphasia_recovery,
27
+ generate_custom_fc, to_torch, to_cuda, to_numpy, vec2mat, mat2vec, ATLAS_REGIONS,
28
+ download_model, get_connectivity_visualization
29
+ )
30
+ print("Successfully imported osf_demovae_adapter modules")
31
+ except ImportError as e:
32
+ print(f"Error importing osf_demovae_adapter modules: {e} - make sure path is correct")
33
+
34
+ # Model configuration
35
+ # Use /tmp for model storage in Hugging Face Spaces
36
+ MODEL_DIR = "/tmp/osf_models" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
37
+ os.makedirs(MODEL_DIR, exist_ok=True)
38
+ MODEL_PATH = os.path.join(MODEL_DIR, 'osf_demovae_model.pt')
39
+ LATENT_DIM = 30
40
+ INPUT_DIM = 1000
41
+ DEMO_DIM = 5
42
+ # Be cautious with CUDA in Spaces as resources may be limited
43
+ USE_CUDA = torch.cuda.is_available() and not IS_SPACE # Disable CUDA in Spaces for stability
44
+
45
+ # Initialize model during startup if needed
46
+ model = None
47
+ demovae_model = None
48
+ prediction_model = None
49
+ model_loaded = False
50
+
51
+ # Helper function for aphasia severity interpretation
52
+ def get_aphasia_severity_category(wab_score):
53
+ """Interpret WAB AQ score to determine aphasia severity category"""
54
+ if wab_score >= 93.8:
55
+ return "No aphasia (within normal limits)"
56
+ elif wab_score >= 75:
57
+ return "Mild aphasia"
58
+ elif wab_score >= 50:
59
+ return "Moderate aphasia"
60
+ elif wab_score >= 25:
61
+ return "Severe aphasia"
62
+ else:
63
+ return "Very severe aphasia"
64
+
65
+ def load_model():
66
+ """Load the VAE model and prediction model from disk or download if not available"""
67
+ global model, model_loaded, prediction_model, demovae_model
68
+ try:
69
+ # Try to load both models from the combined pickle file first
70
+ combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
71
+
72
+ # Check if combined model exists, otherwise try to download it
73
+ if not os.path.exists(combined_model_path):
74
+ try:
75
+ print("Combined model file not found. Attempting to download...")
76
+ combined_model_path = download_model('combined')
77
+ except Exception as e:
78
+ print(f"Could not download combined model: {e}")
79
+ combined_model_path = None
80
+
81
+ # If we have a combined model file, load it
82
+ if combined_model_path and os.path.exists(combined_model_path):
83
+ with open(combined_model_path, 'rb') as f:
84
+ models_dict = pickle.load(f)
85
+ demovae_model = models_dict['demovae']
86
+ prediction_model = models_dict['prediction']
87
+ model = demovae_model.vae
88
+ print("DemoVAE and prediction models loaded successfully from", combined_model_path)
89
+ else:
90
+ # Fall back to loading models separately
91
+ print("Combined model file not available. Trying to load or download models separately...")
92
+
93
+ # Check if DemoVAE model exists, otherwise try to download it
94
+ if not os.path.exists(MODEL_PATH):
95
+ try:
96
+ print("DemoVAE model not found. Attempting to download...")
97
+ MODEL_PATH = download_model('demovae')
98
+ except Exception as e:
99
+ print(f"Could not download DemoVAE model: {e}")
100
+ return False
101
+
102
+ # Create model instance and load DemoVAE model
103
+ from demovae.sklearn import DemoVAE
104
+ demovae_model = DemoVAE(latent_dim=LATENT_DIM, use_cuda=USE_CUDA)
105
+ demovae_model.load(MODEL_PATH)
106
+ model = demovae_model.vae
107
+ print("DemoVAE model loaded successfully from", MODEL_PATH)
108
+
109
+ # Check for prediction model
110
+ pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
111
+ if not os.path.exists(pred_model_path):
112
+ try:
113
+ print("Prediction model not found. Attempting to download...")
114
+ pred_model_path = download_model('prediction')
115
+ except Exception as e:
116
+ print(f"Could not download prediction model: {e}")
117
+ print("Warning: Aphasia score prediction will not be available.")
118
+ prediction_model = None
119
+ model_loaded = True
120
+ return True
121
+
122
+ # Load prediction model if available
123
+ if os.path.exists(pred_model_path):
124
+ with open(pred_model_path, 'rb') as f:
125
+ prediction_model = pickle.load(f)
126
+ print("Prediction model loaded successfully from", pred_model_path)
127
+ else:
128
+ print("Warning: Prediction model not found. Aphasia score prediction will not be available.")
129
+ prediction_model = None
130
+
131
+ model_loaded = True
132
+ return True
133
+ except Exception as e:
134
+ print(f"Error loading model: {str(e)}")
135
+ model_loaded = False
136
+ return False
137
+
138
+ def train_model(progress=gr.Progress()):
139
+ """Train the model and update progress"""
140
+ global model, model_loaded, prediction_model, demovae_model
141
+
142
+ status_messages = []
143
+
144
+ # Process data from HuggingFace
145
+ progress(0.1, desc="Processing OSF data from HuggingFace...")
146
+ status_messages.append("Step 1: Loading and processing OSF data from HuggingFace...")
147
+ X_fc, X_demo, y_wab, y_improvement, final_df = load_and_process_data(
148
+ quick_test=False # Use the full dataset, not just a sample
149
+ )
150
+ status_messages.append(f"✓ Data processed: {len(X_fc)} samples with {X_fc.shape[1]} FC features and {X_demo.shape[1]} demographic features")
151
+
152
+ # Train VAE model
153
+ progress(0.3, desc="Training DemoVAE model (first stage)...")
154
+ status_messages.append("\nStep 2: Training DemoVAE model (first stage of pipeline)...")
155
+ status_messages.append("This model will learn latent representations of brain connectivity patterns")
156
+ demovae_model, z_train, z_test, X_fc_test, X_demo_test, y_test = train_demovae_model(
157
+ X_fc, X_demo, y_wab, save_model=True, model_path=MODEL_PATH
158
+ )
159
+
160
+ # Update global model
161
+ model = demovae_model.vae
162
+ model_loaded = True
163
+ status_messages.append(f"✓ DemoVAE trained successfully: {demovae_model.latent_dim} latent dimensions")
164
+
165
+ # Train Random Forest prediction model for aphasia scores
166
+ progress(0.7, desc="Training Random Forest model (second stage)...")
167
+ status_messages.append("\nStep 3: Training Random Forest model (second stage of pipeline)...")
168
+ status_messages.append("This model will predict aphasia scores from latent brain connectivity patterns and demographics")
169
+
170
+ print("\n===== STARTING SECOND STAGE: RANDOM FOREST TRAINING =====")
171
+ print("The first stage (VAE) extracted latent representations of brain connectivity")
172
+ print("Now training Random Forest to predict aphasia scores from these representations")
173
+
174
+ X_combined = np.hstack([z_test, X_demo_test])
175
+ pred_model, y_pred, rmse_val, r2 = predict_aphasia_recovery(z_test, X_demo_test, y_test)
176
+
177
+ status_messages.append(f"✓ Random Forest trained successfully")
178
+ status_messages.append(f" - Prediction accuracy: RMSE = {rmse_val:.2f}, R² = {r2:.2f}")
179
+
180
+ # Save prediction model
181
+ status_messages.append("\nStep 4: Saving trained models...")
182
+ prediction_model = pred_model
183
+ pred_model_path = os.path.join(MODEL_DIR, 'aphasia_prediction_model.pkl')
184
+ with open(pred_model_path, 'wb') as f:
185
+ pickle.dump(pred_model, f)
186
+ status_messages.append(f"✓ Saved Random Forest model to {pred_model_path}")
187
+
188
+ # Save the trained models for future reference
189
+ combined_model_path = os.path.join(MODEL_DIR, 'demovae_and_prediction_models.pkl')
190
+ with open(combined_model_path, 'wb') as f:
191
+ pickle.dump({
192
+ 'demovae': demovae_model,
193
+ 'prediction': pred_model,
194
+ 'latent_dim': demovae_model.latent_dim
195
+ }, f)
196
+ status_messages.append(f"✓ Saved combined models to {combined_model_path}")
197
+
198
+ progress(1.0, desc="Model training complete!")
199
+ status_messages.append("\n✅ MODEL TRAINING COMPLETE!")
200
+ status_messages.append("You can now use the model to predict aphasia scores and visualize functional connectivity")
201
+
202
+ return "\n".join(status_messages)
203
+
204
+ def analyze_fc_regions(matrix, region_names, top_n=5):
205
+ """Analyze top connected brain regions from a FC matrix"""
206
+ n_regions = len(region_names)
207
+
208
+ # Get the average connectivity per region
209
+ avg_connectivity = np.zeros(n_regions)
210
+ for i in range(n_regions):
211
+ # Skip self-connections
212
+ connections = [matrix[i,j] for j in range(n_regions) if i != j]
213
+ avg_connectivity[i] = np.mean(connections)
214
+
215
+ # Get top positive and negative connected regions
216
+ pos_indices = np.argsort(avg_connectivity)[-top_n:][::-1]
217
+ neg_indices = np.argsort(avg_connectivity)[:top_n]
218
+
219
+ top_positive = [(region_names[i], avg_connectivity[i]) for i in pos_indices]
220
+ top_negative = [(region_names[i], avg_connectivity[i]) for i in neg_indices]
221
+
222
+ return top_positive, top_negative
223
+
224
+ def generate_fc_visualization(age, mpo, education, gender, handedness,
225
+ aphasia_severity, lesion_size,
226
+ use_custom_score=False, custom_score=None):
227
+ """Generate FC visualization based on demographics and return results"""
228
+ global model_loaded, model, demovae_model, prediction_model
229
+
230
+ # Check if model is loaded
231
+ if not model_loaded:
232
+ if os.path.exists(MODEL_PATH):
233
+ # Try to load existing model
234
+ if not load_model():
235
+ return None, "Failed to load model. Please train the model first."
236
+ else:
237
+ return None, "Model not found. Please train the model first."
238
+
239
+ # Convert gender to format expected by model
240
+ gender_val = 1 if gender == "Male" else 0
241
+ handedness_val = 1 if handedness == "Right" else 0
242
+
243
+ # Prepare demographics for the model
244
+ demo_values = {
245
+ 'age': age,
246
+ 'mpo': mpo,
247
+ 'education': education,
248
+ 'gender': 'male' if gender_val else 'female',
249
+ 'handedness': 'right' if handedness_val else 'left'
250
+ }
251
+
252
+ # Set predicted score to None unless we override it
253
+ predicted_aphasia_score = None
254
+ aphasia_score_source = "default"
255
+
256
+ if use_custom_score and custom_score is not None:
257
+ # Use user-provided custom score
258
+ predicted_aphasia_score = custom_score
259
+ aphasia_score_source = "custom"
260
+
261
+ # Generate FC matrix using our adapter function
262
+ try:
263
+ # Try the new function signature first (returns 3 values)
264
+ custom_fc_mat, gen_predicted_score, viz_path = generate_custom_fc(
265
+ demo_values,
266
+ demovae_model,
267
+ prediction_model if not use_custom_score else None,
268
+ visualize=True
269
+ )
270
+ except (ValueError, TypeError) as e:
271
+ # Fall back to older function signature (returns 2 values)
272
+ print(f"Warning: Using older generate_custom_fc signature: {e}")
273
+ custom_fc_mat, gen_predicted_score = generate_custom_fc(
274
+ demo_values,
275
+ demovae_model,
276
+ prediction_model if not use_custom_score else None
277
+ )
278
+ viz_path = None
279
+
280
+ # If we're using the generated prediction
281
+ if not use_custom_score and gen_predicted_score is not None:
282
+ predicted_aphasia_score = gen_predicted_score
283
+ aphasia_score_source = "predicted"
284
+ elif predicted_aphasia_score is None:
285
+ # Fall back to default value if nothing else was set
286
+ predicted_aphasia_score = aphasia_severity
287
+
288
+ # If we have a visualization path from the new function, use it
289
+ # Use /tmp for visualization files in Spaces
290
+ viz_dir = "/tmp/fc_visualizations" if IS_SPACE else os.path.dirname(os.path.abspath(__file__))
291
+ os.makedirs(viz_dir, exist_ok=True)
292
+ temp_img_path = os.path.join(viz_dir, f"temp_fc_matrix_{time.strftime('%Y%m%d_%H%M%S')}.png")
293
+
294
+ if viz_path and os.path.exists(viz_path):
295
+ # Use the already created visualization
296
+ import shutil
297
+ shutil.copy(viz_path, temp_img_path)
298
+ else:
299
+ # Generate FC heatmap with aphasia score in title using our new color scheme
300
+ try:
301
+ # Use the new visualization function if available
302
+ get_connectivity_visualization(
303
+ custom_fc_mat,
304
+ subject_id=f"Patient: Age {age}, Gender {'M' if gender_val else 'F'}, Aphasia Score: {predicted_aphasia_score:.1f}",
305
+ output_path=temp_img_path
306
+ )
307
+ except (NameError, AttributeError):
308
+ # Fall back to old style visualization
309
+ plt.figure(figsize=(10, 8))
310
+ plt.imshow(custom_fc_mat, cmap='coolwarm', vmin=-1, vmax=1)
311
+ plt.colorbar(label='Correlation')
312
+ plt.title(f'FC Matrix: Age {age}, Gender {"M" if gender_val else "F"}, Aphasia Score: {predicted_aphasia_score:.1f}')
313
+ plt.savefig(temp_img_path)
314
+ plt.close()
315
+
316
+ # Create DataFrame for FC values
317
+ region_names = ATLAS_REGIONS[:custom_fc_mat.shape[0]]
318
+
319
+ # Analyze FC regions
320
+ top_positive, top_negative = analyze_fc_regions(custom_fc_mat, region_names)
321
+
322
+ # Create summary text with the analysis
323
+ severity_category = get_aphasia_severity_category(predicted_aphasia_score)
324
+
325
+ summary = f"""### Aphasia Score: {predicted_aphasia_score:.1f}/100
326
+ Category: {severity_category}
327
+ Source: {"Model Prediction" if aphasia_score_source == "predicted" else "Custom Value" if aphasia_score_source == "custom" else "Default"}
328
+
329
+ ### Demographic Information
330
+ - Age: {age} years
331
+ - Months Post Onset: {mpo}
332
+ - Education: {education} years
333
+ - Gender: {gender}
334
+ - Handedness: {handedness}
335
+ - Lesion Size: {lesion_size}%
336
+
337
+ ### Brain Connectivity Analysis
338
+ Top connected brain regions:
339
+ """
340
+
341
+ for region, value in top_positive:
342
+ summary += f"- {region}: {value:.2f}\n"
343
+
344
+ summary += "\nLeast connected brain regions:\n"
345
+
346
+ for region, value in top_negative:
347
+ summary += f"- {region}: {value:.2f}\n"
348
+
349
+ # Create dataframe for FC matrix
350
+ df_data = []
351
+ for i in range(custom_fc_mat.shape[0]):
352
+ for j in range(custom_fc_mat.shape[0]):
353
+ if i < j: # Only include upper triangle to avoid redundancy
354
+ df_data.append({
355
+ "Region 1": region_names[i],
356
+ "Region 2": region_names[j],
357
+ "Connectivity": round(float(custom_fc_mat[i, j]), 2)
358
+ })
359
+
360
+ # Sort by absolute connectivity value
361
+ df = pd.DataFrame(df_data)
362
+ df = df.sort_values(by="Connectivity", key=abs, ascending=False)
363
+
364
+ # Limit to top 100 connections for performance
365
+ df = df.head(100)
366
+
367
+ return temp_img_path, summary, df
368
+
369
+ # Check if model exists and try to load it
370
+ if os.path.exists(MODEL_PATH):
371
+ print("Model file found. Loading model...")
372
+ load_model()
373
+ else:
374
+ print("No model found. Please train the model first.")
375
+
376
+ # Create Gradio interface
377
+ with gr.Blocks(title="Aphasia Prediction with FC Visualization") as demo:
378
+ gr.Markdown("# Aphasia Prediction with Functional Connectivity Visualization")
379
+ gr.Markdown("This app predicts aphasia scores based on patient demographics and displays functional connectivity patterns in the brain.")
380
+
381
+ with gr.Tab("Predict & Visualize"):
382
+ with gr.Row():
383
+ with gr.Column(scale=1):
384
+ gr.Markdown("### Patient Demographics")
385
+ age = gr.Slider(minimum=20, maximum=90, value=60, step=1, label="Age (years)")
386
+ mpo = gr.Slider(minimum=1, maximum=36, value=6, step=1, label="Months Post Onset")
387
+ education = gr.Slider(minimum=8, maximum=22, value=16, step=1, label="Education (years)")
388
+ gender = gr.Radio(["Male", "Female"], value="Male", label="Gender")
389
+ handedness = gr.Radio(["Right", "Left"], value="Right", label="Handedness")
390
+
391
+ gr.Markdown("### Aphasia Information")
392
+ aphasia_severity = gr.Slider(minimum=0, maximum=100, value=50, step=1, label="Aphasia Severity (WAB AQ)")
393
+ lesion_size = gr.Slider(minimum=0, maximum=100, value=20, step=1, label="Lesion Size (%)")
394
+
395
+ use_custom_score = gr.Checkbox(label="Override with custom score", value=False)
396
+ custom_score = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Custom WAB AQ Score",
397
+ visible=False)
398
+
399
+ # Make custom score visible only when checkbox is selected
400
+ use_custom_score.change(lambda x: gr.update(visible=x), inputs=[use_custom_score], outputs=[custom_score])
401
+
402
+ generate_btn = gr.Button("Generate Functional Connectivity", variant="primary")
403
+
404
+ with gr.Column(scale=2):
405
+ with gr.Row():
406
+ fc_image = gr.Image(label="Functional Connectivity Matrix", show_download_button=True)
407
+ fc_summary = gr.Markdown(label="Analysis Summary")
408
+
409
+ fc_data = gr.DataFrame(label="Top FC Connections")
410
+
411
+ # Generate FC on button click
412
+ generate_btn.click(
413
+ generate_fc_visualization,
414
+ inputs=[age, mpo, education, gender, handedness,
415
+ aphasia_severity, lesion_size,
416
+ use_custom_score, custom_score],
417
+ outputs=[fc_image, fc_summary, fc_data]
418
+ )
419
+
420
+ with gr.Tab("Train Model"):
421
+ gr.Markdown("### Train or Retrain the Model")
422
+ gr.Markdown("""
423
+ This tab allows you to train the two-stage model:
424
+ 1. First stage: DemoVAE model learns brain connectivity patterns
425
+ 2. Second stage: Random Forest predicts aphasia scores
426
+
427
+ Note: This will download data from HuggingFace 'SreekarB/OSFData' and use the full dataset for training.
428
+ """)
429
+
430
+ train_btn = gr.Button("Train Model", variant="primary")
431
+ train_output = gr.Textbox(label="Training Status", lines=20)
432
+
433
+ train_btn.click(train_model, inputs=[], outputs=[train_output])
434
+
435
+ gr.Markdown("## How to use")
436
+ gr.Markdown("""
437
+ 1. Set the patient's demographic information and aphasia details
438
+ 2. Click "Generate Functional Connectivity" to see the visualization and prediction
439
+ 3. Optionally, override the model's prediction with your own custom score
440
+ 4. If the model is not trained, go to the "Train Model" tab to train it first
441
+
442
+ The heatmap shows correlations between brain regions. Yellow indicates positive correlations (regions that activate together),
443
+ green indicates neutral correlations, and blue indicates negative correlations (regions with opposing activation patterns).
444
+ """)
445
+
446
+ if __name__ == "__main__":
447
+ # Set up the optimal launch configuration for Hugging Face Spaces
448
+ if IS_SPACE:
449
+ demo.launch(server_name="0.0.0.0", share=False)
450
+ else:
451
+ demo.launch()
osf_demovae_adapter.py ADDED
@@ -0,0 +1,1369 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ OSF DemoVAE Adapter for HuggingFace Spaces
3
+
4
+ This script is optimized for running in HuggingFace Spaces.
5
+ It loads functional connectivity data from SreekarB/OSFData and
6
+ pretrained models from SreekarB/OSFModels.
7
+ """
8
+
9
+ import os
10
+ import numpy as np
11
+ import pandas as pd
12
+ import nibabel as nib
13
+ import sys
14
+ import tempfile
15
+ import pickle
16
+ import json
17
+ from pathlib import Path
18
+ from tqdm import tqdm
19
+
20
+ # Import HuggingFace libraries
21
+ from datasets import load_dataset
22
+ from huggingface_hub import hf_hub_download, HfApi, list_repo_files
23
+
24
+ # Import PyTorch if available (needed for DemoVAE)
25
+ try:
26
+ import torch
27
+ TORCH_AVAILABLE = True
28
+ except ImportError:
29
+ TORCH_AVAILABLE = False
30
+ print("WARNING: PyTorch not available. Some functionality may be limited.")
31
+
32
+ # Add PIP package to path - in Spaces this will be the correct path
33
+ sys.path.append(os.path.join(os.path.dirname(__file__), 'pip', 'src'))
34
+
35
+ # Only try to import DemoVAE if PyTorch is available
36
+ if TORCH_AVAILABLE:
37
+ try:
38
+ from demovae.model import to_torch, to_cuda, to_numpy, VAE
39
+ from demovae.sklearn import DemoVAE
40
+ DEMOVAE_AVAILABLE = True
41
+ except ImportError:
42
+ DEMOVAE_AVAILABLE = False
43
+ print("WARNING: DemoVAE package not found. Only data loading will be available.")
44
+ else:
45
+ DEMOVAE_AVAILABLE = False
46
+ print("WARNING: PyTorch not available. DemoVAE functionality disabled.")
47
+
48
+ # Constants for HuggingFace Spaces
49
+ DATA_REPO = "SreekarB/OSFData" # HuggingFace dataset repository
50
+ MODEL_REPO = "SreekarB/OSFModels" # HuggingFace model repository
51
+ IS_SPACE = os.environ.get('SPACE_ID') is not None # Check if running in HF Spaces
52
+
53
+ # Define paths for cached data - in Spaces, use /tmp for temporary storage
54
+ CACHE_DIR = "/tmp/osf_data" if IS_SPACE else "./cache"
55
+ os.makedirs(CACHE_DIR, exist_ok=True)
56
+
57
+ # Model URLs and filenames
58
+ MODEL_FILES = {
59
+ "demovae": "osf_demovae_model.pt",
60
+ "prediction": "aphasia_prediction_model.pkl",
61
+ "combined": "demovae_and_prediction_models.pkl"
62
+ }
63
+
64
+ # Atlas region labels for the FC matrix
65
+ # These are the standard AAL (Automated Anatomical Labeling) atlas brain regions
66
+ # We use 45 regions typically used in functional connectivity matrices
67
+ ATLAS_REGIONS = [
68
+ "Frontal_Sup_L", "Frontal_Sup_R", "Frontal_Mid_L", "Frontal_Mid_R",
69
+ "Frontal_Inf_Oper_L", "Frontal_Inf_Oper_R", "Frontal_Inf_Tri_L", "Frontal_Inf_Tri_R",
70
+ "Frontal_Inf_Orb_L", "Frontal_Inf_Orb_R", "Rolandic_Oper_L", "Rolandic_Oper_R",
71
+ "SMA_L", "SMA_R", "Olfactory_L", "Olfactory_R",
72
+ "Frontal_Med_Orb_L", "Frontal_Med_Orb_R", "Rectus_L", "Rectus_R",
73
+ "Insula_L", "Insula_R", "Cingulum_Ant_L", "Cingulum_Ant_R",
74
+ "Cingulum_Mid_L", "Cingulum_Mid_R", "Cingulum_Post_L", "Cingulum_Post_R",
75
+ "Hippocampus_L", "Hippocampus_R", "ParaHippocampal_L", "ParaHippocampal_R",
76
+ "Amygdala_L", "Amygdala_R", "Calcarine_L", "Calcarine_R",
77
+ "Cuneus_L", "Cuneus_R", "Lingual_L", "Lingual_R",
78
+ "Occipital_Sup_L", "Occipital_Sup_R", "Occipital_Mid_L", "Occipital_Mid_R",
79
+ "Occipital_Inf_L"
80
+ ]
81
+
82
+ # Ensure we have exactly 45 regions
83
+ if len(ATLAS_REGIONS) != 45:
84
+ print(f"Warning: Expected 45 regions but have {len(ATLAS_REGIONS)}. Using generic labels.")
85
+ ATLAS_REGIONS = [f"Region_{i+1}" for i in range(45)]
86
+
87
+ # Utility functions for functional connectivity
88
+ def mat2vec(fc):
89
+ """Convert matrix to vector"""
90
+ d = fc.shape[0]
91
+ a, b = np.triu_indices(d, 1)
92
+ return fc[a, b]
93
+
94
+ def vec2mat(fc):
95
+ """Convert vector to matrix"""
96
+ d = int(round((1+(1+8*fc.size)**0.5)/2))
97
+ a, b = np.triu_indices(d, 1)
98
+ mat = np.zeros((d, d))
99
+ mat[a, b] = fc
100
+ mat += mat.T
101
+ ones = np.arange(d)
102
+ mat[ones, ones] = 1
103
+ return mat
104
+
105
+ def load_nii_from_huggingface(file_name):
106
+ """
107
+ Load a NII file from HuggingFace and return as a nibabel image object.
108
+ Optimized for HuggingFace Spaces.
109
+
110
+ Args:
111
+ file_name: Name of the NII file (e.g., "P01_rs.nii")
112
+
113
+ Returns:
114
+ Nibabel image object
115
+ """
116
+ print(f"Loading {file_name} from {DATA_REPO}...")
117
+
118
+ # Check cache first to avoid repeated downloads
119
+ cache_path = os.path.join(CACHE_DIR, file_name)
120
+ if os.path.exists(cache_path):
121
+ try:
122
+ print(f"Loading from cache: {cache_path}")
123
+ nii_img = nib.load(cache_path)
124
+ return nii_img
125
+ except Exception as e:
126
+ print(f"Error loading from cache: {e}")
127
+ # If cache is corrupted, remove it
128
+ os.remove(cache_path)
129
+
130
+ try:
131
+ # First try direct download using huggingface_hub
132
+ try:
133
+ # Use hf_hub_download which works well in Spaces
134
+ tmp_path = hf_hub_download(
135
+ repo_id=DATA_REPO,
136
+ filename=file_name,
137
+ repo_type="dataset",
138
+ cache_dir=CACHE_DIR
139
+ )
140
+
141
+ # Load with nibabel
142
+ nii_img = nib.load(tmp_path)
143
+
144
+ # Make a copy in our cache dir for easier access
145
+ if tmp_path != cache_path:
146
+ import shutil
147
+ shutil.copy2(tmp_path, cache_path)
148
+
149
+ return nii_img
150
+
151
+ except Exception as e:
152
+ print(f"Direct download failed: {e}")
153
+
154
+ # Try dataset loading approach
155
+ try:
156
+ dataset = load_dataset(DATA_REPO, data_files=file_name)
157
+
158
+ if "train" in dataset and len(dataset["train"]) > 0:
159
+ if file_name in dataset["train"][0]:
160
+ nii_data = dataset["train"][0][file_name]
161
+
162
+ # Save to our cache
163
+ with open(cache_path, 'wb') as f:
164
+ f.write(nii_data)
165
+
166
+ # Load with nibabel
167
+ nii_img = nib.load(cache_path)
168
+ return nii_img
169
+
170
+ except Exception as e:
171
+ print(f"Dataset loading failed: {e}")
172
+
173
+ # Check if we're in Spaces and the file might be available in the local repository
174
+ if IS_SPACE:
175
+ local_paths = [
176
+ os.path.join("/app/OSFData", file_name), # Standard location in Spaces
177
+ os.path.join("OSFData", file_name), # Repository root
178
+ os.path.join("data", file_name) # Common data directory
179
+ ]
180
+
181
+ for path in local_paths:
182
+ if os.path.exists(path):
183
+ print(f"Loading from local Spaces file: {path}")
184
+ nii_img = nib.load(path)
185
+ return nii_img
186
+
187
+ # All methods failed
188
+ print(f"Could not load {file_name} using any method")
189
+ return None
190
+
191
+ except Exception as e:
192
+ print(f"Error loading {file_name}: {e}")
193
+ return None
194
+
195
+ def nii_to_fc_matrix(nii_img, target_size=45):
196
+ """
197
+ Process a NIfTI image to extract functional connectivity matrix
198
+
199
+ Args:
200
+ nii_img: Nibabel image object
201
+ target_size: Target size for the FC matrix (default 45x45)
202
+
203
+ Returns:
204
+ Functional connectivity matrix of consistent size
205
+ """
206
+ try:
207
+ # Get time series data
208
+ print(" • Getting time series data...")
209
+ time_series = nii_img.get_fdata()
210
+ print(f" Time series shape: {time_series.shape}")
211
+
212
+ # Ensure we have 4D data (three spatial dimensions + time)
213
+ if len(time_series.shape) < 4:
214
+ print(f" ✗ Error: Expected 4D data but got {len(time_series.shape)}D data")
215
+ return None
216
+
217
+ # Reshape to (voxels, time)
218
+ print(" • Reshaping time series data...")
219
+ orig_shape = time_series.shape
220
+ time_series = time_series.reshape(orig_shape[0]*orig_shape[1]*orig_shape[2], orig_shape[3])
221
+ print(f" Reshaped to {time_series.shape}")
222
+
223
+ # Remove NaN and infinity values
224
+ print(" • Cleaning data (removing NaN and infinity)...")
225
+ nan_count = np.isnan(time_series).sum()
226
+ inf_count = np.isinf(time_series).sum()
227
+ if nan_count > 0 or inf_count > 0:
228
+ print(f" Found {nan_count} NaN values and {inf_count} infinity values")
229
+ time_series = np.nan_to_num(time_series, nan=0, posinf=0, neginf=0)
230
+
231
+ # Filter out voxels with no signal
232
+ print(" • Filtering voxels with no signal...")
233
+ signal_mask = np.std(time_series, axis=1) > 0
234
+ active_voxels = np.sum(signal_mask)
235
+ print(f" Active voxels: {active_voxels} out of {time_series.shape[0]} ({active_voxels/time_series.shape[0]*100:.1f}%)")
236
+
237
+ if active_voxels < 10:
238
+ print(" ✗ Error: Too few active voxels for reliable correlation")
239
+ return None
240
+
241
+ time_series = time_series[signal_mask]
242
+
243
+ # Ensure we don't have too many voxels - sample if needed for consistency
244
+ if time_series.shape[0] > 10000:
245
+ print(f" • Sampling voxels to reduce computational load...")
246
+ np.random.seed(42) # For reproducibility
247
+ sample_indices = np.random.choice(time_series.shape[0], 10000, replace=False)
248
+ time_series = time_series[sample_indices]
249
+ print(f" Sampled to {time_series.shape[0]} voxels")
250
+
251
+ # Calculate correlation matrix
252
+ print(" • Calculating functional connectivity matrix...")
253
+ fc_matrix = np.corrcoef(time_series.T)
254
+ print(f" Raw FC matrix shape: {fc_matrix.shape}")
255
+
256
+ # Validate the matrix
257
+ if np.isnan(fc_matrix).any():
258
+ print(" ✗ Error: FC matrix contains NaN values")
259
+ return None
260
+
261
+ if np.isinf(fc_matrix).any():
262
+ print(" ✗ Error: FC matrix contains infinity values")
263
+ return None
264
+
265
+ # Ensure consistent matrix size (45x45) for the FC matrix
266
+ if fc_matrix.shape[0] != target_size:
267
+ print(f" • Resizing FC matrix to standard {target_size}x{target_size} size...")
268
+
269
+ if fc_matrix.shape[0] > target_size:
270
+ # Take the first target_size x target_size submatrix
271
+ fc_matrix = fc_matrix[:target_size, :target_size]
272
+ else:
273
+ # Pad with zeros
274
+ padded_matrix = np.zeros((target_size, target_size))
275
+ padded_matrix[:fc_matrix.shape[0], :fc_matrix.shape[1]] = fc_matrix
276
+
277
+ # Make sure the diagonal is 1
278
+ np.fill_diagonal(padded_matrix, 1)
279
+
280
+ fc_matrix = padded_matrix
281
+
282
+ print(f" Final FC matrix shape: {fc_matrix.shape}")
283
+
284
+ print(" ✓ Functional connectivity matrix successfully calculated")
285
+ return fc_matrix
286
+
287
+ except Exception as e:
288
+ print(f" ✗ Error processing NIfTI data: {e}")
289
+ return None
290
+
291
+ def get_all_nii_files():
292
+ """
293
+ Discover all NII files available in the HuggingFace dataset
294
+
295
+ Returns:
296
+ List of NII filenames
297
+ """
298
+ print("Discovering NII files in SreekarB/OSFData repository...")
299
+
300
+ # Check cache first
301
+ cache_file = os.path.join(CACHE_DIR, "nii_files_list.json")
302
+ if os.path.exists(cache_file):
303
+ try:
304
+ with open(cache_file, 'r') as f:
305
+ nii_files = json.load(f)
306
+ print(f"Loaded {len(nii_files)} NII files from cache")
307
+ return nii_files
308
+ except Exception as e:
309
+ print(f"Error loading from cache: {e}")
310
+
311
+ try:
312
+ # Try to list repository files using HF API
313
+ try:
314
+ api = HfApi()
315
+ files = api.list_repo_files(DATA_REPO, repo_type="dataset")
316
+ nii_files = [f for f in files if f.endswith('.nii')]
317
+
318
+ if nii_files:
319
+ print(f"Found {len(nii_files)} NII files in the repository")
320
+ # Save to cache
321
+ with open(cache_file, 'w') as f:
322
+ json.dump(nii_files, f)
323
+ return nii_files
324
+ except Exception as e:
325
+ print(f"Could not list files via API: {e}")
326
+
327
+ # Try loading demo data to check what files exist
328
+ print("Trying to load demographic data to find subject IDs...")
329
+ try:
330
+ demo_data = load_dataset(DATA_REPO, data_files="FC_graph_covariate_data.csv")
331
+ if "train" in demo_data:
332
+ df = pd.DataFrame(demo_data["train"])
333
+ if "ID" in df.columns:
334
+ subject_ids = df["ID"].unique()
335
+ print(f"Found {len(subject_ids)} subject IDs in demographic data")
336
+ nii_files = [f"{subject_id}_rs.nii" for subject_id in subject_ids]
337
+
338
+ # Save to cache
339
+ with open(cache_file, 'w') as f:
340
+ json.dump(nii_files, f)
341
+ return nii_files
342
+ except Exception as e:
343
+ print(f"Error loading demographic data: {e}")
344
+
345
+ # Fallback to standard pattern - assuming 30 subjects (P01 to P30)
346
+ print("Using default pattern for 30 subjects (P01-P30)...")
347
+ nii_files = [f"P{i:02d}_rs.nii" for i in range(1, 31)]
348
+
349
+ # Save to cache
350
+ with open(cache_file, 'w') as f:
351
+ json.dump(nii_files, f)
352
+ return nii_files
353
+
354
+ except Exception as e:
355
+ print(f"Error discovering NII files: {e}")
356
+ return [f"P{i:02d}_rs.nii" for i in range(1, 31)]
357
+
358
+ def download_model(model_type):
359
+ """
360
+ Download a pretrained model from the SreekarB/OSFModels repository
361
+
362
+ Args:
363
+ model_type: Type of model to download ('demovae', 'prediction', or 'combined')
364
+
365
+ Returns:
366
+ Path to the downloaded model file
367
+ """
368
+ if model_type not in MODEL_FILES:
369
+ raise ValueError(f"Unknown model type: {model_type}. Available types: {list(MODEL_FILES.keys())}")
370
+
371
+ model_file = MODEL_FILES[model_type]
372
+ cache_path = os.path.join(CACHE_DIR, model_file)
373
+
374
+ # Check if the model is already cached
375
+ if os.path.exists(cache_path):
376
+ print(f"Using cached model: {model_file}")
377
+ return cache_path
378
+
379
+ print(f"Downloading {model_type} model from {MODEL_REPO}...")
380
+ try:
381
+ # Download the model using huggingface_hub
382
+ downloaded_path = hf_hub_download(
383
+ repo_id=MODEL_REPO,
384
+ filename=model_file,
385
+ repo_type="model",
386
+ cache_dir=CACHE_DIR
387
+ )
388
+
389
+ print(f"Model downloaded to: {downloaded_path}")
390
+ return downloaded_path
391
+
392
+ except Exception as e:
393
+ print(f"Error downloading model: {e}")
394
+ raise
395
+
396
+ def load_demographic_data():
397
+ """
398
+ Load demographic data from HuggingFace dataset
399
+
400
+ Returns:
401
+ Pandas DataFrame with demographic data
402
+ """
403
+ print("Loading demographic data from SreekarB/OSFData...")
404
+
405
+ # Cache path for demographic data
406
+ cache_path = os.path.join(CACHE_DIR, "demographic_data.csv")
407
+
408
+ # Check if data is already cached
409
+ if os.path.exists(cache_path):
410
+ try:
411
+ print(f"Loading demographic data from cache: {cache_path}")
412
+ df = pd.read_csv(cache_path)
413
+ print(f"Loaded data for {len(df)} subjects with columns: {', '.join(df.columns)}")
414
+ return df
415
+ except Exception as e:
416
+ print(f"Error loading cached demographic data: {e}")
417
+
418
+ try:
419
+ # Try to download file using huggingface_hub
420
+ try:
421
+ demo_path = hf_hub_download(
422
+ repo_id=DATA_REPO,
423
+ filename="FC_graph_covariate_data.csv",
424
+ repo_type="dataset",
425
+ cache_dir=CACHE_DIR
426
+ )
427
+
428
+ df = pd.read_csv(demo_path)
429
+ df.to_csv(cache_path, index=False) # Save to our cache
430
+ print(f"Loaded demographic data with {len(df)} subjects")
431
+ return df
432
+
433
+ except Exception as e:
434
+ print(f"Error downloading demographic file: {e}")
435
+
436
+ # Try loading using datasets library
437
+ try:
438
+ demo_dataset = load_dataset(DATA_REPO, data_files="FC_graph_covariate_data.csv")
439
+ if "train" in demo_dataset:
440
+ df = pd.DataFrame(demo_dataset["train"])
441
+ df.to_csv(cache_path, index=False) # Save to our cache
442
+ print(f"Loaded demographic data with {len(df)} subjects")
443
+ return df
444
+ except Exception as e:
445
+ print(f"Error loading demographic data via datasets: {e}")
446
+
447
+ # Check if file exists locally in Spaces
448
+ if IS_SPACE:
449
+ local_paths = [
450
+ os.path.join("/app/OSFData", "FC_graph_covariate_data.csv"),
451
+ os.path.join("OSFData", "FC_graph_covariate_data.csv"),
452
+ os.path.join("data", "FC_graph_covariate_data.csv")
453
+ ]
454
+
455
+ for path in local_paths:
456
+ if os.path.exists(path):
457
+ df = pd.read_csv(path)
458
+ df.to_csv(cache_path, index=False) # Save to our cache
459
+ print(f"Loaded demographic data from local file: {path}")
460
+ return df
461
+
462
+ raise FileNotFoundError("Could not load demographic data using any method")
463
+
464
+ except Exception as e:
465
+ print(f"Could not load demographic data: {e}")
466
+ # Create a synthetic dataset for testing
467
+ print("Creating synthetic demographic data for testing")
468
+ data = []
469
+ for i in range(1, 31):
470
+ subject_id = f"P{i:02d}"
471
+ data.append({
472
+ "ID": subject_id,
473
+ "age": np.random.randint(40, 80),
474
+ "gender": "M" if np.random.random() > 0.5 else "F",
475
+ "handedness": "R" if np.random.random() > 0.2 else "L",
476
+ "education": np.random.randint(8, 20),
477
+ "mpo": np.random.randint(1, 24),
478
+ "wab_aq": np.random.uniform(20, 90)
479
+ })
480
+ df = pd.DataFrame(data)
481
+ df.to_csv(cache_path, index=False)
482
+ return df
483
+
484
+ def load_models():
485
+ """
486
+ Load pretrained models from HuggingFace
487
+
488
+ Returns:
489
+ Dictionary containing loaded models
490
+ """
491
+ if not DEMOVAE_AVAILABLE:
492
+ print("DemoVAE package not available. Cannot load models.")
493
+ return None
494
+
495
+ print("Loading pretrained models from SreekarB/OSFModels...")
496
+
497
+ try:
498
+ # Try to load the combined model first (contains both VAE and prediction model)
499
+ try:
500
+ combined_path = download_model("combined")
501
+
502
+ with open(combined_path, 'rb') as f:
503
+ models = pickle.load(f)
504
+
505
+ print("Successfully loaded combined models")
506
+ return models
507
+
508
+ except Exception as e:
509
+ print(f"Error loading combined model: {e}")
510
+ print("Trying to load individual models...")
511
+
512
+ # Try loading individual models
513
+ models = {}
514
+
515
+ # Load DemoVAE model
516
+ try:
517
+ vae_path = download_model("demovae")
518
+
519
+ # Load the model - adapting based on file type
520
+ if vae_path.endswith('.pt'):
521
+ # PyTorch model file
522
+ demovae_model = DemoVAE()
523
+ demovae_model.load(vae_path)
524
+ else:
525
+ # Pickle file
526
+ with open(vae_path, 'rb') as f:
527
+ demovae_model = pickle.load(f)
528
+
529
+ models["demovae"] = demovae_model
530
+ print("Successfully loaded DemoVAE model")
531
+
532
+ except Exception as e:
533
+ print(f"Error loading DemoVAE model: {e}")
534
+
535
+ # Load prediction model
536
+ try:
537
+ pred_path = download_model("prediction")
538
+
539
+ with open(pred_path, 'rb') as f:
540
+ pred_model = pickle.load(f)
541
+
542
+ models["prediction"] = pred_model
543
+ print("Successfully loaded prediction model")
544
+
545
+ except Exception as e:
546
+ print(f"Error loading prediction model: {e}")
547
+
548
+ if models:
549
+ return models
550
+ else:
551
+ print("Could not load any models")
552
+ return None
553
+
554
+ except Exception as e:
555
+ print(f"Error loading models: {e}")
556
+ return None
557
+
558
+ def load_and_process_data(quick_test=False):
559
+ """
560
+ Load and process data from OSF dataset on HuggingFace
561
+
562
+ Args:
563
+ quick_test (bool): If True, use a small subset of data for testing
564
+
565
+ Returns:
566
+ Tuple of (FC features, demographics, WAB scores, improvement scores, dataframe)
567
+ """
568
+ print("Loading and processing OSF data from HuggingFace...")
569
+
570
+ # Step 1: Load demographic data
571
+ demo_df = load_demographic_data()
572
+ if demo_df is None or len(demo_df) == 0:
573
+ raise ValueError("Could not load demographic data")
574
+
575
+ # Step 2: Get list of NII files to process
576
+ nii_files = get_all_nii_files()
577
+ if not nii_files:
578
+ raise ValueError("No NII files found")
579
+
580
+ # Process appropriate number of subjects
581
+ max_subjects = 5 if quick_test else len(nii_files)
582
+ print(f"Processing {'a subset of' if quick_test else 'all'} subjects: {max_subjects}")
583
+
584
+ # Step 3: Process NII files to FC matrices
585
+ fc_data = {}
586
+
587
+ # Use the demo_df to get subject IDs if available
588
+ if 'ID' in demo_df.columns:
589
+ subject_ids = demo_df['ID'].unique()[:max_subjects]
590
+
591
+ for subject_id in tqdm(subject_ids, desc="Processing subjects"):
592
+ nii_file = f"{subject_id}_rs.nii"
593
+ if nii_file in nii_files:
594
+ fc_matrix, fc_vector = process_subject(subject_id, nii_file)
595
+ if fc_matrix is not None and fc_vector is not None:
596
+ fc_data[subject_id] = {
597
+ "matrix": fc_matrix,
598
+ "vector": fc_vector
599
+ }
600
+ else:
601
+ # Use the first max_subjects NII files
602
+ for i, nii_file in enumerate(nii_files[:max_subjects]):
603
+ subject_id = nii_file.split('_')[0] # Extract subject ID (e.g., P01)
604
+ fc_matrix, fc_vector = process_subject(subject_id, nii_file)
605
+ if fc_matrix is not None and fc_vector is not None:
606
+ fc_data[subject_id] = {
607
+ "matrix": fc_matrix,
608
+ "vector": fc_vector
609
+ }
610
+
611
+ if not fc_data:
612
+ raise ValueError("Failed to process any functional connectivity data")
613
+
614
+ # Step 4: Create matrices for model training
615
+ # Extract demographic features
616
+ processed_subjects = sorted(fc_data.keys())
617
+ subject_indices = [demo_df.index[demo_df['ID'] == subject_id].tolist()[0] for subject_id in processed_subjects if subject_id in demo_df['ID'].values]
618
+
619
+ if not subject_indices:
620
+ # If no matching subjects, create dummy demographics
621
+ X_demo = np.zeros((len(processed_subjects), 5)) # age, mpo, education, gender, handedness
622
+ else:
623
+ # Extract demographic features for subjects we have
624
+ X_demo = np.zeros((len(processed_subjects), 5))
625
+
626
+ for i, subject_id in enumerate(processed_subjects):
627
+ if subject_id in demo_df['ID'].values:
628
+ subject_row = demo_df[demo_df['ID'] == subject_id].iloc[0]
629
+ X_demo[i, 0] = subject_row.get('age', 60)
630
+ X_demo[i, 1] = subject_row.get('mpo', 6)
631
+ X_demo[i, 2] = subject_row.get('education', 12)
632
+ X_demo[i, 3] = 1 if subject_row.get('gender', 'M') == 'M' else 0
633
+ X_demo[i, 4] = 1 if subject_row.get('handedness', 'R') == 'R' else 0
634
+
635
+ # Extract FC features - need to handle potentially different vector sizes
636
+ try:
637
+ # First try direct conversion - works if all vectors have the same length
638
+ X_fc = np.array([fc_data[subject_id]["vector"] for subject_id in processed_subjects])
639
+ except ValueError as e:
640
+ print(f"Warning: FC vectors have inconsistent shapes: {e}")
641
+
642
+ # Get all vector lengths
643
+ vector_lengths = [len(fc_data[subject_id]["vector"]) for subject_id in processed_subjects]
644
+ print(f"Vector lengths: min={min(vector_lengths)}, max={max(vector_lengths)}")
645
+
646
+ # Find most common vector length
647
+ from collections import Counter
648
+ length_counts = Counter(vector_lengths)
649
+ most_common_length = length_counts.most_common(1)[0][0]
650
+ print(f"Most common vector length: {most_common_length} (occurs {length_counts[most_common_length]} times)")
651
+
652
+ # Filter to keep only subjects with the most common vector length
653
+ consistent_subjects = [subject_id for subject_id in processed_subjects
654
+ if len(fc_data[subject_id]["vector"]) == most_common_length]
655
+
656
+ if len(consistent_subjects) < 3:
657
+ # If we don't have enough subjects with consistent vector lengths,
658
+ # pad/truncate vectors to the most common length
659
+ print(f"Too few subjects with consistent vector length. Padding/truncating all vectors.")
660
+ padded_vectors = []
661
+ for subject_id in processed_subjects:
662
+ vec = fc_data[subject_id]["vector"]
663
+ if len(vec) < most_common_length:
664
+ # Pad with zeros
665
+ padded_vec = np.zeros(most_common_length)
666
+ padded_vec[:len(vec)] = vec
667
+ else:
668
+ # Truncate
669
+ padded_vec = vec[:most_common_length]
670
+ padded_vectors.append(padded_vec)
671
+ X_fc = np.array(padded_vectors)
672
+ # Keep all subjects
673
+ else:
674
+ # Use only subjects with consistent vector length
675
+ print(f"Using {len(consistent_subjects)} subjects with consistent vector length {most_common_length}")
676
+ X_fc = np.array([fc_data[subject_id]["vector"] for subject_id in consistent_subjects])
677
+ # Update processed_subjects to only include those with consistent vectors
678
+ processed_subjects = consistent_subjects
679
+
680
+ print(f"Final X_fc shape: {X_fc.shape}")
681
+
682
+ # Get WAB scores if available
683
+ if 'wab_aq' in demo_df.columns:
684
+ y_wab = np.array([demo_df[demo_df['ID'] == subject_id]['wab_aq'].values[0]
685
+ if subject_id in demo_df['ID'].values and not pd.isna(demo_df[demo_df['ID'] == subject_id]['wab_aq'].values[0])
686
+ else np.random.uniform(20, 80)
687
+ for subject_id in processed_subjects])
688
+ else:
689
+ # Generate random WAB scores for testing
690
+ y_wab = np.random.uniform(20, 80, len(processed_subjects))
691
+
692
+ # Generate synthetic improvement data if needed
693
+ improvement_data = []
694
+ for subject_id in processed_subjects:
695
+ improvement = np.random.uniform(0, 40) if 'improvement' not in demo_df.columns else None
696
+ if 'improvement' in demo_df.columns and subject_id in demo_df['ID'].values:
697
+ subj_impr = demo_df[demo_df['ID'] == subject_id]['improvement'].values
698
+ if len(subj_impr) > 0 and not pd.isna(subj_impr[0]):
699
+ improvement = subj_impr[0]
700
+
701
+ improvement_data.append({
702
+ 'ID': subject_id,
703
+ 'improvement': improvement if improvement is not None else np.random.uniform(0, 40)
704
+ })
705
+
706
+ # Convert to DataFrame
707
+ improvement_df = pd.DataFrame(improvement_data)
708
+
709
+ # Merge with demo_df
710
+ merged_df = pd.merge(demo_df, improvement_df, on='ID', how='left')
711
+
712
+ # Get improvement values
713
+ y_improvement = np.array([improvement_df[improvement_df['ID'] == subject_id]['improvement'].values[0]
714
+ for subject_id in processed_subjects])
715
+
716
+ print(f"Processed data for {len(processed_subjects)} subjects")
717
+ print(f"X_fc shape: {X_fc.shape}")
718
+ print(f"X_demo shape: {X_demo.shape}")
719
+ print(f"y_wab shape: {y_wab.shape}")
720
+ print(f"y_improvement shape: {y_improvement.shape}")
721
+
722
+ return X_fc, X_demo, y_wab, y_improvement, merged_df
723
+
724
+ def plot_connectivity_matrix(fc_matrix, subject_id=None, save_path=None, show_labels=True):
725
+ """
726
+ Plot a functional connectivity matrix with region labels and a custom color scheme
727
+
728
+ Args:
729
+ fc_matrix: The functional connectivity matrix to plot
730
+ subject_id: Subject identifier for the title
731
+ save_path: Path to save the figure, if None just display
732
+ show_labels: Whether to show region labels on the plot
733
+
734
+ Returns:
735
+ The matplotlib figure object
736
+ """
737
+ import matplotlib.pyplot as plt
738
+ from matplotlib.colors import LinearSegmentedColormap
739
+ import numpy as np
740
+
741
+ # Ensure we have matplotlib
742
+ if 'plt' not in locals():
743
+ print("Matplotlib is required for plotting connectivity matrices")
744
+ return None
745
+
746
+ # Create a custom colormap (yellow-green-blue)
747
+ colors = [(1, 1, 0), # Yellow for positive correlations
748
+ (0, 1, 0), # Green for moderate correlations
749
+ (0, 0, 1)] # Blue for negative correlations
750
+
751
+ n_bins = 256 # Number of discrete colors
752
+ custom_cmap = LinearSegmentedColormap.from_list("YellowGreenBlue", colors, N=n_bins)
753
+
754
+ # Create figure
755
+ fig_size = 12 if show_labels else 8
756
+ fig, ax = plt.subplots(figsize=(fig_size, fig_size))
757
+
758
+ # Plot the connectivity matrix
759
+ im = ax.imshow(fc_matrix, cmap=custom_cmap, vmin=-1, vmax=1)
760
+
761
+ # Add a title
762
+ title = "Functional Connectivity Matrix"
763
+ if subject_id:
764
+ title += f" - Subject {subject_id}"
765
+ ax.set_title(title, fontsize=14)
766
+
767
+ # Add labels if requested
768
+ if show_labels and fc_matrix.shape[0] <= len(ATLAS_REGIONS):
769
+ # Get labels for the matrix size
770
+ labels = ATLAS_REGIONS[:fc_matrix.shape[0]]
771
+
772
+ # Add x and y labels
773
+ ax.set_xticks(np.arange(len(labels)))
774
+ ax.set_yticks(np.arange(len(labels)))
775
+ ax.set_xticklabels(labels, rotation=90, fontsize=8)
776
+ ax.set_yticklabels(labels, fontsize=8)
777
+
778
+ # Add grid lines
779
+ ax.set_xticks(np.arange(-.5, len(labels), 1), minor=True)
780
+ ax.set_yticks(np.arange(-.5, len(labels), 1), minor=True)
781
+ ax.grid(which='minor', color='gray', linestyle='-', linewidth=0.5, alpha=0.3)
782
+ else:
783
+ # Just add numbers
784
+ ax.set_xlabel("Brain Region Index")
785
+ ax.set_ylabel("Brain Region Index")
786
+
787
+ # Add colorbar
788
+ cbar = fig.colorbar(im, ax=ax)
789
+ cbar.set_label("Correlation Strength", rotation=270, labelpad=15)
790
+
791
+ # Add annotations explaining the color scheme
792
+ fig.text(0.01, 0.01, "Color scheme: Yellow (positive correlation), Green (neutral), Blue (negative correlation)",
793
+ fontsize=8, ha='left')
794
+
795
+ # Add explanation of matrix content
796
+ explanation = (
797
+ "This matrix shows the functional connectivity between brain regions.\n"
798
+ "Each cell represents the correlation of activity between two regions.\n"
799
+ "Positive values (yellow) indicate regions that activate together.\n"
800
+ "Negative values (blue) indicate regions with opposite activation patterns."
801
+ )
802
+
803
+ ax.annotate(explanation, xy=(0.5, -0.15), xycoords='axes fraction',
804
+ ha='center', va='center', fontsize=9,
805
+ bbox=dict(boxstyle='round', fc='lavender', alpha=0.8))
806
+
807
+ # Tighten layout
808
+ plt.tight_layout()
809
+
810
+ # Save if requested
811
+ if save_path:
812
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
813
+ print(f"Figure saved to {save_path}")
814
+
815
+ return fig
816
+
817
+
818
+ def process_subject(subject_id, nii_file):
819
+ """
820
+ Process a single subject's NII file to FC matrix
821
+
822
+ Args:
823
+ subject_id: Subject identifier
824
+ nii_file: NII filename
825
+
826
+ Returns:
827
+ Tuple of (FC matrix, FC vector) or (None, None) if processing fails
828
+ """
829
+ print(f"\nProcessing {subject_id} from {nii_file}...")
830
+
831
+ # Check if we already have cached results
832
+ matrix_cache = os.path.join(CACHE_DIR, f"{subject_id}_fc_matrix.npy")
833
+ vector_cache = os.path.join(CACHE_DIR, f"{subject_id}_fc_vector.npy")
834
+
835
+ if os.path.exists(matrix_cache) and os.path.exists(vector_cache):
836
+ try:
837
+ print(f"Loading cached FC data for {subject_id}")
838
+ fc_matrix = np.load(matrix_cache)
839
+ fc_vector = np.load(vector_cache)
840
+ return fc_matrix, fc_vector
841
+ except Exception as e:
842
+ print(f"Error loading cached FC data: {e}")
843
+
844
+ try:
845
+ # Step 1: Load NII file
846
+ nii_img = load_nii_from_huggingface(nii_file)
847
+
848
+ if nii_img is None:
849
+ print(f"Failed to load {nii_file}. Skipping {subject_id}.")
850
+ return None, None
851
+
852
+ # Step 2: Convert to FC matrix
853
+ fc_matrix = nii_to_fc_matrix(nii_img)
854
+
855
+ if fc_matrix is None:
856
+ print(f"Failed to process NII data to FC matrix. Skipping {subject_id}.")
857
+ return None, None
858
+
859
+ # Step 3: Extract FC vector (upper triangle)
860
+ fc_vector = mat2vec(fc_matrix)
861
+
862
+ # Save to cache
863
+ np.save(matrix_cache, fc_matrix)
864
+ np.save(vector_cache, fc_vector)
865
+
866
+ return fc_matrix, fc_vector
867
+
868
+ except Exception as e:
869
+ print(f"Error processing {subject_id}: {e}")
870
+ return None, None
871
+
872
+ def get_connectivity_visualization(fc_matrix, subject_id=None, output_path=None):
873
+ """
874
+ Generate a connectivity visualization for use in a web interface
875
+
876
+ Args:
877
+ fc_matrix: Functional connectivity matrix
878
+ subject_id: Subject identifier
879
+ output_path: Path to save the visualization (if None, just returns the bytes)
880
+
881
+ Returns:
882
+ Visualization bytes or path to the saved file
883
+ """
884
+ import io
885
+ import matplotlib.pyplot as plt
886
+
887
+ # Create the visualization
888
+ fig = plot_connectivity_matrix(fc_matrix, subject_id=subject_id, show_labels=True)
889
+
890
+ # If output path is provided, save directly
891
+ if output_path:
892
+ fig.savefig(output_path, dpi=300, bbox_inches='tight')
893
+ plt.close(fig)
894
+ return output_path
895
+
896
+ # Otherwise, return the bytes
897
+ buf = io.BytesIO()
898
+ fig.savefig(buf, format='png', dpi=300, bbox_inches='tight')
899
+ plt.close(fig)
900
+ buf.seek(0)
901
+
902
+ return buf
903
+
904
+ def main():
905
+ """
906
+ Main function for processing functional connectivity data from HuggingFace
907
+ Optimized for HuggingFace Spaces
908
+ """
909
+ print(f"OSF DemoVAE Adapter - Running in {'HuggingFace Spaces' if IS_SPACE else 'local mode'}")
910
+ print("="*50)
911
+
912
+ # Step 1: Load demographic data
913
+ demo_df = load_demographic_data()
914
+ if demo_df is None or len(demo_df) == 0:
915
+ print("Error: Could not load demographic data. Exiting.")
916
+ return
917
+
918
+ # Step 2: Get list of NII files to process
919
+ nii_files = get_all_nii_files()
920
+ if not nii_files:
921
+ print("Error: No NII files found. Exiting.")
922
+ return
923
+
924
+ print(f"Found {len(nii_files)} NII files to process")
925
+
926
+ # Step 3: Process NII files to FC matrices
927
+ fc_data = {}
928
+
929
+ # Process each subject - limit to 5 for testing in Spaces to save resources
930
+ max_subjects = 5 if IS_SPACE else len(nii_files)
931
+
932
+ # Use the demo_df to get subject IDs if available
933
+ if 'ID' in demo_df.columns:
934
+ subject_ids = demo_df['ID'].unique()[:max_subjects]
935
+ print(f"Processing {len(subject_ids)} subjects from demographic data")
936
+
937
+ for subject_id in tqdm(subject_ids, desc="Processing subjects"):
938
+ nii_file = f"{subject_id}_rs.nii"
939
+ if nii_file in nii_files:
940
+ fc_matrix, fc_vector = process_subject(subject_id, nii_file)
941
+ if fc_matrix is not None and fc_vector is not None:
942
+ fc_data[subject_id] = {
943
+ "matrix": fc_matrix,
944
+ "vector": fc_vector
945
+ }
946
+ else:
947
+ print(f"Warning: No NII file found for subject {subject_id}")
948
+ else:
949
+ # Process the first max_subjects NII files
950
+ for i, nii_file in enumerate(nii_files[:max_subjects]):
951
+ subject_id = nii_file.split('_')[0] # Extract subject ID (e.g., P01)
952
+ fc_matrix, fc_vector = process_subject(subject_id, nii_file)
953
+ if fc_matrix is not None and fc_vector is not None:
954
+ fc_data[subject_id] = {
955
+ "matrix": fc_matrix,
956
+ "vector": fc_vector
957
+ }
958
+
959
+ # Step 4: Print summary
960
+ print("\n" + "="*50)
961
+ print(f"Successfully processed {len(fc_data)} subjects")
962
+
963
+ if fc_data:
964
+ # Save combined FC vectors for modeling
965
+ fc_vectors = np.array([fc_data[subject_id]["vector"] for subject_id in sorted(fc_data.keys())])
966
+ os.makedirs(os.path.join(CACHE_DIR, "processed"), exist_ok=True)
967
+ np.save(os.path.join(CACHE_DIR, "processed", "all_fc_vectors.npy"), fc_vectors)
968
+
969
+ # Save subject IDs
970
+ with open(os.path.join(CACHE_DIR, "processed", "subject_ids.txt"), "w") as f:
971
+ for subject_id in sorted(fc_data.keys()):
972
+ f.write(f"{subject_id}\n")
973
+
974
+ # Show example statistics
975
+ first_subject = list(fc_data.keys())[0]
976
+ fc_matrix = fc_data[first_subject]["matrix"]
977
+ print(f"\nExample FC matrix for {first_subject}:")
978
+ print(f"Shape: {fc_matrix.shape}")
979
+ print(f"Min value: {fc_matrix.min():.4f}")
980
+ print(f"Max value: {fc_matrix.max():.4f}")
981
+ print(f"Mean value: {fc_matrix.mean():.4f}")
982
+
983
+ # Create visualization of the first subject's FC matrix
984
+ try:
985
+ # Get the matrix
986
+ fc_matrix = fc_data[first_subject]["matrix"]
987
+
988
+ # Create the plot with the new color scheme and region labels
989
+ os.makedirs(os.path.join(CACHE_DIR, "visualizations"), exist_ok=True)
990
+ plot_path = os.path.join(CACHE_DIR, "visualizations", f"{first_subject}_fc_matrix.png")
991
+
992
+ print(f"\nCreating visualization of functional connectivity for {first_subject}...")
993
+ fig = plot_connectivity_matrix(
994
+ fc_matrix,
995
+ subject_id=first_subject,
996
+ save_path=plot_path,
997
+ show_labels=True
998
+ )
999
+
1000
+ # Print explanation of the matrix
1001
+ print("\nFunctional Connectivity Matrix Explanation:")
1002
+ print("-------------------------------------------")
1003
+ print("This matrix represents functional connectivity between 45 brain regions.")
1004
+ print("Each cell (i,j) shows the correlation between activity in regions i and j.")
1005
+ print("- Yellow cells: Strong positive correlation (regions activate together)")
1006
+ print("- Green cells: Neutral or weak correlation")
1007
+ print("- Blue cells: Negative correlation (when one region activates, the other deactivates)")
1008
+ print("\nThe brain regions are labeled according to the AAL atlas:")
1009
+
1010
+ # Print some example region pairs with their connectivity values
1011
+ strong_pos = []
1012
+ strong_neg = []
1013
+
1014
+ # Find some interesting connections
1015
+ for i in range(fc_matrix.shape[0]):
1016
+ for j in range(i+1, fc_matrix.shape[0]):
1017
+ val = fc_matrix[i, j]
1018
+ if val > 0.7:
1019
+ strong_pos.append((i, j, val))
1020
+ elif val < -0.4:
1021
+ strong_neg.append((i, j, val))
1022
+
1023
+ # Print examples of strongly connected regions
1024
+ if strong_pos:
1025
+ print("\nExamples of strongly positively connected regions:")
1026
+ for i, j, val in sorted(strong_pos, key=lambda x: x[2], reverse=True)[:3]:
1027
+ if i < len(ATLAS_REGIONS) and j < len(ATLAS_REGIONS):
1028
+ print(f" • {ATLAS_REGIONS[i]} and {ATLAS_REGIONS[j]}: {val:.2f}")
1029
+
1030
+ if strong_neg:
1031
+ print("\nExamples of negatively connected regions:")
1032
+ for i, j, val in sorted(strong_neg, key=lambda x: x[2])[:3]:
1033
+ if i < len(ATLAS_REGIONS) and j < len(ATLAS_REGIONS):
1034
+ print(f" • {ATLAS_REGIONS[i]} and {ATLAS_REGIONS[j]}: {val:.2f}")
1035
+
1036
+ print(f"\nVisualization saved to: {plot_path}")
1037
+
1038
+ except Exception as e:
1039
+ print(f"Error creating visualization: {e}")
1040
+
1041
+ # Step 5: Try to load models (if DemoVAE is available)
1042
+ if DEMOVAE_AVAILABLE:
1043
+ models = load_models()
1044
+ if models:
1045
+ print("\nModels loaded successfully")
1046
+
1047
+ # If we have both models, try prediction on an example subject
1048
+ if "demovae" in models and "prediction" in models:
1049
+ try:
1050
+ # Get the first subject's FC vector
1051
+ fc_vec = fc_data[first_subject]["vector"]
1052
+
1053
+ # Get demographic features from demo_df
1054
+ if 'ID' in demo_df.columns:
1055
+ subject_row = demo_df[demo_df['ID'] == first_subject]
1056
+ if not subject_row.empty:
1057
+ # Extract demographic features
1058
+ demo_features = np.array([
1059
+ subject_row['age'].values[0],
1060
+ subject_row['mpo'].values[0] if 'mpo' in subject_row else 6,
1061
+ subject_row['education'].values[0] if 'education' in subject_row else 12,
1062
+ 1 if subject_row['gender'].values[0] == 'M' else 0 if 'gender' in subject_row else 0.5,
1063
+ 1 if subject_row['handedness'].values[0] == 'R' else 0 if 'handedness' in subject_row else 0.5
1064
+ ]).reshape(1, -1)
1065
+
1066
+ # Get latent representation from DemoVAE
1067
+ z = models["demovae"].get_latents(fc_vec.reshape(1, -1))
1068
+
1069
+ # Combine with demographic features for prediction
1070
+ X_combined = np.hstack([z, demo_features])
1071
+
1072
+ # Make prediction
1073
+ predicted_wab = models["prediction"].predict(X_combined)[0]
1074
+
1075
+ actual_wab = subject_row['wab_aq'].values[0] if 'wab_aq' in subject_row else None
1076
+
1077
+ print(f"\nPrediction for {first_subject}:")
1078
+ print(f"Predicted WAB AQ score: {predicted_wab:.2f}")
1079
+ if actual_wab is not None:
1080
+ print(f"Actual WAB AQ score: {actual_wab:.2f}")
1081
+ print(f"Prediction error: {abs(predicted_wab - actual_wab):.2f}")
1082
+
1083
+ # Generate a second visualization showing the relationship between
1084
+ # predicted aphasia recovery and specific brain region connectivity
1085
+ try:
1086
+ # Create a visualization showing the most important connections for prediction
1087
+ if 'feature_importances_' in dir(models["prediction"]):
1088
+ print("\nCreating visualization of important connections for prediction...")
1089
+
1090
+ # Get feature importances from the Random Forest model
1091
+ importances = models["prediction"].feature_importances_
1092
+
1093
+ # The first elements of importances correspond to the latent variables
1094
+ # and we can't directly map those to regions. So we'll compute a score for each region.
1095
+
1096
+ # Create a heatmap version of the FC matrix, with cells colored by importance
1097
+ importance_matrix = np.zeros_like(fc_matrix)
1098
+
1099
+ # Simple approach: use the original matrix but scale the color by overall importance
1100
+ # (this is a simplification since we can't directly map latent vars to regions)
1101
+ avg_importance = np.mean(importances[:z.shape[1]])
1102
+ importance_matrix = fc_matrix * avg_importance
1103
+
1104
+ # Plot this matrix
1105
+ imp_plot_path = os.path.join(CACHE_DIR, "visualizations",
1106
+ f"{first_subject}_importance_matrix.png")
1107
+
1108
+ # Use original connectivity but with a different title
1109
+ fig = plot_connectivity_matrix(
1110
+ fc_matrix,
1111
+ subject_id=f"{first_subject} (Regions Influencing Prediction)",
1112
+ save_path=imp_plot_path,
1113
+ show_labels=True
1114
+ )
1115
+
1116
+ print(f"Prediction importance visualization saved to: {imp_plot_path}")
1117
+ except Exception as e:
1118
+ print(f"Error creating importance visualization: {e}")
1119
+ except Exception as e:
1120
+ print(f"Error making prediction: {e}")
1121
+ else:
1122
+ print("No FC matrices were successfully processed")
1123
+
1124
+ print("\nProcessing complete!")
1125
+ print(f"Data cached in: {CACHE_DIR}")
1126
+ print("Ready for use in HuggingFace Spaces!")
1127
+
1128
+
1129
+ def train_demovae_model(X_fc, X_demo, y, save_model=True, model_path="osf_demovae_model.pt"):
1130
+ """
1131
+ Train DemoVAE model on OSF data using the PIP implementation
1132
+
1133
+ Args:
1134
+ X_fc: Functional connectivity features
1135
+ X_demo: Demographic features
1136
+ y: Target variable (WAB AQ scores)
1137
+ save_model: Whether to save the model
1138
+ model_path: Path to save the model
1139
+
1140
+ Returns:
1141
+ Trained model, latent representations, test data
1142
+ """
1143
+ if not DEMOVAE_AVAILABLE:
1144
+ raise ImportError("DemoVAE package is not available. Cannot train model.")
1145
+
1146
+ print("Training DemoVAE model on OSF data...")
1147
+
1148
+ # Split data into train and test sets
1149
+ from sklearn.model_selection import train_test_split
1150
+ X_fc_train, X_fc_test, X_demo_train, X_demo_test, y_train, y_test = train_test_split(
1151
+ X_fc, X_demo, y, test_size=0.2, random_state=42
1152
+ )
1153
+
1154
+ # Check for categorical variables that have only one value
1155
+ demo_types = []
1156
+ demo_train_processed = []
1157
+ demo_test_processed = []
1158
+
1159
+ # Define demographic types for each column and validate them
1160
+ # [age, mpo, education, gender, handedness]
1161
+ all_demo_types = ['continuous', 'continuous', 'continuous', 'categorical', 'categorical']
1162
+
1163
+ # Check each demographic variable and convert to continuous if there's only one value
1164
+ for i, demo_type in enumerate(all_demo_types):
1165
+ demo_train_col = X_demo_train[:, i]
1166
+
1167
+ if demo_type == 'categorical' and len(np.unique(demo_train_col)) == 1:
1168
+ print(f"Warning: Column {i} has only one category. Treating as continuous.")
1169
+ demo_types.append('continuous')
1170
+ # Convert to a scalar continuous variable (0 or 1)
1171
+ demo_train_processed.append(demo_train_col)
1172
+ demo_test_processed.append(X_demo_test[:, i])
1173
+ else:
1174
+ demo_types.append(demo_type)
1175
+ demo_train_processed.append(demo_train_col)
1176
+ demo_test_processed.append(X_demo_test[:, i])
1177
+
1178
+ # Print final demographic types
1179
+ print(f"Using demographic types: {demo_types}")
1180
+
1181
+ # Create model with parameters tuned for OSF data
1182
+ model = DemoVAE(
1183
+ latent_dim=30,
1184
+ loss_rec_mult=100,
1185
+ loss_decor_mult=10,
1186
+ loss_pred_mult=0.1,
1187
+ nepochs=300,
1188
+ pperiod=50,
1189
+ bsize=min(32, len(X_fc_train)),
1190
+ use_cuda=torch.cuda.is_available()
1191
+ )
1192
+
1193
+ # Train the model
1194
+ try:
1195
+ model.fit(X_fc_train, demo_train_processed, demo_types)
1196
+ except Exception as e:
1197
+ print(f"Error training model: {str(e)}")
1198
+ # Fall back to treating all demographics as continuous if categorical fails
1199
+ print("Falling back to all continuous demographics")
1200
+ demo_types = ['continuous'] * len(all_demo_types)
1201
+ model.fit(X_fc_train, demo_train_processed, demo_types)
1202
+
1203
+ # Get latent representations
1204
+ z_train = model.get_latents(X_fc_train)
1205
+ z_test = model.get_latents(X_fc_test)
1206
+
1207
+ # Save model if requested
1208
+ if save_model:
1209
+ model.save(model_path)
1210
+ print(f"Model saved to {model_path}")
1211
+
1212
+ return model, z_train, z_test, X_fc_test, X_demo_test, y_test
1213
+
1214
+
1215
+ def generate_custom_fc(demo_values, model, prediction_model=None, visualize=True):
1216
+ """
1217
+ Generate custom FC matrix for a patient with given demographics
1218
+
1219
+ Args:
1220
+ demo_values: Dictionary with demographic values
1221
+ model: Trained DemoVAE model
1222
+ prediction_model: Optional aphasia prediction model
1223
+ visualize: Whether to create a visualization of the matrix
1224
+
1225
+ Returns:
1226
+ Tuple of (FC matrix, predicted aphasia score, visualization path or None)
1227
+ """
1228
+ if not DEMOVAE_AVAILABLE:
1229
+ raise ImportError("DemoVAE package is not available. Cannot generate custom FC matrix.")
1230
+
1231
+ # Extract demographic values
1232
+ age = demo_values.get('age', 60)
1233
+ mpo = demo_values.get('mpo', 6)
1234
+ education = demo_values.get('education', 16)
1235
+ gender = 1 if demo_values.get('gender', 'male').lower() == 'male' else 0
1236
+ handedness = 1 if demo_values.get('handedness', 'right').lower() == 'right' else 0
1237
+
1238
+ # Format for model
1239
+ custom_demo = np.array([[age, mpo, education, gender, handedness]])
1240
+ custom_demo_lists = [custom_demo[:, i] for i in range(custom_demo.shape[1])]
1241
+ demo_types = ['continuous', 'continuous', 'continuous', 'categorical', 'categorical']
1242
+
1243
+ # Generate random latent vector
1244
+ np.random.seed(42) # For reproducibility
1245
+ custom_z = np.random.randn(1, model.latent_dim)
1246
+
1247
+ # Convert to torch for model
1248
+ custom_z_torch = to_cuda(to_torch(custom_z), model.vae.use_cuda)
1249
+
1250
+ # Format demo for model
1251
+ demo_torch = to_torch(custom_demo)
1252
+ if model.vae.use_cuda:
1253
+ demo_torch = demo_torch.cuda()
1254
+
1255
+ # Generate FC
1256
+ custom_fc = model.transform(1, custom_demo_lists, demo_types)
1257
+ custom_fc = custom_fc[0] # Get the first (only) example
1258
+
1259
+ # Convert to matrix
1260
+ if custom_fc.size == 1000:
1261
+ custom_fc = custom_fc[:990] # Use first 990 elements for 45x45 matrix
1262
+
1263
+ fc_matrix = vec2mat(custom_fc)
1264
+
1265
+ # Ensure matrix is the right shape for visualization (45x45)
1266
+ if fc_matrix.shape[0] != 45:
1267
+ # Resize to 45x45
1268
+ temp_matrix = np.zeros((45, 45))
1269
+ min_dim = min(fc_matrix.shape[0], 45)
1270
+ temp_matrix[:min_dim, :min_dim] = fc_matrix[:min_dim, :min_dim]
1271
+ np.fill_diagonal(temp_matrix, 1) # Ensure diagonal is 1
1272
+ fc_matrix = temp_matrix
1273
+
1274
+ # Predict aphasia score if model is provided
1275
+ predicted_score = None
1276
+ if prediction_model is not None:
1277
+ X_combined = np.hstack([custom_z, custom_demo])
1278
+ predicted_score = float(prediction_model.predict(X_combined)[0])
1279
+ predicted_score = max(0, min(100, predicted_score)) # Clip to valid range
1280
+
1281
+ # Create visualization if requested
1282
+ viz_path = None
1283
+ if visualize:
1284
+ try:
1285
+ # Create a descriptive ID based on demographics
1286
+ custom_id = f"custom_age{age}_mpo{mpo}_edu{education}_{'M' if gender == 1 else 'F'}_{'R' if handedness == 1 else 'L'}"
1287
+
1288
+ # Create directory for visualizations
1289
+ os.makedirs(os.path.join(CACHE_DIR, "visualizations"), exist_ok=True)
1290
+ viz_path = os.path.join(CACHE_DIR, "visualizations", f"{custom_id}_fc_matrix.png")
1291
+
1292
+ # Create title with demographic info and predicted score
1293
+ title = f"Custom FC Matrix: Age {age}, MPO {mpo}, Education {education}, "
1294
+ title += f"{'Male' if gender == 1 else 'Female'}, {'Right' if handedness == 1 else 'Left'} handed"
1295
+
1296
+ if predicted_score is not None:
1297
+ title += f"\nPredicted WAB Score: {predicted_score:.1f}"
1298
+
1299
+ # Create the visualization
1300
+ fig = plot_connectivity_matrix(
1301
+ fc_matrix,
1302
+ subject_id=custom_id,
1303
+ save_path=viz_path,
1304
+ show_labels=True
1305
+ )
1306
+
1307
+ print(f"Generated custom FC matrix visualization saved to: {viz_path}")
1308
+
1309
+ except Exception as e:
1310
+ print(f"Error creating visualization for custom FC matrix: {e}")
1311
+
1312
+ return fc_matrix, predicted_score, viz_path
1313
+
1314
+
1315
+ def predict_aphasia_recovery(z_test, X_demo_test, y_test):
1316
+ """
1317
+ Train a model to predict aphasia recovery based on latent features
1318
+
1319
+ Args:
1320
+ z_test: Latent representations
1321
+ X_demo_test: Demographic features
1322
+ y_test: Target variable (WAB AQ scores)
1323
+
1324
+ Returns:
1325
+ Prediction model, predictions, RMSE, R²
1326
+ """
1327
+ print("==================================================")
1328
+ print("TRAINING RANDOM FOREST FOR APHASIA SCORE PREDICTION")
1329
+ print("==================================================")
1330
+ print("This is the second stage of the prediction pipeline:")
1331
+ print("1. VAE model extracts latent representations from FC data")
1332
+ print("2. Now training Random Forest to predict WAB AQ scores")
1333
+ print(f"Input features: {z_test.shape[1]} latent variables + {X_demo_test.shape[1]} demographic features")
1334
+ print(f"Training samples: {len(y_test)}")
1335
+
1336
+ # Combine latent features with demographic features
1337
+ X_combined = np.hstack([z_test, X_demo_test])
1338
+
1339
+ from sklearn.ensemble import RandomForestRegressor
1340
+ from sklearn.metrics import mean_squared_error, r2_score
1341
+ import math
1342
+
1343
+ print("Training Random Forest Regressor with 100 trees...")
1344
+ # Train a random forest regression model
1345
+ model = RandomForestRegressor(n_estimators=100, random_state=42)
1346
+ model.fit(X_combined, y_test)
1347
+ print("Random Forest training complete!")
1348
+
1349
+ # Make predictions
1350
+ y_pred = model.predict(X_combined)
1351
+
1352
+ # Evaluate model
1353
+ mse = mean_squared_error(y_test, y_pred)
1354
+ rmse_val = math.sqrt(mse)
1355
+ r2 = r2_score(y_test, y_pred)
1356
+
1357
+ print(f"Random Forest Prediction Results:")
1358
+ print(f"RMSE: {rmse_val:.4f}")
1359
+ print(f"R²: {r2:.4f}")
1360
+ print("==================================================")
1361
+ print("Random Forest training successful!")
1362
+ print("==================================================")
1363
+
1364
+ # Return the model and evaluation metrics
1365
+ return model, y_pred, rmse_val, r2
1366
+
1367
+
1368
+ if __name__ == "__main__":
1369
+ main()
pip/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pip/.gitignore ADDED
@@ -0,0 +1 @@
 
 
1
+ dist/*
pip/.pypirc ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [distutils]
2
+ index-servers =
3
+ pypi
4
+
5
+ [pypi]
6
+ username = __token__
pip/LICENSE ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2024 Anton Orlichenko
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
pip/README.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Installable version of DemoVAE demographic-conditioned variational autoencoder for fMRI data.
2
+
3
+ For use with pip.
4
+
5
+ Perform fMRI distribution sampling, remove confounds, and harmonize multi-site data.
6
+
7
+ Supports FC, ALFF, and ReHO data.
pip/pyproject.toml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "demovae"
3
+ version = "0.0.6"
4
+ authors = [
5
+ { name="Anton Orlichenko", email="aorliche@gmail.com" }
6
+ ]
7
+ description = "A demographic-conditioned variational autoencoder for fMRI distribution sampling, removal of confounds, and multi-site harmonization. Works with FC, ALFF, or ReHO data."
8
+ readme = "README.md"
9
+ dependencies = [ "numpy", "torch", "scikit-learn" ]
10
+ requires-python = ">=3.8"
11
+ classifiers = [
12
+ "Programming Language :: Python :: 3",
13
+ "License :: OSI Approved :: MIT License",
14
+ "Operating System :: OS Independent",
15
+ ]
16
+
17
+ [project.urls]
18
+ Homepage = "https://github.com/aorliche/demo-vae/"
19
+ Issues = "https://github.com/aorliche/demo-vae/issues"
20
+
21
+ [build-system]
22
+ requires = ["hatchling"]
23
+ build-backend = "hatchling.build"
pip/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ numpy>=1.20.0
2
+ pandas>=1.3.0
3
+ torch>=1.9.0
4
+ matplotlib>=3.4.0
5
+ scikit-learn>=1.0.0
6
+ tqdm>=4.62.0
7
+ nibabel>=3.2.0
8
+ gradio>=3.50.0
9
+ pillow>=9.0.0
10
+ datasets>=2.10.0
11
+ huggingface_hub>=0.16.0
pip/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
pip/src/demovae/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from demovae.sklearn import DemoVAE
pip/src/demovae/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (226 Bytes). View file
 
pip/src/demovae/__pycache__/model.cpython-311.pyc ADDED
Binary file (15.5 kB). View file
 
pip/src/demovae/__pycache__/sklearn.cpython-311.pyc ADDED
Binary file (6.42 kB). View file
 
pip/src/demovae/model.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import random
7
+ import numpy as np
8
+
9
+ from sklearn.linear_model import Ridge
10
+ from sklearn.linear_model import LogisticRegression
11
+
12
+ def to_torch(x):
13
+ return torch.from_numpy(x).float()
14
+
15
+ def to_cuda(x, use_cuda):
16
+ if use_cuda:
17
+ return x.cuda()
18
+ else:
19
+ return x
20
+
21
+ def to_numpy(x):
22
+ return x.detach().cpu().numpy()
23
+
24
+ class VAE(nn.Module):
25
+ def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
26
+ super(VAE, self).__init__()
27
+ self.input_dim = input_dim
28
+ self.latent_dim = latent_dim
29
+ self.demo_dim = demo_dim
30
+ self.use_cuda = use_cuda
31
+ self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
32
+ self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
33
+ self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
34
+ self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
35
+
36
+ def enc(self, x):
37
+ x = F.relu(self.enc1(x))
38
+ z = self.enc2(x)
39
+ return z
40
+
41
+ def gen(self, n):
42
+ return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
43
+
44
+ def dec(self, z, demo):
45
+ z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
46
+ x = F.relu(self.dec1(z))
47
+ x = self.dec2(x)
48
+ #x = x.reshape(len(z), 264, 5)
49
+ #x = torch.einsum('nac,nbc->nab', x, x)
50
+ #a,b = np.triu_indices(264, 1)
51
+ #x = x[:,a,b]
52
+ return x
53
+
54
+ def rmse(a, b, mean=torch.mean):
55
+ return mean((a-b)**2)**0.5
56
+
57
+ def latent_loss(z, use_cuda=True):
58
+ C = z.T@z
59
+ mu = torch.mean(z, dim=0)
60
+ tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
61
+ tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
62
+ loss_C = rmse(C, tgt1)
63
+ loss_mu = rmse(mu, tgt2)
64
+ return loss_C, loss_mu, C, mu
65
+
66
+ def decor_loss(z, demo, use_cuda=True):
67
+ ps = []
68
+ losses = []
69
+ for di in range(demo.shape[1]):
70
+ d = demo[:,di]
71
+ d = d - torch.mean(d)
72
+ p = torch.einsum('n,nz->z', d, z)
73
+ p = p/torch.std(d)
74
+ p = p/torch.einsum('nz,nz->z', z, z)
75
+ tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
76
+ loss = rmse(p, tgt)
77
+ losses.append(loss)
78
+ ps.append(p)
79
+ losses = torch.stack(losses)
80
+ return losses, ps
81
+
82
+ def pretty(x):
83
+ return f'{round(float(x), 4)}'
84
+
85
+ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
86
+ demo_t = []
87
+ demo_idx = 0
88
+ for d,t,s in zip(demo, demo_types, pred_stats):
89
+ if t == 'continuous':
90
+ demo_t.append(to_cuda(to_torch(d), use_cuda))
91
+ elif t == 'categorical':
92
+ for dd in d:
93
+ if dd not in s:
94
+ print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
95
+ raise Exception('Bad demographic')
96
+ for ss in s:
97
+ idx = (d == ss).astype('bool')
98
+ zeros = torch.zeros(len(d))
99
+ zeros[idx] = 1
100
+ demo_t.append(to_cuda(zeros, use_cuda))
101
+ demo_idx += 1
102
+ demo_t = torch.stack(demo_t).permute(1,0)
103
+ return demo_t
104
+
105
+ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
106
+ # Get linear predictors for demographics
107
+ pred_w = []
108
+ pred_i = []
109
+ # Pred stats are mean and std for continuous, and a list of all values for categorical
110
+ pred_stats = []
111
+ for i,d,t in zip(range(len(demo)), demo, demo_types):
112
+ print(f'Fitting auxilliary guidance model for demographic {i} {t}...', end='')
113
+ if t == 'continuous':
114
+ pred_stats.append([np.mean(d), np.std(d)])
115
+ reg = Ridge(alpha=alpha).fit(x, d)
116
+ reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
117
+ reg_i = reg.intercept_
118
+ pred_w.append(reg_w)
119
+ pred_i.append(reg_i)
120
+ elif t == 'categorical':
121
+ pred_stats.append(sorted(list(set(list(d)))))
122
+ reg = LogisticRegression(C=LR_C).fit(x, d)
123
+ # Binary
124
+ if len(reg.coef_) == 1:
125
+ reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
126
+ reg_i = reg.intercept_[0]
127
+ pred_w.append(-reg_w)
128
+ pred_i.append(-reg_i)
129
+ pred_w.append(reg_w)
130
+ pred_i.append(reg_i)
131
+ # Categorical
132
+ else:
133
+ for i in range(len(reg.coef_)):
134
+ reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
135
+ reg_i = reg.intercept_[i]
136
+ pred_w.append(reg_w)
137
+ pred_i.append(reg_i)
138
+ else:
139
+ print(f'demographic type "{t}" not "continuous" or "categorical"')
140
+ raise Exception('Bad demographic type')
141
+ print(' done')
142
+ ret_obj.pred_stats = pred_stats
143
+ # Convert input to pytorch
144
+ print('Converting input to pytorch')
145
+ x = to_cuda(to_torch(x), vae.use_cuda)
146
+ # Convert demographics to pytorch
147
+ print('Converting demographics to pytorch')
148
+ demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
149
+ # Training loop
150
+ print('Beginning VAE training')
151
+ ce = nn.CrossEntropyLoss()
152
+ optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
153
+ for e in range(nepochs):
154
+ for bs in range(0,len(x),bsize):
155
+ xb = x[bs:(bs+bsize)]
156
+ db = demo_t[bs:(bs+bsize)]
157
+ optim.zero_grad()
158
+ # Reconstruct
159
+ z = vae.enc(xb)
160
+ y = vae.dec(z, db)
161
+ loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
162
+ loss_decor, _ = decor_loss(z, db, vae.use_cuda)
163
+ loss_decor = sum(loss_decor)
164
+ loss_rec = rmse(xb, y)
165
+ # Sample demographics
166
+ demo_gen = []
167
+ for s,t in zip(pred_stats, demo_types):
168
+ if t == 'continuous':
169
+ mu = s[0]
170
+ std = s[1]
171
+ dd = torch.randn(100).float()
172
+ dd = dd*std+mu
173
+ dd = to_cuda(dd, vae.use_cuda)
174
+ demo_gen.append(dd)
175
+ elif t == 'categorical':
176
+ idx = random.randint(0, len(s)-1)
177
+ for i in range(len(s)):
178
+ if idx == i:
179
+ dd = torch.ones(100).float()
180
+ else:
181
+ dd = torch.zeros(100).float()
182
+ dd = to_cuda(dd, vae.use_cuda)
183
+ demo_gen.append(dd)
184
+ demo_gen = torch.stack(demo_gen).permute(1,0)
185
+ # Generate
186
+ z = vae.gen(100)
187
+ y = vae.dec(z, demo_gen)
188
+ # Regressor/classifier guidance loss
189
+ losses_pred = []
190
+ idcs = []
191
+ dg_idx = 0
192
+ for s,t in zip(pred_stats, demo_types):
193
+ if t == 'continuous':
194
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
195
+ loss = rmse(demo_gen[:,dg_idx], yy)
196
+ losses_pred.append(loss)
197
+ idcs.append(float(demo_gen[0,dg_idx]))
198
+ dg_idx += 1
199
+ elif t == 'categorical':
200
+ loss = 0
201
+ for i in range(len(s)):
202
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
203
+ loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
204
+ idcs.append(int(demo_gen[0,dg_idx]))
205
+ dg_idx += 1
206
+ losses_pred.append(loss)
207
+ total_loss = loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor + loss_pred_mult*sum(losses_pred)
208
+ total_loss.backward()
209
+ optim.step()
210
+ if e%pperiod == 0 or e == nepochs-1:
211
+ print(f'Epoch {e} ', end='')
212
+ print(f'ReconLoss {pretty(loss_rec)} ', end='')
213
+ print(f'CovarianceLoss {pretty(loss_C)} ', end='')
214
+ print(f'MeanLoss {pretty(loss_mu)} ', end='')
215
+ print(f'DecorLoss {pretty(loss_decor)} ', end='')
216
+ losses_pred = [pretty(loss) for loss in losses_pred]
217
+ print(f'GuidanceTargets {idcs} GuidanceLosses {losses_pred} ', end='')
218
+ print()
219
+ print('Training complete.')
220
+
221
+
pip/src/demovae/sklearn.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from demovae.model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
3
+
4
+ from sklearn.base import BaseEstimator
5
+
6
+ # For saving
7
+ import torch
8
+
9
+ class DemoVAE(BaseEstimator):
10
+ def __init__(self, **params):
11
+ self.set_params(**params)
12
+
13
+ @staticmethod
14
+ def get_default_params():
15
+ return dict(latent_dim=60, # Latent dimension
16
+ use_cuda=True, # GPU acceleration
17
+ nepochs=3000, # Training epochs
18
+ pperiod=100, # Epochs between printing updates
19
+ bsize=1000, # Batch size
20
+ loss_C_mult=1, # Covariance loss (KL div)
21
+ loss_mu_mult=1, # Mean loss (KL div)
22
+ loss_rec_mult=100, # Reconstruction loss
23
+ loss_decor_mult=10, # Latent-demographic decorrelation loss
24
+ loss_pred_mult=0.001, # Classifier/regressor guidance loss
25
+ alpha=100, # Regularization for continuous guidance models
26
+ LR_C=100, # Regularization for categorical guidance models
27
+ lr=1e-4, # Learning rate
28
+ weight_decay=0, # L2 regularization for VAE model
29
+ )
30
+
31
+ def get_params(self, **params):
32
+ return dict(latent_dim=self.latent_dim,
33
+ use_cuda=self.use_cuda,
34
+ nepochs=self.nepochs,
35
+ pperiod=self.pperiod,
36
+ bsize=self.bsize,
37
+ loss_C_mult=self.loss_C_mult,
38
+ loss_mu_mult=self.loss_mu_mult,
39
+ loss_rec_mult=self.loss_rec_mult,
40
+ loss_decor_mult=self.loss_decor_mult,
41
+ loss_pred_mult=self.loss_pred_mult,
42
+ alpha=self.alpha,
43
+ LR_C=self.LR_C,
44
+ lr=self.lr,
45
+ weight_decay=self.weight_decay,
46
+ )
47
+
48
+ def set_params(self, **params):
49
+ dft = DemoVAE.get_default_params()
50
+ for key in dft:
51
+ if key in params:
52
+ setattr(self, key, params[key])
53
+ else:
54
+ setattr(self, key, dft[key])
55
+ return self
56
+
57
+ def fit(self, x, demo, demo_types, **kwargs):
58
+ # Get demo_dim
59
+ demo_dim = 0
60
+ for d,t in zip(demo, demo_types):
61
+ if t == 'continuous':
62
+ demo_dim += 1
63
+ elif t == 'categorical':
64
+ ll = len(set(list(d)))
65
+ if ll == 1:
66
+ print('Only one type of category for categorical variable')
67
+ raise Exception('Bad categorical')
68
+ demo_dim += ll
69
+ else:
70
+ print(f'demographic type "{t}" not "continuous" or "categorical"')
71
+ raise Exception('Bad demographic type')
72
+ # Save parameters
73
+ self.input_dim = x.shape[1]
74
+ self.demo_dim = demo_dim
75
+ # Create model
76
+ self.vae = VAE(x.shape[1], self.latent_dim, demo_dim, self.use_cuda)
77
+ # Train model
78
+ train_vae(self.vae, x, demo, demo_types,
79
+ self.nepochs, self.pperiod, self.bsize,
80
+ self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, self.loss_decor_mult, self.loss_pred_mult,
81
+ self.lr, self.weight_decay, self.alpha, self.LR_C,
82
+ self)
83
+ return self
84
+
85
+ def transform(self, x, demo, demo_types, **kwargs):
86
+ if isinstance(x, int):
87
+ # Generate
88
+ z = self.vae.gen(x)
89
+ else:
90
+ # Get latents for real data
91
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
92
+ demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
93
+ y = self.vae.dec(z, demo_t)
94
+ return to_numpy(y)
95
+
96
+ def fit_transform(self, x, demo, demo_types, **kwargs):
97
+ self.fit(x, demo, demo_types)
98
+ return self.transform(x, demo, demo_types)
99
+
100
+ def get_latents(self, x):
101
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
102
+ return to_numpy(z)
103
+
104
+ def save(self, path):
105
+ params = self.get_params()
106
+ dct = dict(pred_stats=self.pred_stats,
107
+ params=params,
108
+ input_dim=self.input_dim,
109
+ demo_dim=self.demo_dim,
110
+ model_state_dict=self.vae.state_dict())
111
+ torch.save(dct, path)
112
+
113
+ def load(self, path):
114
+ dct = torch.load(path)
115
+ self.pred_stats = dct['pred_stats']
116
+ self.set_params(**dct['params'])
117
+ self.vae = VAE(dct['input_dim'],
118
+ dct['params']['latent_dim'],
119
+ dct['demo_dim'],
120
+ dct['params']['use_cuda'])
121
+ self.vae.load_state_dict(dct['model_state_dict'])
122
+
123
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.7.0
2
+ numpy>=1.19.0
3
+ pandas>=1.1.0
4
+ scikit-learn>=0.24.0
5
+ matplotlib>=3.3.0
6
+ nibabel>=3.2.0
7
+ tqdm>=4.50.0
8
+ gradio>=3.50.0
9
+ natsort>=8.0.0
10
+ requests>=2.25.0
11
+ datasets>=2.15.0
12
+ huggingface_hub>=0.10.0