SreekarB commited on
Commit
ef677f1
·
verified ·
1 Parent(s): 15df58d

Upload 13 files

Browse files
Files changed (13) hide show
  1. README.md +51 -12
  2. analysis.py +16 -0
  3. app.py +52 -0
  4. config.py +24 -0
  5. data_preprocessing.py +79 -0
  6. main.py +185 -0
  7. requirements.txt +12 -0
  8. src/.DS_Store +0 -0
  9. src/demovae/model.py +221 -0
  10. src/demovae/sklearn.py +123 -0
  11. utils.py +186 -0
  12. vae_model.py +150 -0
  13. visualization.py +44 -0
README.md CHANGED
@@ -1,12 +1,51 @@
1
- ---
2
- title: AphasiaPred
3
- emoji: 😻
4
- colorFrom: indigo
5
- colorTo: purple
6
- sdk: gradio
7
- sdk_version: 5.20.1
8
- app_file: app.py
9
- pinned: false
10
- ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Aphasia fMRI to FC Analysis using VAE
2
+
3
+ This demo performs functional connectivity analysis on fMRI data using a Variational Autoencoder (VAE) approach. It's designed to work with aphasia patient data, analyzing brain connectivity patterns and their relationship to demographic variables.
4
+
5
+ ## About the Model
6
+
7
+ This application implements a VAE model that:
8
+ 1. Takes functional connectivity (FC) matrices derived from fMRI data
9
+ 2. Learns a lower-dimensional latent representation of brain connectivity
10
+ 3. Conditions the generation process on demographic variables (age, sex, time post-stroke, WAB scores)
11
+ 4. Allows analysis of relationships between brain connectivity patterns and demographic variables
12
+
13
+ ## Dataset
14
+
15
+ This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
16
+
17
+ - Functional connectivity matrices from fMRI data
18
+ - Demographic information in `FC_graph_covariate_data.csv` including:
19
+ - Age at stroke
20
+ - Sex
21
+ - Months post-stroke
22
+ - WAB scores (aphasia severity)
23
+
24
+ ## How to Use
25
+
26
+ 1. **Data Source**: By default, it uses the HuggingFace dataset. You can change to a local directory if needed.
27
+ 2. **Model Parameters**:
28
+ - Latent Dimensions: Controls the size of the latent space (default: 32)
29
+ - Number of Epochs: Training iterations (default: 1000)
30
+ - Batch Size: Training batch size (default: 16)
31
+
32
+ 3. **Run the Analysis**: The model will:
33
+ - Load and process the data
34
+ - Train the VAE model
35
+ - Analyze relationships between latent variables and demographics
36
+ - Generate visualizations of original, reconstructed, and generated FC matrices
37
+
38
+ ## Outputs
39
+
40
+ The application produces visualizations showing:
41
+ - Original FC matrix
42
+ - Reconstructed FC matrix
43
+ - Generated FC matrix (based on specific demographic inputs)
44
+ - Correlation plots between latent variables and demographic features
45
+
46
+ ## Technical Details
47
+
48
+ - Framework: PyTorch
49
+ - Interface: Gradio
50
+ - Dataset: HuggingFace Datasets API
51
+ - Analysis: Custom implementation of conditional VAE with demographic conditioning
analysis.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from scipy.stats import pearsonr
2
+
3
+ def analyze_fc_patterns(latents, demographics):
4
+ results = {}
5
+ for demo_name, demo_values in demographics.items():
6
+ if demo_name != 'sex': # For continuous variables
7
+ correlations = []
8
+ p_values = []
9
+ for latent_dim in range(latents.shape[1]):
10
+ r, p = pearsonr(latents[:, latent_dim], demo_values)
11
+ correlations.append(r)
12
+ p_values.append(p)
13
+ results[demo_name] = {'correlations': correlations, 'p_values': p_values}
14
+
15
+ return results
16
+
app.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from main import run_fc_analysis
3
+ import os
4
+
5
+ def gradio_fc_analysis(data_source, demographic_file, latent_dim, nepochs, bsize, use_hf_dataset):
6
+ fig = run_fc_analysis(
7
+ data_dir=data_source,
8
+ demographic_file=demographic_file,
9
+ latent_dim=latent_dim,
10
+ nepochs=nepochs,
11
+ bsize=bsize,
12
+ save_model=True,
13
+ use_hf_dataset=use_hf_dataset
14
+ )
15
+ return fig
16
+
17
+ def create_interface():
18
+ iface = gr.Interface(
19
+ fn=gradio_fc_analysis,
20
+ inputs=[
21
+ gr.Textbox(label="Data Source (HF Dataset ID or Local Directory)",
22
+ default="SreekarB/OSFData"),
23
+ gr.Textbox(label="Demographic File",
24
+ default="FC_graph_covariate_data.csv"),
25
+ gr.Slider(minimum=8, maximum=64, step=8,
26
+ label="Latent Dimensions", default=32),
27
+ gr.Slider(minimum=100, maximum=5000, step=100,
28
+ label="Number of Epochs", default=1000),
29
+ gr.Slider(minimum=8, maximum=64, step=8,
30
+ label="Batch Size", default=16),
31
+ gr.Checkbox(label="Use HuggingFace Dataset",
32
+ default=True),
33
+ ],
34
+ outputs="plot",
35
+ title="Aphasia fMRI to FC Analysis using VAE",
36
+ description="""
37
+ Analysis pipeline: fMRI → FC matrices → VAE → Analysis
38
+
39
+ This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
40
+ The demographic file FC_graph_covariate_data.csv contains age_at_stroke, sex, months_post_stroke, and wab_score.
41
+ """,
42
+ examples=[
43
+ ["SreekarB/OSFData", "FC_graph_covariate_data.csv", 32, 500, 16, True],
44
+ ],
45
+ cache_examples=False,
46
+ )
47
+ return iface
48
+
49
+ if __name__ == "__main__":
50
+ iface = create_interface()
51
+ iface.launch(share=True)
52
+
config.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model configuration
2
+ MODEL_CONFIG = {
3
+ 'latent_dim': 32,
4
+ 'nepochs': 1000,
5
+ 'bsize': 16,
6
+ 'loss_rec_mult': 100,
7
+ 'loss_decor_mult': 10,
8
+ 'lr': 1e-4
9
+ }
10
+
11
+ # Preprocessing configuration
12
+ PREPROCESS_CONFIG = {
13
+ 't_r': 2.0,
14
+ 'high_pass': 0.01,
15
+ 'low_pass': 0.1,
16
+ 'radius': 5
17
+ }
18
+
19
+ # Dataset configuration
20
+ DATASET_CONFIG = {
21
+ 'name': 'SreekarB/OSFData',
22
+ 'split': 'train'
23
+ }
24
+
data_preprocessing.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ from datasets import load_dataset
4
+ from nilearn import input_data, connectome
5
+ from nilearn.image import load_img
6
+ import nibabel as nib
7
+
8
+ def preprocess_fmri_to_fc(dataset_name, atlas_path=None):
9
+ dataset = load_dataset(dataset_name, split="train")
10
+
11
+ # Load Power 264 atlas or specified atlas
12
+ if atlas_path is None:
13
+ # Use Power 264 coordinates to create spherical ROIs
14
+ from nilearn import datasets
15
+ power = datasets.fetch_coords_power_2011()
16
+ coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
17
+
18
+ masker = input_data.NiftiSpheresMasker(
19
+ coords, radius=5,
20
+ standardize=True,
21
+ memory='nilearn_cache', memory_level=1,
22
+ verbose=0,
23
+ detrend=True,
24
+ low_pass=0.1,
25
+ high_pass=0.01,
26
+ t_r=2.0 # Adjust TR according to your data
27
+ )
28
+ else:
29
+ masker = input_data.NiftiLabelsMasker(
30
+ labels_img=atlas_path,
31
+ standardize=True,
32
+ memory='nilearn_cache', memory_level=1,
33
+ verbose=0,
34
+ detrend=True,
35
+ low_pass=0.1,
36
+ high_pass=0.01,
37
+ t_r=2.0 # Adjust TR according to your data
38
+ )
39
+
40
+ # Load demographic data
41
+ demo_df = pd.DataFrame(dataset['demographics'])
42
+
43
+ demo_data = [
44
+ demo_df['age_at_stroke'].values,
45
+ demo_df['sex'].values,
46
+ demo_df['months_post_stroke'].values,
47
+ demo_df['wab_score'].values
48
+ ]
49
+
50
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
51
+
52
+ # Process fMRI data and compute FC matrices
53
+ fc_matrices = []
54
+ for nii_file in dataset['nii_files']:
55
+ fmri_img = load_img(nii_file)
56
+ time_series = masker.fit_transform(fmri_img)
57
+
58
+ correlation_measure = connectome.ConnectivityMeasure(
59
+ kind='correlation',
60
+ vectorize=False,
61
+ discard_diagonal=False
62
+ )
63
+
64
+ fc_matrix = correlation_measure.fit_transform([time_series])[0]
65
+
66
+ triu_indices = np.triu_indices_from(fc_matrix, k=1)
67
+ fc_triu = fc_matrix[triu_indices]
68
+
69
+ fc_triu = np.arctanh(fc_triu) # Fisher z-transform
70
+
71
+ fc_matrices.append(fc_triu)
72
+
73
+ X = np.array(fc_matrices)
74
+
75
+ # Normalize the FC data
76
+ X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
77
+
78
+ return X, demo_data, demo_types
79
+
main.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ from pathlib import Path
5
+ import nibabel as nib
6
+ from data_preprocessing import preprocess_fmri_to_fc
7
+ from vae_model import train_fc_vae, DemoVAE
8
+ from analysis import analyze_fc_patterns
9
+ from visualization import visualize_fc_analysis
10
+ from config import MODEL_CONFIG, DATASET_CONFIG
11
+ import pandas as pd
12
+ import io
13
+ from typing import List, Dict, Union, Tuple, Any
14
+
15
+ def load_data(data_dir="SreekarB/OSFData", demographic_file="FC_graph_covariate_data.csv", use_hf_dataset=True):
16
+ """
17
+ Load fMRI data and demographics from HuggingFace dataset or local files
18
+ """
19
+ if use_hf_dataset:
20
+ # Load from HuggingFace Datasets
21
+ from datasets import load_dataset
22
+
23
+ print(f"Loading dataset from HuggingFace: {data_dir}")
24
+ dataset = load_dataset(data_dir)
25
+
26
+ # Load demographics from the dataset
27
+ if demographic_file in dataset["train"].features:
28
+ demo_df = pd.DataFrame(dataset["train"][demographic_file])
29
+ else:
30
+ # Try to load from the dataset files
31
+ try:
32
+ demo_content = dataset["train"][demographic_file][0]
33
+ demo_df = pd.read_csv(io.StringIO(demo_content))
34
+ except Exception as e:
35
+ print(f"Error loading demographics from dataset: {e}")
36
+ # Download the CSV from the dataset repo
37
+ import huggingface_hub
38
+ csv_path = huggingface_hub.hf_hub_download(repo_id=data_dir, filename=demographic_file)
39
+ demo_df = pd.read_csv(csv_path)
40
+
41
+ # Extract demographic data
42
+ demo_data = [
43
+ demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else np.array([]),
44
+ demo_df['sex'].values if 'sex' in demo_df.columns else np.array([]),
45
+ demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else np.array([]),
46
+ demo_df['wab_score'].values if 'wab_score' in demo_df.columns else np.array([])
47
+ ]
48
+
49
+ # Get fMRI/FC files from dataset
50
+ nii_files = []
51
+ for f in dataset["train"].features:
52
+ if f.endswith(".nii.gz") or f.endswith(".nii"):
53
+ nii_files.append(f)
54
+
55
+ if not nii_files:
56
+ print("No .nii/.nii.gz files found in dataset, checking for FC matrices")
57
+ # Try to find FC matrices directly
58
+ fc_matrices = []
59
+ for f in dataset["train"].features:
60
+ if f.startswith("fc_") or f.endswith("_fc"):
61
+ fc_matrices.append(dataset["train"][f])
62
+
63
+ if fc_matrices:
64
+ print(f"Found {len(fc_matrices)} FC matrices in dataset")
65
+ return fc_matrices, demo_data, demo_types
66
+ else:
67
+ # Original local file loading
68
+ # Load demographics
69
+ demo_df = pd.read_csv(demographic_file)
70
+
71
+ demo_data = [
72
+ demo_df['age_at_stroke'].values,
73
+ demo_df['sex'].values,
74
+ demo_df['months_post_stroke'].values,
75
+ demo_df['wab_score'].values
76
+ ]
77
+
78
+ # Load fMRI files
79
+ nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
80
+
81
+ demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
82
+ return nii_files, demo_data, demo_types
83
+
84
+ def run_fc_analysis(data_dir="SreekarB/OSFData",
85
+ demographic_file="FC_graph_covariate_data.csv",
86
+ latent_dim=32,
87
+ nepochs=1000,
88
+ bsize=16,
89
+ save_model=True,
90
+ use_hf_dataset=True):
91
+
92
+ # Update MODEL_CONFIG with user-specified parameters
93
+ MODEL_CONFIG.update({
94
+ 'latent_dim': latent_dim,
95
+ 'nepochs': nepochs,
96
+ 'bsize': bsize
97
+ })
98
+
99
+ # Load data
100
+ print("Loading data...")
101
+ nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
102
+
103
+ # Add import for io module if it's missing
104
+ import io
105
+
106
+ # Check if we got FC matrices directly
107
+ if isinstance(nii_files, list) and all(isinstance(item, np.ndarray) for item in nii_files):
108
+ print("Using pre-computed FC matrices...")
109
+ X = np.stack(nii_files)
110
+ else:
111
+ # Prepare data by converting fMRI to FC matrices
112
+ print("Converting fMRI data to FC matrices...")
113
+ X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
114
+
115
+ # Train VAE and get data
116
+ print("Training VAE...")
117
+ vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
118
+
119
+ if save_model:
120
+ print("Saving model...")
121
+ os.makedirs('models', exist_ok=True)
122
+ torch.save(vae.state_dict(), 'models/vae_model.pth')
123
+
124
+ # Get latent representations
125
+ print("Getting latent representations...")
126
+ latents = vae.get_latents(X)
127
+
128
+ # Analyze results
129
+ print("Analyzing demographic relationships...")
130
+ demographics = {
131
+ 'age_at_stroke': demo_data[0] if len(demo_data[0]) > 0 else np.zeros(len(X)),
132
+ 'months_post_stroke': demo_data[2] if len(demo_data[2]) > 0 else np.zeros(len(X)),
133
+ 'wab_score': demo_data[3] if len(demo_data[3]) > 0 else np.zeros(len(X))
134
+ }
135
+ analysis_results = analyze_fc_patterns(latents, demographics)
136
+
137
+ # Generate new FC matrix
138
+ print("Generating new FC matrices...")
139
+ new_demographics = [
140
+ [60.0], # age at stroke
141
+ ['M'], # sex
142
+ [12.0], # months post stroke
143
+ [80.0] # wab score
144
+ ]
145
+ generated_fc = vae.transform(1, new_demographics, demo_types)
146
+ reconstructed_fc = vae.transform(X, demo_data, demo_types)
147
+
148
+ # Visualize results
149
+ print("Creating visualizations...")
150
+ fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
151
+
152
+ return fig
153
+
154
+ if __name__ == "__main__":
155
+ import argparse
156
+
157
+ parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
158
+ parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
159
+ help='HuggingFace dataset ID or directory containing fMRI data')
160
+ parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
161
+ help='Path to demographic data CSV file')
162
+ parser.add_argument('--latent_dim', type=int, default=32,
163
+ help='Dimension of latent space')
164
+ parser.add_argument('--nepochs', type=int, default=1000,
165
+ help='Number of training epochs')
166
+ parser.add_argument('--bsize', type=int, default=16,
167
+ help='Batch size for training')
168
+ parser.add_argument('--no_save', action='store_false',
169
+ help='Do not save the model')
170
+ parser.add_argument('--use_local', action='store_true',
171
+ help='Use local data instead of HuggingFace dataset')
172
+
173
+ args = parser.parse_args()
174
+
175
+ fig = run_fc_analysis(
176
+ data_dir=args.data_dir,
177
+ demographic_file=args.demographic_file,
178
+ latent_dim=args.latent_dim,
179
+ nepochs=args.nepochs,
180
+ bsize=args.bsize,
181
+ save_model=args.no_save,
182
+ use_hf_dataset=not args.use_local
183
+ )
184
+ fig.show()
185
+
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=1.9.0
2
+ numpy>=1.19.2
3
+ pandas>=1.2.4
4
+ nilearn>=0.8.1
5
+ nibabel>=3.2.1
6
+ scikit-learn>=0.24.2
7
+ matplotlib>=3.4.2
8
+ gradio>=2.0.0
9
+ datasets>=1.11.0
10
+ huggingface_hub>=0.12.0
11
+ transformers>=4.15.0
12
+
src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
src/demovae/model.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ import random
7
+ import numpy as np
8
+
9
+ from sklearn.linear_model import Ridge
10
+ from sklearn.linear_model import LogisticRegression
11
+
12
+ def to_torch(x):
13
+ return torch.from_numpy(x).float()
14
+
15
+ def to_cuda(x, use_cuda):
16
+ if use_cuda:
17
+ return x.cuda()
18
+ else:
19
+ return x
20
+
21
+ def to_numpy(x):
22
+ return x.detach().cpu().numpy()
23
+
24
+ class VAE(nn.Module):
25
+ def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
26
+ super(VAE, self).__init__()
27
+ self.input_dim = input_dim
28
+ self.latent_dim = latent_dim
29
+ self.demo_dim = demo_dim
30
+ self.use_cuda = use_cuda
31
+ self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
32
+ self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
33
+ self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
34
+ self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
35
+
36
+ def enc(self, x):
37
+ x = F.relu(self.enc1(x))
38
+ z = self.enc2(x)
39
+ return z
40
+
41
+ def gen(self, n):
42
+ return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
43
+
44
+ def dec(self, z, demo):
45
+ z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
46
+ x = F.relu(self.dec1(z))
47
+ x = self.dec2(x)
48
+ #x = x.reshape(len(z), 264, 5)
49
+ #x = torch.einsum('nac,nbc->nab', x, x)
50
+ #a,b = np.triu_indices(264, 1)
51
+ #x = x[:,a,b]
52
+ return x
53
+
54
+ def rmse(a, b, mean=torch.mean):
55
+ return mean((a-b)**2)**0.5
56
+
57
+ def latent_loss(z, use_cuda=True):
58
+ C = z.T@z
59
+ mu = torch.mean(z, dim=0)
60
+ tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
61
+ tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
62
+ loss_C = rmse(C, tgt1)
63
+ loss_mu = rmse(mu, tgt2)
64
+ return loss_C, loss_mu, C, mu
65
+
66
+ def decor_loss(z, demo, use_cuda=True):
67
+ ps = []
68
+ losses = []
69
+ for di in range(demo.shape[1]):
70
+ d = demo[:,di]
71
+ d = d - torch.mean(d)
72
+ p = torch.einsum('n,nz->z', d, z)
73
+ p = p/torch.std(d)
74
+ p = p/torch.einsum('nz,nz->z', z, z)
75
+ tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
76
+ loss = rmse(p, tgt)
77
+ losses.append(loss)
78
+ ps.append(p)
79
+ losses = torch.stack(losses)
80
+ return losses, ps
81
+
82
+ def pretty(x):
83
+ return f'{round(float(x), 4)}'
84
+
85
+ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
86
+ demo_t = []
87
+ demo_idx = 0
88
+ for d,t,s in zip(demo, demo_types, pred_stats):
89
+ if t == 'continuous':
90
+ demo_t.append(to_cuda(to_torch(d), use_cuda))
91
+ elif t == 'categorical':
92
+ for dd in d:
93
+ if dd not in s:
94
+ print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
95
+ raise Exception('Bad demographic')
96
+ for ss in s:
97
+ idx = (d == ss).astype('bool')
98
+ zeros = torch.zeros(len(d))
99
+ zeros[idx] = 1
100
+ demo_t.append(to_cuda(zeros, use_cuda))
101
+ demo_idx += 1
102
+ demo_t = torch.stack(demo_t).permute(1,0)
103
+ return demo_t
104
+
105
+ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
106
+ # Get linear predictors for demographics
107
+ pred_w = []
108
+ pred_i = []
109
+ # Pred stats are mean and std for continuous, and a list of all values for categorical
110
+ pred_stats = []
111
+ for i,d,t in zip(range(len(demo)), demo, demo_types):
112
+ print(f'Fitting auxilliary guidance model for demographic {i} {t}...', end='')
113
+ if t == 'continuous':
114
+ pred_stats.append([np.mean(d), np.std(d)])
115
+ reg = Ridge(alpha=alpha).fit(x, d)
116
+ reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
117
+ reg_i = reg.intercept_
118
+ pred_w.append(reg_w)
119
+ pred_i.append(reg_i)
120
+ elif t == 'categorical':
121
+ pred_stats.append(sorted(list(set(list(d)))))
122
+ reg = LogisticRegression(C=LR_C).fit(x, d)
123
+ # Binary
124
+ if len(reg.coef_) == 1:
125
+ reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
126
+ reg_i = reg.intercept_[0]
127
+ pred_w.append(-reg_w)
128
+ pred_i.append(-reg_i)
129
+ pred_w.append(reg_w)
130
+ pred_i.append(reg_i)
131
+ # Categorical
132
+ else:
133
+ for i in range(len(reg.coef_)):
134
+ reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
135
+ reg_i = reg.intercept_[i]
136
+ pred_w.append(reg_w)
137
+ pred_i.append(reg_i)
138
+ else:
139
+ print(f'demographic type "{t}" not "continuous" or "categorical"')
140
+ raise Exception('Bad demographic type')
141
+ print(' done')
142
+ ret_obj.pred_stats = pred_stats
143
+ # Convert input to pytorch
144
+ print('Converting input to pytorch')
145
+ x = to_cuda(to_torch(x), vae.use_cuda)
146
+ # Convert demographics to pytorch
147
+ print('Converting demographics to pytorch')
148
+ demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
149
+ # Training loop
150
+ print('Beginning VAE training')
151
+ ce = nn.CrossEntropyLoss()
152
+ optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
153
+ for e in range(nepochs):
154
+ for bs in range(0,len(x),bsize):
155
+ xb = x[bs:(bs+bsize)]
156
+ db = demo_t[bs:(bs+bsize)]
157
+ optim.zero_grad()
158
+ # Reconstruct
159
+ z = vae.enc(xb)
160
+ y = vae.dec(z, db)
161
+ loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
162
+ loss_decor, _ = decor_loss(z, db, vae.use_cuda)
163
+ loss_decor = sum(loss_decor)
164
+ loss_rec = rmse(xb, y)
165
+ # Sample demographics
166
+ demo_gen = []
167
+ for s,t in zip(pred_stats, demo_types):
168
+ if t == 'continuous':
169
+ mu = s[0]
170
+ std = s[1]
171
+ dd = torch.randn(100).float()
172
+ dd = dd*std+mu
173
+ dd = to_cuda(dd, vae.use_cuda)
174
+ demo_gen.append(dd)
175
+ elif t == 'categorical':
176
+ idx = random.randint(0, len(s)-1)
177
+ for i in range(len(s)):
178
+ if idx == i:
179
+ dd = torch.ones(100).float()
180
+ else:
181
+ dd = torch.zeros(100).float()
182
+ dd = to_cuda(dd, vae.use_cuda)
183
+ demo_gen.append(dd)
184
+ demo_gen = torch.stack(demo_gen).permute(1,0)
185
+ # Generate
186
+ z = vae.gen(100)
187
+ y = vae.dec(z, demo_gen)
188
+ # Regressor/classifier guidance loss
189
+ losses_pred = []
190
+ idcs = []
191
+ dg_idx = 0
192
+ for s,t in zip(pred_stats, demo_types):
193
+ if t == 'continuous':
194
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
195
+ loss = rmse(demo_gen[:,dg_idx], yy)
196
+ losses_pred.append(loss)
197
+ idcs.append(float(demo_gen[0,dg_idx]))
198
+ dg_idx += 1
199
+ elif t == 'categorical':
200
+ loss = 0
201
+ for i in range(len(s)):
202
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
203
+ loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
204
+ idcs.append(int(demo_gen[0,dg_idx]))
205
+ dg_idx += 1
206
+ losses_pred.append(loss)
207
+ total_loss = loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor + loss_pred_mult*sum(losses_pred)
208
+ total_loss.backward()
209
+ optim.step()
210
+ if e%pperiod == 0 or e == nepochs-1:
211
+ print(f'Epoch {e} ', end='')
212
+ print(f'ReconLoss {pretty(loss_rec)} ', end='')
213
+ print(f'CovarianceLoss {pretty(loss_C)} ', end='')
214
+ print(f'MeanLoss {pretty(loss_mu)} ', end='')
215
+ print(f'DecorLoss {pretty(loss_decor)} ', end='')
216
+ losses_pred = [pretty(loss) for loss in losses_pred]
217
+ print(f'GuidanceTargets {idcs} GuidanceLosses {losses_pred} ', end='')
218
+ print()
219
+ print('Training complete.')
220
+
221
+
src/demovae/sklearn.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from demovae.model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
3
+
4
+ from sklearn.base import BaseEstimator
5
+
6
+ # For saving
7
+ import torch
8
+
9
+ class DemoVAE(BaseEstimator):
10
+ def __init__(self, **params):
11
+ self.set_params(**params)
12
+
13
+ @staticmethod
14
+ def get_default_params():
15
+ return dict(latent_dim=60, # Latent dimension
16
+ use_cuda=True, # GPU acceleration
17
+ nepochs=3000, # Training epochs
18
+ pperiod=100, # Epochs between printing updates
19
+ bsize=1000, # Batch size
20
+ loss_C_mult=1, # Covariance loss (KL div)
21
+ loss_mu_mult=1, # Mean loss (KL div)
22
+ loss_rec_mult=100, # Reconstruction loss
23
+ loss_decor_mult=10, # Latent-demographic decorrelation loss
24
+ loss_pred_mult=0.001, # Classifier/regressor guidance loss
25
+ alpha=100, # Regularization for continuous guidance models
26
+ LR_C=100, # Regularization for categorical guidance models
27
+ lr=1e-4, # Learning rate
28
+ weight_decay=0, # L2 regularization for VAE model
29
+ )
30
+
31
+ def get_params(self, **params):
32
+ return dict(latent_dim=self.latent_dim,
33
+ use_cuda=self.use_cuda,
34
+ nepochs=self.nepochs,
35
+ pperiod=self.pperiod,
36
+ bsize=self.bsize,
37
+ loss_C_mult=self.loss_C_mult,
38
+ loss_mu_mult=self.loss_mu_mult,
39
+ loss_rec_mult=self.loss_rec_mult,
40
+ loss_decor_mult=self.loss_decor_mult,
41
+ loss_pred_mult=self.loss_pred_mult,
42
+ alpha=self.alpha,
43
+ LR_C=self.LR_C,
44
+ lr=self.lr,
45
+ weight_decay=self.weight_decay,
46
+ )
47
+
48
+ def set_params(self, **params):
49
+ dft = DemoVAE.get_default_params()
50
+ for key in dft:
51
+ if key in params:
52
+ setattr(self, key, params[key])
53
+ else:
54
+ setattr(self, key, dft[key])
55
+ return self
56
+
57
+ def fit(self, x, demo, demo_types, **kwargs):
58
+ # Get demo_dim
59
+ demo_dim = 0
60
+ for d,t in zip(demo, demo_types):
61
+ if t == 'continuous':
62
+ demo_dim += 1
63
+ elif t == 'categorical':
64
+ ll = len(set(list(d)))
65
+ if ll == 1:
66
+ print('Only one type of category for categorical variable')
67
+ raise Exception('Bad categorical')
68
+ demo_dim += ll
69
+ else:
70
+ print(f'demographic type "{t}" not "continuous" or "categorical"')
71
+ raise Exception('Bad demographic type')
72
+ # Save parameters
73
+ self.input_dim = x.shape[1]
74
+ self.demo_dim = demo_dim
75
+ # Create model
76
+ self.vae = VAE(x.shape[1], self.latent_dim, demo_dim, self.use_cuda)
77
+ # Train model
78
+ train_vae(self.vae, x, demo, demo_types,
79
+ self.nepochs, self.pperiod, self.bsize,
80
+ self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, self.loss_decor_mult, self.loss_pred_mult,
81
+ self.lr, self.weight_decay, self.alpha, self.LR_C,
82
+ self)
83
+ return self
84
+
85
+ def transform(self, x, demo, demo_types, **kwargs):
86
+ if isinstance(x, int):
87
+ # Generate
88
+ z = self.vae.gen(x)
89
+ else:
90
+ # Get latents for real data
91
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
92
+ demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
93
+ y = self.vae.dec(z, demo_t)
94
+ return to_numpy(y)
95
+
96
+ def fit_transform(self, x, demo, demo_types, **kwargs):
97
+ self.fit(x, demo, demo_types)
98
+ return self.transform(x, demo, demo_types)
99
+
100
+ def get_latents(self, x):
101
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
102
+ return to_numpy(z)
103
+
104
+ def save(self, path):
105
+ params = self.get_params()
106
+ dct = dict(pred_stats=self.pred_stats,
107
+ params=params,
108
+ input_dim=self.input_dim,
109
+ demo_dim=self.demo_dim,
110
+ model_state_dict=self.vae.state_dict())
111
+ torch.save(dct, path)
112
+
113
+ def load(self, path):
114
+ dct = torch.load(path)
115
+ self.pred_stats = dct['pred_stats']
116
+ self.set_params(**dct['params'])
117
+ self.vae = VAE(dct['input_dim'],
118
+ dct['params']['latent_dim'],
119
+ dct['demo_dim'],
120
+ dct['params']['use_cuda'])
121
+ self.vae.load_state_dict(dct['model_state_dict'])
122
+
123
+
utils.py ADDED
@@ -0,0 +1,186 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from sklearn.linear_model import Ridge, LogisticRegression
4
+
5
+ def to_torch(x):
6
+ return torch.from_numpy(x).float()
7
+
8
+ def to_cuda(x, use_cuda):
9
+ return x.cuda() if use_cuda else x
10
+
11
+ def to_numpy(x):
12
+ return x.detach().cpu().numpy()
13
+
14
+ def fc_matrix_from_triu(triu_values, n_rois=264):
15
+ fc_matrix = np.zeros((n_rois, n_rois))
16
+ triu_indices = np.triu_indices(n_rois, k=1)
17
+ triu_values = np.tanh(triu_values)
18
+ fc_matrix[triu_indices] = triu_values
19
+ fc_matrix = fc_matrix + fc_matrix.T
20
+ np.fill_diagonal(fc_matrix, 1)
21
+ return fc_matrix
22
+
23
+ def rmse(a, b, mean=torch.mean):
24
+ return mean((a-b)**2)**0.5
25
+
26
+ def latent_loss(z, use_cuda=True):
27
+ C = z.T@z
28
+ mu = torch.mean(z, dim=0)
29
+ tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
30
+ tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
31
+ loss_C = rmse(C, tgt1)
32
+ loss_mu = rmse(mu, tgt2)
33
+ return loss_C, loss_mu, C, mu
34
+
35
+ def decor_loss(z, demo, use_cuda=True):
36
+ ps = []
37
+ losses = []
38
+ for di in range(demo.shape[1]):
39
+ d = demo[:,di]
40
+ d = d - torch.mean(d)
41
+ p = torch.einsum('n,nz->z', d, z)
42
+ p = p/torch.std(d)
43
+ p = p/torch.einsum('nz,nz->z', z, z)
44
+ tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
45
+ loss = rmse(p, tgt)
46
+ losses.append(loss)
47
+ ps.append(p)
48
+ losses = torch.stack(losses)
49
+ return losses, ps
50
+ def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
51
+ demo_t = []
52
+ demo_idx = 0
53
+ for d, t, s in zip(demo, demo_types, pred_stats):
54
+ if t == 'continuous':
55
+ demo_t.append(to_cuda(to_torch(d), use_cuda))
56
+ elif t == 'categorical':
57
+ for dd in d:
58
+ if dd not in s:
59
+ print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
60
+ raise Exception('Bad demographic')
61
+ for ss in s:
62
+ idx = (d == ss).astype('bool')
63
+ zeros = torch.zeros(len(d))
64
+ zeros[idx] = 1
65
+ demo_t.append(to_cuda(zeros, use_cuda))
66
+ demo_idx += 1
67
+ demo_t = torch.stack(demo_t).permute(1,0)
68
+ return demo_t
69
+
70
+ def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
71
+ loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
72
+ loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
73
+ # Get linear predictors for demographics
74
+ pred_w = []
75
+ pred_i = []
76
+ pred_stats = []
77
+
78
+ for i, d, t in zip(range(len(demo)), demo, demo_types):
79
+ print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
80
+ if t == 'continuous':
81
+ pred_stats.append([np.mean(d), np.std(d)])
82
+ reg = Ridge(alpha=alpha).fit(x, d)
83
+ reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
84
+ reg_i = reg.intercept_
85
+ pred_w.append(reg_w)
86
+ pred_i.append(reg_i)
87
+ elif t == 'categorical':
88
+ pred_stats.append(sorted(list(set(list(d)))))
89
+ reg = LogisticRegression(C=LR_C).fit(x, d)
90
+ if len(reg.coef_) == 1:
91
+ reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
92
+ reg_i = reg.intercept_[0]
93
+ pred_w.append(-reg_w)
94
+ pred_i.append(-reg_i)
95
+ pred_w.append(reg_w)
96
+ pred_i.append(reg_i)
97
+ else:
98
+ for i in range(len(reg.coef_)):
99
+ reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
100
+ reg_i = reg.intercept_[i]
101
+ pred_w.append(reg_w)
102
+ pred_i.append(reg_i)
103
+ print(' done')
104
+
105
+ ret_obj.pred_stats = pred_stats
106
+
107
+ # Convert input to pytorch
108
+ x = to_cuda(to_torch(x), vae.use_cuda)
109
+
110
+ # Convert demographics to pytorch
111
+ demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
112
+
113
+ # Training loop
114
+ ce = torch.nn.CrossEntropyLoss()
115
+ optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
116
+
117
+ for e in range(nepochs):
118
+ for bs in range(0, len(x), bsize):
119
+ xb = x[bs:(bs+bsize)]
120
+ db = demo_t[bs:(bs+bsize)]
121
+ optim.zero_grad()
122
+
123
+ # Reconstruct
124
+ z = vae.enc(xb)
125
+ y = vae.dec(z, db)
126
+ loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
127
+ loss_decor, _ = decor_loss(z, db, vae.use_cuda)
128
+ loss_decor = sum(loss_decor)
129
+ loss_rec = rmse(xb, y)
130
+
131
+ # Sample demographics
132
+ demo_gen = []
133
+ for s, t in zip(pred_stats, demo_types):
134
+ if t == 'continuous':
135
+ mu, std = s
136
+ dd = torch.randn(100).float()
137
+ dd = dd*std+mu
138
+ dd = to_cuda(dd, vae.use_cuda)
139
+ demo_gen.append(dd)
140
+ elif t == 'categorical':
141
+ idx = np.random.randint(0, len(s))
142
+ for i in range(len(s)):
143
+ dd = torch.ones(100).float() if idx == i else torch.zeros(100).float()
144
+ dd = to_cuda(dd, vae.use_cuda)
145
+ demo_gen.append(dd)
146
+
147
+ demo_gen = torch.stack(demo_gen).permute(1,0)
148
+
149
+ # Generate
150
+ z = vae.gen(100)
151
+ y = vae.dec(z, demo_gen)
152
+
153
+ # Regressor/classifier guidance loss
154
+ losses_pred = []
155
+ idcs = []
156
+ dg_idx = 0
157
+
158
+ for s, t in zip(pred_stats, demo_types):
159
+ if t == 'continuous':
160
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
161
+ loss = rmse(demo_gen[:,dg_idx], yy)
162
+ losses_pred.append(loss)
163
+ idcs.append(float(demo_gen[0,dg_idx]))
164
+ dg_idx += 1
165
+ elif t == 'categorical':
166
+ loss = 0
167
+ for i in range(len(s)):
168
+ yy = y@pred_w[dg_idx]+pred_i[dg_idx]
169
+ loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
170
+ idcs.append(int(demo_gen[0,dg_idx]))
171
+ dg_idx += 1
172
+ losses_pred.append(loss)
173
+
174
+ total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
175
+ loss_rec_mult*loss_rec + loss_decor_mult*loss_decor +
176
+ loss_pred_mult*sum(losses_pred))
177
+
178
+ total_loss.backward()
179
+ optim.step()
180
+
181
+ if e%pperiod == 0 or e == nepochs-1:
182
+ print(f'Epoch {e} ReconLoss {loss_rec:.4f} CovarianceLoss {loss_C:.4f} '
183
+ f'MeanLoss {loss_mu:.4f} DecorLoss {loss_decor:.4f}')
184
+ print(f'GuidanceTargets {idcs}')
185
+ print(f'GuidanceLosses {[f"{loss:.4f}" for loss in losses_pred]}')
186
+
vae_model.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from utils import to_torch, to_cuda, to_numpy, demo_to_torch
6
+ from sklearn.base import BaseEstimator
7
+
8
+ class VAE(nn.Module):
9
+ def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
10
+ super(VAE, self).__init__()
11
+ self.input_dim = input_dim
12
+ self.latent_dim = latent_dim
13
+ self.demo_dim = demo_dim
14
+ self.use_cuda = use_cuda
15
+
16
+ # Encoder
17
+ self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
18
+ self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
19
+
20
+ # Decoder
21
+ self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
22
+ self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
23
+
24
+ # Batch normalization layers
25
+ self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
26
+ self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
27
+
28
+ def enc(self, x):
29
+ x = self.bn1(F.relu(self.enc1(x)))
30
+ z = self.enc2(x)
31
+ return z
32
+
33
+ def gen(self, n):
34
+ return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
35
+
36
+ def dec(self, z, demo):
37
+ z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
38
+ x = self.bn2(F.relu(self.dec1(z)))
39
+ x = self.dec2(x)
40
+ return x
41
+
42
+ class DemoVAE(BaseEstimator):
43
+ def __init__(self, **params):
44
+ self.set_params(**params)
45
+
46
+ @staticmethod
47
+ def get_default_params():
48
+ return dict(
49
+ latent_dim=32,
50
+ use_cuda=True,
51
+ nepochs=1000,
52
+ pperiod=100,
53
+ bsize=16,
54
+ loss_C_mult=1,
55
+ loss_mu_mult=1,
56
+ loss_rec_mult=100,
57
+ loss_decor_mult=10,
58
+ loss_pred_mult=0.001,
59
+ alpha=100,
60
+ LR_C=100,
61
+ lr=1e-4,
62
+ weight_decay=0
63
+ )
64
+
65
+ def get_params(self, deep=True):
66
+ return {k: getattr(self, k) for k in self.get_default_params().keys()}
67
+
68
+ def set_params(self, **params):
69
+ for k, v in self.get_default_params().items():
70
+ setattr(self, k, params.get(k, v))
71
+ return self
72
+
73
+ def fit(self, x, demo, demo_types):
74
+ from utils import train_vae
75
+
76
+ # Calculate demo_dim
77
+ demo_dim = 0
78
+ for d, t in zip(demo, demo_types):
79
+ if t == 'continuous':
80
+ demo_dim += 1
81
+ elif t == 'categorical':
82
+ demo_dim += len(set(d))
83
+ else:
84
+ raise ValueError(f'Demographic type "{t}" not supported')
85
+
86
+ # Initialize VAE
87
+ self.input_dim = x.shape[1]
88
+ self.demo_dim = demo_dim
89
+ self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
90
+
91
+ # Train VAE
92
+ train_vae(
93
+ self.vae, x, demo, demo_types,
94
+ self.nepochs, self.pperiod, self.bsize,
95
+ self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
96
+ self.loss_decor_mult, self.loss_pred_mult,
97
+ self.lr, self.weight_decay, self.alpha, self.LR_C,
98
+ self
99
+ )
100
+ return self
101
+
102
+ def transform(self, x, demo, demo_types):
103
+ if isinstance(x, int):
104
+ z = self.vae.gen(x)
105
+ else:
106
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
107
+ demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
108
+ y = self.vae.dec(z, demo_t)
109
+ return to_numpy(y)
110
+
111
+ def get_latents(self, x):
112
+ z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
113
+ return to_numpy(z)
114
+
115
+ def save(self, path):
116
+ torch.save({
117
+ 'model_state_dict': self.vae.state_dict(),
118
+ 'params': self.get_params(),
119
+ 'pred_stats': self.pred_stats,
120
+ 'input_dim': self.input_dim,
121
+ 'demo_dim': self.demo_dim
122
+ }, path)
123
+
124
+ def load(self, path):
125
+ checkpoint = torch.load(path)
126
+ self.set_params(**checkpoint['params'])
127
+ self.pred_stats = checkpoint['pred_stats']
128
+ self.input_dim = checkpoint['input_dim']
129
+ self.demo_dim = checkpoint['demo_dim']
130
+ self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
131
+ self.vae.load_state_dict(checkpoint['model_state_dict'])
132
+
133
+ def train_fc_vae(X, demo_data, demo_types, model_config):
134
+ n_rois = 264
135
+ input_dim = (n_rois * (n_rois - 1)) // 2
136
+
137
+ vae = DemoVAE(
138
+ latent_dim=model_config['latent_dim'],
139
+ nepochs=model_config['nepochs'],
140
+ bsize=model_config['bsize'],
141
+ loss_rec_mult=model_config['loss_rec_mult'],
142
+ loss_decor_mult=model_config['loss_decor_mult'],
143
+ lr=model_config['lr'],
144
+ use_cuda=torch.cuda.is_available()
145
+ )
146
+
147
+ vae.fit(X, demo_data, demo_types)
148
+
149
+ return vae, X, demo_data, demo_types
150
+
visualization.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+ from utils import fc_matrix_from_triu
4
+
5
+ def visualize_fc_analysis(original_triu, reconstructed_triu, generated_triu, analysis_results=None):
6
+ fig = plt.figure(figsize=(15, 10))
7
+ gs = plt.GridSpec(2, 3)
8
+
9
+ ax1 = fig.add_subplot(gs[0, 0])
10
+ ax2 = fig.add_subplot(gs[0, 1])
11
+ ax3 = fig.add_subplot(gs[0, 2])
12
+
13
+ original = fc_matrix_from_triu(original_triu)
14
+ reconstructed = fc_matrix_from_triu(reconstructed_triu)
15
+ generated = fc_matrix_from_triu(generated_triu)
16
+
17
+ im1 = ax1.imshow(original, cmap='RdBu_r', vmin=-1, vmax=1)
18
+ ax1.set_title('Original FC')
19
+
20
+ im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=-1, vmax=1)
21
+ ax2.set_title('Reconstructed FC')
22
+
23
+ im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=-1, vmax=1)
24
+ ax3.set_title('Generated FC')
25
+
26
+ plt.colorbar(im1, ax=ax1)
27
+ plt.colorbar(im2, ax=ax2)
28
+ plt.colorbar(im3, ax=ax3)
29
+
30
+ if analysis_results is not None:
31
+ ax4 = fig.add_subplot(gs[1, :])
32
+ for demo_name, results in analysis_results.items():
33
+ significant_dims = np.where(np.array(results['p_values']) < 0.05)[0]
34
+ correlations = np.array(results['correlations'])
35
+ ax4.plot(correlations, label=f'{demo_name} (sig. dims: {len(significant_dims)})')
36
+
37
+ ax4.set_xlabel('Latent Dimension')
38
+ ax4.set_ylabel('Correlation Strength')
39
+ ax4.set_title('Demographic Correlations with Latent Dimensions')
40
+ ax4.legend()
41
+
42
+ plt.tight_layout()
43
+ return fig
44
+