File size: 13,003 Bytes
ef677f1
1c47445
 
 
 
b32645b
ef677f1
 
1c47445
 
 
 
a7f7808
1c47445
ef677f1
1c47445
 
763369a
1c47445
c775b23
1c47445
c775b23
1c47445
 
9135a28
1c47445
67303f6
1c47445
 
 
 
ef677f1
1c47445
 
 
 
 
55c1385
1c47445
 
 
 
55c1385
1c47445
 
 
 
 
 
 
 
 
b32645b
55c1385
1c47445
 
9641510
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbe81c1
1c47445
 
 
 
 
 
 
 
 
 
 
dbe81c1
1c47445
 
 
 
 
 
dbe81c1
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b32645b
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7f7808
1c47445
 
 
37a1b01
 
 
 
 
 
1c47445
 
 
 
 
 
 
37a1b01
 
1c47445
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50c6714
ef677f1
 
 
1c47445
 
 
 
ef677f1
 
 
 
 
 
 
 
1c47445
 
 
ef677f1
 
 
1c47445
b32645b
 
 
 
 
1c47445
 
b32645b
1c47445
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
import os
import sys
# Add the src directory to the path so we can import from demovae
sys.path.append(os.path.join(os.path.dirname(__file__), 'src'))

import numpy as np
import torch
from pathlib import Path
import nibabel as nib
from data_preprocessing import preprocess_fmri_to_fc
from src.demovae.sklearn import DemoVAE
from analysis import analyze_fc_patterns
from visualization import plot_fc_matrices
from config import MODEL_CONFIG, DATASET_CONFIG
import pandas as pd
import io
from typing import List, Dict, Union, Tuple, Any

def train_fc_vae(X, demo_data, demo_types, model_config):
    """
    Train a VAE model on functional connectivity matrices
    """
    n_rois = 264
    input_dim = (n_rois * (n_rois - 1)) // 2
    
    print(f"Creating VAE with latent dim={model_config['latent_dim']}, epochs={model_config['nepochs']}")
    
    # Ensure X is a numpy array with correct data type
    if not isinstance(X, np.ndarray):
        print(f"Converting X from {type(X)} to numpy array")
        X = np.array(X, dtype=np.float32)
    
    # Ensure demo_data contains numpy arrays
    for i, d in enumerate(demo_data):
        if not isinstance(d, np.ndarray):
            print(f"Converting demographic {i} from {type(d)} to numpy array")
            demo_data[i] = np.array(d)
    
    # Check for NaN or Inf values
    if np.isnan(X).any() or np.isinf(X).any():
        print("Warning: X contains NaN or Inf values. Replacing with zeros.")
        X = np.nan_to_num(X)
    
    # Create the VAE model
    vae = DemoVAE(
        latent_dim=model_config['latent_dim'],
        nepochs=model_config['nepochs'],
        bsize=model_config['bsize'],
        loss_rec_mult=model_config.get('loss_rec_mult', 100),
        loss_decor_mult=model_config.get('loss_decor_mult', 10),
        lr=model_config.get('lr', 1e-4),
        use_cuda=torch.cuda.is_available()
    )
    
    print("Fitting VAE model...")
    vae.fit(X, demo_data, demo_types)
    
    return vae, X, demo_data, demo_types

def load_data(data_dir="SreekarB/OSFData", demographic_file=None, use_hf_dataset=True):
    """
    Load fMRI data and demographics from HuggingFace dataset or local files
    """
    if use_hf_dataset:
        # Load from HuggingFace Datasets
        from datasets import load_dataset
        
        print(f"Loading dataset from HuggingFace: {data_dir}")
        dataset = load_dataset(data_dir)
        
        print(f"Dataset columns: {dataset['train'].column_names}")
        
        # Get demographics directly from the dataset
        # Create a DataFrame from the dataset features
        demo_df = pd.DataFrame({
            'ID': dataset['train']['ID'],
            'wab_aq': dataset['train']['wab_aq'],
            'age': dataset['train']['age'],
            'mpo': dataset['train']['mpo'],
            'education': dataset['train']['education'],
            'gender': dataset['train']['gender'],
            'handedness': dataset['train']['handedness']
        })
        
        print(f"Loaded demographic data with {len(demo_df)} subjects")
        
        # Extract demographic data matching our expected format
        # Map the dataset columns to our expected format
        demo_data = [
            demo_df['age'].values,  # age at stroke -> age
            demo_df['gender'].values,  # sex -> gender
            demo_df['mpo'].values,  # months post stroke -> mpo
            demo_df['wab_aq'].values  # wab score -> wab_aq
        ]
        
        # Check for FC matrices in the dataset
        fc_columns = []
        for col in dataset['train'].column_names:
            if col.startswith("fc_") or "_fc" in col:
                fc_columns.append(col)
        
        if fc_columns:
            print(f"Found {len(fc_columns)} FC matrix columns: {fc_columns}")
            # Extract FC matrices
            fc_matrices = []
            for fc_col in fc_columns:
                fc_matrices.append(dataset['train'][fc_col])
            
            # If we have FC matrices, return them directly
            demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
            return fc_matrices, demo_data, demo_types
            
        # If no FC matrices, look for .nii files
        nii_files = []
        for col in dataset['train'].column_names:
            if col.endswith(".nii.gz") or col.endswith(".nii"):
                nii_files.append(dataset['train'][col])
        
        if nii_files:
            print(f"Found {len(nii_files)} .nii files")
        else:
            print("No FC matrices or .nii files found in dataset. Will need to construct FC matrices.")
            # If no structured data is found, we can try to download raw files later
            
    else:
        # Original local file loading
        # Load demographics
        demo_df = pd.read_csv(demographic_file)
        
        demo_data = [
            demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else demo_df['age'].values,
            demo_df['sex'].values if 'sex' in demo_df.columns else demo_df['gender'].values,
            demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else demo_df['mpo'].values,
            demo_df['wab_score'].values if 'wab_score' in demo_df.columns else demo_df['wab_aq'].values
        ]
        
        # Load fMRI files
        nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
    
    demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
    return nii_files, demo_data, demo_types

def run_fc_analysis(data_dir="SreekarB/OSFData", 
                    demographic_file=None, 
                    latent_dim=32, 
                    nepochs=1000, 
                    bsize=16,
                    save_model=True,
                    use_hf_dataset=True,
                    return_data=False):
    
    # Update MODEL_CONFIG with user-specified parameters
    MODEL_CONFIG.update({
        'latent_dim': latent_dim,
        'nepochs': nepochs,
        'bsize': bsize
    })
    
    try:
        # Load data
        print("Loading data...")
        nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
        
        # For SreekarB/OSFData, directly generate synthetic FC matrices
        if data_dir == "SreekarB/OSFData" and use_hf_dataset:
            print("Using SreekarB/OSFData dataset with synthetic FC matrices...")
            X, demo_data, demo_types = preprocess_fmri_to_fc(data_dir, demo_data, demo_types)
        # Check if we got FC matrices directly
        elif isinstance(nii_files, list) and len(nii_files) > 0 and hasattr(nii_files[0], 'shape'):
            print("Using pre-computed FC matrices...")
            # Convert list of FC matrices to numpy array
            X = np.stack([np.array(fc) for fc in nii_files])
        else:
            # Prepare data by converting fMRI to FC matrices
            print("Converting fMRI data to FC matrices...")
            X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
        
        # Print shapes and data types
        print(f"X shape: {X.shape}, type: {type(X)}")
        for i, d in enumerate(demo_data):
            print(f"Demo data {i} shape: {d.shape if hasattr(d, 'shape') else len(d)}, type: {type(d)}")
        
        # Train VAE and get data
        print("Training VAE...")
        try:
            # Use the proper DemoVAE implementation from src/demovae/sklearn.py
            vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
            
            if save_model:
                print("Saving model...")
                os.makedirs('models', exist_ok=True)
                # Use the save method from DemoVAE
                vae.save('models/vae_model.pth')
                print("Model saved successfully.")
        except Exception as e:
            print(f"Error during VAE training: {e}")
            raise
        
        # Get latent representations
        print("Getting latent representations...")
        latents = vae.get_latents(X)
        
        # Analyze results
        print("Analyzing demographic relationships...")
        demographics = {
            'age': demo_data[0],
            'months_post_onset': demo_data[2],
            'wab_aq': demo_data[3]
        }
        analysis_results = analyze_fc_patterns(latents, demographics)
        
        # Generate new FC matrix
        print("Generating new FC matrices...")
        
        # Get data types from original demographic data for proper conversion
        demo_dtypes = [type(d[0]) if len(d) > 0 else float for d in demo_data]
        
        # Convert to numpy arrays to avoid "expected np.ndarray (got list)" error
        new_demographics = [
            np.array([60.0], dtype=np.float64),        # age
            np.array(['M'], dtype=np.str_),           # gender
            np.array([12.0], dtype=np.float64),       # months post onset
            np.array([80.0], dtype=np.float64)        # wab score
        ]
        
        # Verify the demographic data arrays match the expected types
        print("Demographic data types:")
        for i, (name, data) in enumerate(zip(['age', 'gender', 'mpo', 'wab'], new_demographics)):
            print(f"  {name}: shape={data.shape}, dtype={data.dtype}")
        
        print("Generating FC matrix with demographic values: age=60, gender=M, mpo=12, wab=80")
        try:
            generated_fc = vae.transform(1, new_demographics, demo_types)
        except Exception as e:
            print(f"Error generating new FC matrix: {e}")
            # Try with a fallback approach
            print("Trying alternative generation approach...")
            # If specific gender is causing issues, try the first gender from training data
            new_demographics[1] = np.array([demo_data[1][0]])
            generated_fc = vae.transform(1, new_demographics, demo_types)
        reconstructed_fc = vae.transform(X, demo_data, demo_types)
        
        # Visualize results
        print("Creating visualizations...")
        fig = plot_fc_matrices(X[0], reconstructed_fc[0], generated_fc[0])
        
        # If requested, return additional data for accuracy calculations
        if return_data:
            # Create a structured outcome measures dictionary
            outcome_measures = {
                'wab_aq': demo_data[3],  # WAB-AQ scores
                # Could add other outcome measures here
            }
            
            results = {
                'vae': vae,
                'X': X,
                'latents': latents,
                'demographics': demographics,
                'reconstructed_fc': reconstructed_fc,
                'generated_fc': generated_fc,
                'analysis_results': analysis_results,
                'outcome_measures': outcome_measures
            }
            return fig, results
        
        return fig
        
    except Exception as e:
        import traceback
        print(f"Error in run_fc_analysis: {str(e)}")
        print(traceback.format_exc())
        
        # Create a dummy figure with error message
        import matplotlib.pyplot as plt
        fig = plt.figure(figsize=(10, 6))
        plt.text(0.5, 0.5, f"Error: {str(e)}", 
                 horizontalalignment='center', verticalalignment='center', 
                 fontsize=12, color='red')
        plt.axis('off')
        
        # Return the error figure and empty results if requested
        if return_data:
            return fig, None
        
        return fig

if __name__ == "__main__":
    import argparse
    
    parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
    parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
                        help='HuggingFace dataset ID or directory containing fMRI data')
    parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
                        help='Path to demographic data CSV file')
    parser.add_argument('--latent_dim', type=int, default=32,
                        help='Dimension of latent space')
    parser.add_argument('--nepochs', type=int, default=1000,
                        help='Number of training epochs')
    parser.add_argument('--bsize', type=int, default=16,
                        help='Batch size for training')
    parser.add_argument('--no_save', action='store_false',
                        help='Do not save the model')
    parser.add_argument('--use_local', action='store_true',
                        help='Use local data instead of HuggingFace dataset')
    
    args = parser.parse_args()
    
    fig = run_fc_analysis(
        data_dir=args.data_dir,
        demographic_file=args.demographic_file,
        latent_dim=args.latent_dim,
        nepochs=args.nepochs,
        bsize=args.bsize,
        save_model=args.no_save,
        use_hf_dataset=not args.use_local
    )
    fig.show()