Spaces:
Sleeping
Sleeping
Upload 13 files
Browse files- README.md +51 -12
- analysis.py +16 -0
- app.py +52 -0
- config.py +24 -0
- data_preprocessing.py +79 -0
- main.py +185 -0
- requirements.txt +12 -0
- src/.DS_Store +0 -0
- src/demovae/model.py +221 -0
- src/demovae/sklearn.py +123 -0
- utils.py +186 -0
- vae_model.py +150 -0
- visualization.py +44 -0
README.md
CHANGED
|
@@ -1,12 +1,51 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Aphasia fMRI to FC Analysis using VAE
|
| 2 |
+
|
| 3 |
+
This demo performs functional connectivity analysis on fMRI data using a Variational Autoencoder (VAE) approach. It's designed to work with aphasia patient data, analyzing brain connectivity patterns and their relationship to demographic variables.
|
| 4 |
+
|
| 5 |
+
## About the Model
|
| 6 |
+
|
| 7 |
+
This application implements a VAE model that:
|
| 8 |
+
1. Takes functional connectivity (FC) matrices derived from fMRI data
|
| 9 |
+
2. Learns a lower-dimensional latent representation of brain connectivity
|
| 10 |
+
3. Conditions the generation process on demographic variables (age, sex, time post-stroke, WAB scores)
|
| 11 |
+
4. Allows analysis of relationships between brain connectivity patterns and demographic variables
|
| 12 |
+
|
| 13 |
+
## Dataset
|
| 14 |
+
|
| 15 |
+
This demo uses the [SreekarB/OSFData](https://huggingface.co/datasets/SreekarB/OSFData) dataset from HuggingFace, which contains:
|
| 16 |
+
|
| 17 |
+
- Functional connectivity matrices from fMRI data
|
| 18 |
+
- Demographic information in `FC_graph_covariate_data.csv` including:
|
| 19 |
+
- Age at stroke
|
| 20 |
+
- Sex
|
| 21 |
+
- Months post-stroke
|
| 22 |
+
- WAB scores (aphasia severity)
|
| 23 |
+
|
| 24 |
+
## How to Use
|
| 25 |
+
|
| 26 |
+
1. **Data Source**: By default, it uses the HuggingFace dataset. You can change to a local directory if needed.
|
| 27 |
+
2. **Model Parameters**:
|
| 28 |
+
- Latent Dimensions: Controls the size of the latent space (default: 32)
|
| 29 |
+
- Number of Epochs: Training iterations (default: 1000)
|
| 30 |
+
- Batch Size: Training batch size (default: 16)
|
| 31 |
+
|
| 32 |
+
3. **Run the Analysis**: The model will:
|
| 33 |
+
- Load and process the data
|
| 34 |
+
- Train the VAE model
|
| 35 |
+
- Analyze relationships between latent variables and demographics
|
| 36 |
+
- Generate visualizations of original, reconstructed, and generated FC matrices
|
| 37 |
+
|
| 38 |
+
## Outputs
|
| 39 |
+
|
| 40 |
+
The application produces visualizations showing:
|
| 41 |
+
- Original FC matrix
|
| 42 |
+
- Reconstructed FC matrix
|
| 43 |
+
- Generated FC matrix (based on specific demographic inputs)
|
| 44 |
+
- Correlation plots between latent variables and demographic features
|
| 45 |
+
|
| 46 |
+
## Technical Details
|
| 47 |
+
|
| 48 |
+
- Framework: PyTorch
|
| 49 |
+
- Interface: Gradio
|
| 50 |
+
- Dataset: HuggingFace Datasets API
|
| 51 |
+
- Analysis: Custom implementation of conditional VAE with demographic conditioning
|
analysis.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from scipy.stats import pearsonr
|
| 2 |
+
|
| 3 |
+
def analyze_fc_patterns(latents, demographics):
|
| 4 |
+
results = {}
|
| 5 |
+
for demo_name, demo_values in demographics.items():
|
| 6 |
+
if demo_name != 'sex': # For continuous variables
|
| 7 |
+
correlations = []
|
| 8 |
+
p_values = []
|
| 9 |
+
for latent_dim in range(latents.shape[1]):
|
| 10 |
+
r, p = pearsonr(latents[:, latent_dim], demo_values)
|
| 11 |
+
correlations.append(r)
|
| 12 |
+
p_values.append(p)
|
| 13 |
+
results[demo_name] = {'correlations': correlations, 'p_values': p_values}
|
| 14 |
+
|
| 15 |
+
return results
|
| 16 |
+
|
app.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from main import run_fc_analysis
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
def gradio_fc_analysis(data_source, demographic_file, latent_dim, nepochs, bsize, use_hf_dataset):
|
| 6 |
+
fig = run_fc_analysis(
|
| 7 |
+
data_dir=data_source,
|
| 8 |
+
demographic_file=demographic_file,
|
| 9 |
+
latent_dim=latent_dim,
|
| 10 |
+
nepochs=nepochs,
|
| 11 |
+
bsize=bsize,
|
| 12 |
+
save_model=True,
|
| 13 |
+
use_hf_dataset=use_hf_dataset
|
| 14 |
+
)
|
| 15 |
+
return fig
|
| 16 |
+
|
| 17 |
+
def create_interface():
|
| 18 |
+
iface = gr.Interface(
|
| 19 |
+
fn=gradio_fc_analysis,
|
| 20 |
+
inputs=[
|
| 21 |
+
gr.Textbox(label="Data Source (HF Dataset ID or Local Directory)",
|
| 22 |
+
default="SreekarB/OSFData"),
|
| 23 |
+
gr.Textbox(label="Demographic File",
|
| 24 |
+
default="FC_graph_covariate_data.csv"),
|
| 25 |
+
gr.Slider(minimum=8, maximum=64, step=8,
|
| 26 |
+
label="Latent Dimensions", default=32),
|
| 27 |
+
gr.Slider(minimum=100, maximum=5000, step=100,
|
| 28 |
+
label="Number of Epochs", default=1000),
|
| 29 |
+
gr.Slider(minimum=8, maximum=64, step=8,
|
| 30 |
+
label="Batch Size", default=16),
|
| 31 |
+
gr.Checkbox(label="Use HuggingFace Dataset",
|
| 32 |
+
default=True),
|
| 33 |
+
],
|
| 34 |
+
outputs="plot",
|
| 35 |
+
title="Aphasia fMRI to FC Analysis using VAE",
|
| 36 |
+
description="""
|
| 37 |
+
Analysis pipeline: fMRI → FC matrices → VAE → Analysis
|
| 38 |
+
|
| 39 |
+
This demo uses the SreekarB/OSFData dataset from HuggingFace by default.
|
| 40 |
+
The demographic file FC_graph_covariate_data.csv contains age_at_stroke, sex, months_post_stroke, and wab_score.
|
| 41 |
+
""",
|
| 42 |
+
examples=[
|
| 43 |
+
["SreekarB/OSFData", "FC_graph_covariate_data.csv", 32, 500, 16, True],
|
| 44 |
+
],
|
| 45 |
+
cache_examples=False,
|
| 46 |
+
)
|
| 47 |
+
return iface
|
| 48 |
+
|
| 49 |
+
if __name__ == "__main__":
|
| 50 |
+
iface = create_interface()
|
| 51 |
+
iface.launch(share=True)
|
| 52 |
+
|
config.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Model configuration
|
| 2 |
+
MODEL_CONFIG = {
|
| 3 |
+
'latent_dim': 32,
|
| 4 |
+
'nepochs': 1000,
|
| 5 |
+
'bsize': 16,
|
| 6 |
+
'loss_rec_mult': 100,
|
| 7 |
+
'loss_decor_mult': 10,
|
| 8 |
+
'lr': 1e-4
|
| 9 |
+
}
|
| 10 |
+
|
| 11 |
+
# Preprocessing configuration
|
| 12 |
+
PREPROCESS_CONFIG = {
|
| 13 |
+
't_r': 2.0,
|
| 14 |
+
'high_pass': 0.01,
|
| 15 |
+
'low_pass': 0.1,
|
| 16 |
+
'radius': 5
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
# Dataset configuration
|
| 20 |
+
DATASET_CONFIG = {
|
| 21 |
+
'name': 'SreekarB/OSFData',
|
| 22 |
+
'split': 'train'
|
| 23 |
+
}
|
| 24 |
+
|
data_preprocessing.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datasets import load_dataset
|
| 4 |
+
from nilearn import input_data, connectome
|
| 5 |
+
from nilearn.image import load_img
|
| 6 |
+
import nibabel as nib
|
| 7 |
+
|
| 8 |
+
def preprocess_fmri_to_fc(dataset_name, atlas_path=None):
|
| 9 |
+
dataset = load_dataset(dataset_name, split="train")
|
| 10 |
+
|
| 11 |
+
# Load Power 264 atlas or specified atlas
|
| 12 |
+
if atlas_path is None:
|
| 13 |
+
# Use Power 264 coordinates to create spherical ROIs
|
| 14 |
+
from nilearn import datasets
|
| 15 |
+
power = datasets.fetch_coords_power_2011()
|
| 16 |
+
coords = np.vstack((power.rois['x'], power.rois['y'], power.rois['z'])).T
|
| 17 |
+
|
| 18 |
+
masker = input_data.NiftiSpheresMasker(
|
| 19 |
+
coords, radius=5,
|
| 20 |
+
standardize=True,
|
| 21 |
+
memory='nilearn_cache', memory_level=1,
|
| 22 |
+
verbose=0,
|
| 23 |
+
detrend=True,
|
| 24 |
+
low_pass=0.1,
|
| 25 |
+
high_pass=0.01,
|
| 26 |
+
t_r=2.0 # Adjust TR according to your data
|
| 27 |
+
)
|
| 28 |
+
else:
|
| 29 |
+
masker = input_data.NiftiLabelsMasker(
|
| 30 |
+
labels_img=atlas_path,
|
| 31 |
+
standardize=True,
|
| 32 |
+
memory='nilearn_cache', memory_level=1,
|
| 33 |
+
verbose=0,
|
| 34 |
+
detrend=True,
|
| 35 |
+
low_pass=0.1,
|
| 36 |
+
high_pass=0.01,
|
| 37 |
+
t_r=2.0 # Adjust TR according to your data
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Load demographic data
|
| 41 |
+
demo_df = pd.DataFrame(dataset['demographics'])
|
| 42 |
+
|
| 43 |
+
demo_data = [
|
| 44 |
+
demo_df['age_at_stroke'].values,
|
| 45 |
+
demo_df['sex'].values,
|
| 46 |
+
demo_df['months_post_stroke'].values,
|
| 47 |
+
demo_df['wab_score'].values
|
| 48 |
+
]
|
| 49 |
+
|
| 50 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 51 |
+
|
| 52 |
+
# Process fMRI data and compute FC matrices
|
| 53 |
+
fc_matrices = []
|
| 54 |
+
for nii_file in dataset['nii_files']:
|
| 55 |
+
fmri_img = load_img(nii_file)
|
| 56 |
+
time_series = masker.fit_transform(fmri_img)
|
| 57 |
+
|
| 58 |
+
correlation_measure = connectome.ConnectivityMeasure(
|
| 59 |
+
kind='correlation',
|
| 60 |
+
vectorize=False,
|
| 61 |
+
discard_diagonal=False
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
fc_matrix = correlation_measure.fit_transform([time_series])[0]
|
| 65 |
+
|
| 66 |
+
triu_indices = np.triu_indices_from(fc_matrix, k=1)
|
| 67 |
+
fc_triu = fc_matrix[triu_indices]
|
| 68 |
+
|
| 69 |
+
fc_triu = np.arctanh(fc_triu) # Fisher z-transform
|
| 70 |
+
|
| 71 |
+
fc_matrices.append(fc_triu)
|
| 72 |
+
|
| 73 |
+
X = np.array(fc_matrices)
|
| 74 |
+
|
| 75 |
+
# Normalize the FC data
|
| 76 |
+
X = (X - np.mean(X, axis=0)) / np.std(X, axis=0)
|
| 77 |
+
|
| 78 |
+
return X, demo_data, demo_types
|
| 79 |
+
|
main.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
import nibabel as nib
|
| 6 |
+
from data_preprocessing import preprocess_fmri_to_fc
|
| 7 |
+
from vae_model import train_fc_vae, DemoVAE
|
| 8 |
+
from analysis import analyze_fc_patterns
|
| 9 |
+
from visualization import visualize_fc_analysis
|
| 10 |
+
from config import MODEL_CONFIG, DATASET_CONFIG
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import io
|
| 13 |
+
from typing import List, Dict, Union, Tuple, Any
|
| 14 |
+
|
| 15 |
+
def load_data(data_dir="SreekarB/OSFData", demographic_file="FC_graph_covariate_data.csv", use_hf_dataset=True):
|
| 16 |
+
"""
|
| 17 |
+
Load fMRI data and demographics from HuggingFace dataset or local files
|
| 18 |
+
"""
|
| 19 |
+
if use_hf_dataset:
|
| 20 |
+
# Load from HuggingFace Datasets
|
| 21 |
+
from datasets import load_dataset
|
| 22 |
+
|
| 23 |
+
print(f"Loading dataset from HuggingFace: {data_dir}")
|
| 24 |
+
dataset = load_dataset(data_dir)
|
| 25 |
+
|
| 26 |
+
# Load demographics from the dataset
|
| 27 |
+
if demographic_file in dataset["train"].features:
|
| 28 |
+
demo_df = pd.DataFrame(dataset["train"][demographic_file])
|
| 29 |
+
else:
|
| 30 |
+
# Try to load from the dataset files
|
| 31 |
+
try:
|
| 32 |
+
demo_content = dataset["train"][demographic_file][0]
|
| 33 |
+
demo_df = pd.read_csv(io.StringIO(demo_content))
|
| 34 |
+
except Exception as e:
|
| 35 |
+
print(f"Error loading demographics from dataset: {e}")
|
| 36 |
+
# Download the CSV from the dataset repo
|
| 37 |
+
import huggingface_hub
|
| 38 |
+
csv_path = huggingface_hub.hf_hub_download(repo_id=data_dir, filename=demographic_file)
|
| 39 |
+
demo_df = pd.read_csv(csv_path)
|
| 40 |
+
|
| 41 |
+
# Extract demographic data
|
| 42 |
+
demo_data = [
|
| 43 |
+
demo_df['age_at_stroke'].values if 'age_at_stroke' in demo_df.columns else np.array([]),
|
| 44 |
+
demo_df['sex'].values if 'sex' in demo_df.columns else np.array([]),
|
| 45 |
+
demo_df['months_post_stroke'].values if 'months_post_stroke' in demo_df.columns else np.array([]),
|
| 46 |
+
demo_df['wab_score'].values if 'wab_score' in demo_df.columns else np.array([])
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
# Get fMRI/FC files from dataset
|
| 50 |
+
nii_files = []
|
| 51 |
+
for f in dataset["train"].features:
|
| 52 |
+
if f.endswith(".nii.gz") or f.endswith(".nii"):
|
| 53 |
+
nii_files.append(f)
|
| 54 |
+
|
| 55 |
+
if not nii_files:
|
| 56 |
+
print("No .nii/.nii.gz files found in dataset, checking for FC matrices")
|
| 57 |
+
# Try to find FC matrices directly
|
| 58 |
+
fc_matrices = []
|
| 59 |
+
for f in dataset["train"].features:
|
| 60 |
+
if f.startswith("fc_") or f.endswith("_fc"):
|
| 61 |
+
fc_matrices.append(dataset["train"][f])
|
| 62 |
+
|
| 63 |
+
if fc_matrices:
|
| 64 |
+
print(f"Found {len(fc_matrices)} FC matrices in dataset")
|
| 65 |
+
return fc_matrices, demo_data, demo_types
|
| 66 |
+
else:
|
| 67 |
+
# Original local file loading
|
| 68 |
+
# Load demographics
|
| 69 |
+
demo_df = pd.read_csv(demographic_file)
|
| 70 |
+
|
| 71 |
+
demo_data = [
|
| 72 |
+
demo_df['age_at_stroke'].values,
|
| 73 |
+
demo_df['sex'].values,
|
| 74 |
+
demo_df['months_post_stroke'].values,
|
| 75 |
+
demo_df['wab_score'].values
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
# Load fMRI files
|
| 79 |
+
nii_files = sorted(list(Path(data_dir).glob('*.nii.gz')))
|
| 80 |
+
|
| 81 |
+
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
|
| 82 |
+
return nii_files, demo_data, demo_types
|
| 83 |
+
|
| 84 |
+
def run_fc_analysis(data_dir="SreekarB/OSFData",
|
| 85 |
+
demographic_file="FC_graph_covariate_data.csv",
|
| 86 |
+
latent_dim=32,
|
| 87 |
+
nepochs=1000,
|
| 88 |
+
bsize=16,
|
| 89 |
+
save_model=True,
|
| 90 |
+
use_hf_dataset=True):
|
| 91 |
+
|
| 92 |
+
# Update MODEL_CONFIG with user-specified parameters
|
| 93 |
+
MODEL_CONFIG.update({
|
| 94 |
+
'latent_dim': latent_dim,
|
| 95 |
+
'nepochs': nepochs,
|
| 96 |
+
'bsize': bsize
|
| 97 |
+
})
|
| 98 |
+
|
| 99 |
+
# Load data
|
| 100 |
+
print("Loading data...")
|
| 101 |
+
nii_files, demo_data, demo_types = load_data(data_dir, demographic_file, use_hf_dataset)
|
| 102 |
+
|
| 103 |
+
# Add import for io module if it's missing
|
| 104 |
+
import io
|
| 105 |
+
|
| 106 |
+
# Check if we got FC matrices directly
|
| 107 |
+
if isinstance(nii_files, list) and all(isinstance(item, np.ndarray) for item in nii_files):
|
| 108 |
+
print("Using pre-computed FC matrices...")
|
| 109 |
+
X = np.stack(nii_files)
|
| 110 |
+
else:
|
| 111 |
+
# Prepare data by converting fMRI to FC matrices
|
| 112 |
+
print("Converting fMRI data to FC matrices...")
|
| 113 |
+
X, demo_data, demo_types = preprocess_fmri_to_fc(nii_files, demo_data, demo_types)
|
| 114 |
+
|
| 115 |
+
# Train VAE and get data
|
| 116 |
+
print("Training VAE...")
|
| 117 |
+
vae, X, demo_data, demo_types = train_fc_vae(X, demo_data, demo_types, MODEL_CONFIG)
|
| 118 |
+
|
| 119 |
+
if save_model:
|
| 120 |
+
print("Saving model...")
|
| 121 |
+
os.makedirs('models', exist_ok=True)
|
| 122 |
+
torch.save(vae.state_dict(), 'models/vae_model.pth')
|
| 123 |
+
|
| 124 |
+
# Get latent representations
|
| 125 |
+
print("Getting latent representations...")
|
| 126 |
+
latents = vae.get_latents(X)
|
| 127 |
+
|
| 128 |
+
# Analyze results
|
| 129 |
+
print("Analyzing demographic relationships...")
|
| 130 |
+
demographics = {
|
| 131 |
+
'age_at_stroke': demo_data[0] if len(demo_data[0]) > 0 else np.zeros(len(X)),
|
| 132 |
+
'months_post_stroke': demo_data[2] if len(demo_data[2]) > 0 else np.zeros(len(X)),
|
| 133 |
+
'wab_score': demo_data[3] if len(demo_data[3]) > 0 else np.zeros(len(X))
|
| 134 |
+
}
|
| 135 |
+
analysis_results = analyze_fc_patterns(latents, demographics)
|
| 136 |
+
|
| 137 |
+
# Generate new FC matrix
|
| 138 |
+
print("Generating new FC matrices...")
|
| 139 |
+
new_demographics = [
|
| 140 |
+
[60.0], # age at stroke
|
| 141 |
+
['M'], # sex
|
| 142 |
+
[12.0], # months post stroke
|
| 143 |
+
[80.0] # wab score
|
| 144 |
+
]
|
| 145 |
+
generated_fc = vae.transform(1, new_demographics, demo_types)
|
| 146 |
+
reconstructed_fc = vae.transform(X, demo_data, demo_types)
|
| 147 |
+
|
| 148 |
+
# Visualize results
|
| 149 |
+
print("Creating visualizations...")
|
| 150 |
+
fig = visualize_fc_analysis(X[0], reconstructed_fc[0], generated_fc[0], analysis_results)
|
| 151 |
+
|
| 152 |
+
return fig
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
import argparse
|
| 156 |
+
|
| 157 |
+
parser = argparse.ArgumentParser(description='Run FC Analysis using VAE')
|
| 158 |
+
parser.add_argument('--data_dir', type=str, default='SreekarB/OSFData',
|
| 159 |
+
help='HuggingFace dataset ID or directory containing fMRI data')
|
| 160 |
+
parser.add_argument('--demographic_file', type=str, default='FC_graph_covariate_data.csv',
|
| 161 |
+
help='Path to demographic data CSV file')
|
| 162 |
+
parser.add_argument('--latent_dim', type=int, default=32,
|
| 163 |
+
help='Dimension of latent space')
|
| 164 |
+
parser.add_argument('--nepochs', type=int, default=1000,
|
| 165 |
+
help='Number of training epochs')
|
| 166 |
+
parser.add_argument('--bsize', type=int, default=16,
|
| 167 |
+
help='Batch size for training')
|
| 168 |
+
parser.add_argument('--no_save', action='store_false',
|
| 169 |
+
help='Do not save the model')
|
| 170 |
+
parser.add_argument('--use_local', action='store_true',
|
| 171 |
+
help='Use local data instead of HuggingFace dataset')
|
| 172 |
+
|
| 173 |
+
args = parser.parse_args()
|
| 174 |
+
|
| 175 |
+
fig = run_fc_analysis(
|
| 176 |
+
data_dir=args.data_dir,
|
| 177 |
+
demographic_file=args.demographic_file,
|
| 178 |
+
latent_dim=args.latent_dim,
|
| 179 |
+
nepochs=args.nepochs,
|
| 180 |
+
bsize=args.bsize,
|
| 181 |
+
save_model=args.no_save,
|
| 182 |
+
use_hf_dataset=not args.use_local
|
| 183 |
+
)
|
| 184 |
+
fig.show()
|
| 185 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=1.9.0
|
| 2 |
+
numpy>=1.19.2
|
| 3 |
+
pandas>=1.2.4
|
| 4 |
+
nilearn>=0.8.1
|
| 5 |
+
nibabel>=3.2.1
|
| 6 |
+
scikit-learn>=0.24.2
|
| 7 |
+
matplotlib>=3.4.2
|
| 8 |
+
gradio>=2.0.0
|
| 9 |
+
datasets>=1.11.0
|
| 10 |
+
huggingface_hub>=0.12.0
|
| 11 |
+
transformers>=4.15.0
|
| 12 |
+
|
src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/demovae/model.py
ADDED
|
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
import random
|
| 7 |
+
import numpy as np
|
| 8 |
+
|
| 9 |
+
from sklearn.linear_model import Ridge
|
| 10 |
+
from sklearn.linear_model import LogisticRegression
|
| 11 |
+
|
| 12 |
+
def to_torch(x):
|
| 13 |
+
return torch.from_numpy(x).float()
|
| 14 |
+
|
| 15 |
+
def to_cuda(x, use_cuda):
|
| 16 |
+
if use_cuda:
|
| 17 |
+
return x.cuda()
|
| 18 |
+
else:
|
| 19 |
+
return x
|
| 20 |
+
|
| 21 |
+
def to_numpy(x):
|
| 22 |
+
return x.detach().cpu().numpy()
|
| 23 |
+
|
| 24 |
+
class VAE(nn.Module):
|
| 25 |
+
def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
|
| 26 |
+
super(VAE, self).__init__()
|
| 27 |
+
self.input_dim = input_dim
|
| 28 |
+
self.latent_dim = latent_dim
|
| 29 |
+
self.demo_dim = demo_dim
|
| 30 |
+
self.use_cuda = use_cuda
|
| 31 |
+
self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
|
| 32 |
+
self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
|
| 33 |
+
self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
|
| 34 |
+
self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
|
| 35 |
+
|
| 36 |
+
def enc(self, x):
|
| 37 |
+
x = F.relu(self.enc1(x))
|
| 38 |
+
z = self.enc2(x)
|
| 39 |
+
return z
|
| 40 |
+
|
| 41 |
+
def gen(self, n):
|
| 42 |
+
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
|
| 43 |
+
|
| 44 |
+
def dec(self, z, demo):
|
| 45 |
+
z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
|
| 46 |
+
x = F.relu(self.dec1(z))
|
| 47 |
+
x = self.dec2(x)
|
| 48 |
+
#x = x.reshape(len(z), 264, 5)
|
| 49 |
+
#x = torch.einsum('nac,nbc->nab', x, x)
|
| 50 |
+
#a,b = np.triu_indices(264, 1)
|
| 51 |
+
#x = x[:,a,b]
|
| 52 |
+
return x
|
| 53 |
+
|
| 54 |
+
def rmse(a, b, mean=torch.mean):
|
| 55 |
+
return mean((a-b)**2)**0.5
|
| 56 |
+
|
| 57 |
+
def latent_loss(z, use_cuda=True):
|
| 58 |
+
C = z.T@z
|
| 59 |
+
mu = torch.mean(z, dim=0)
|
| 60 |
+
tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
|
| 61 |
+
tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 62 |
+
loss_C = rmse(C, tgt1)
|
| 63 |
+
loss_mu = rmse(mu, tgt2)
|
| 64 |
+
return loss_C, loss_mu, C, mu
|
| 65 |
+
|
| 66 |
+
def decor_loss(z, demo, use_cuda=True):
|
| 67 |
+
ps = []
|
| 68 |
+
losses = []
|
| 69 |
+
for di in range(demo.shape[1]):
|
| 70 |
+
d = demo[:,di]
|
| 71 |
+
d = d - torch.mean(d)
|
| 72 |
+
p = torch.einsum('n,nz->z', d, z)
|
| 73 |
+
p = p/torch.std(d)
|
| 74 |
+
p = p/torch.einsum('nz,nz->z', z, z)
|
| 75 |
+
tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 76 |
+
loss = rmse(p, tgt)
|
| 77 |
+
losses.append(loss)
|
| 78 |
+
ps.append(p)
|
| 79 |
+
losses = torch.stack(losses)
|
| 80 |
+
return losses, ps
|
| 81 |
+
|
| 82 |
+
def pretty(x):
|
| 83 |
+
return f'{round(float(x), 4)}'
|
| 84 |
+
|
| 85 |
+
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
|
| 86 |
+
demo_t = []
|
| 87 |
+
demo_idx = 0
|
| 88 |
+
for d,t,s in zip(demo, demo_types, pred_stats):
|
| 89 |
+
if t == 'continuous':
|
| 90 |
+
demo_t.append(to_cuda(to_torch(d), use_cuda))
|
| 91 |
+
elif t == 'categorical':
|
| 92 |
+
for dd in d:
|
| 93 |
+
if dd not in s:
|
| 94 |
+
print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
|
| 95 |
+
raise Exception('Bad demographic')
|
| 96 |
+
for ss in s:
|
| 97 |
+
idx = (d == ss).astype('bool')
|
| 98 |
+
zeros = torch.zeros(len(d))
|
| 99 |
+
zeros[idx] = 1
|
| 100 |
+
demo_t.append(to_cuda(zeros, use_cuda))
|
| 101 |
+
demo_idx += 1
|
| 102 |
+
demo_t = torch.stack(demo_t).permute(1,0)
|
| 103 |
+
return demo_t
|
| 104 |
+
|
| 105 |
+
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize, loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult, loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
|
| 106 |
+
# Get linear predictors for demographics
|
| 107 |
+
pred_w = []
|
| 108 |
+
pred_i = []
|
| 109 |
+
# Pred stats are mean and std for continuous, and a list of all values for categorical
|
| 110 |
+
pred_stats = []
|
| 111 |
+
for i,d,t in zip(range(len(demo)), demo, demo_types):
|
| 112 |
+
print(f'Fitting auxilliary guidance model for demographic {i} {t}...', end='')
|
| 113 |
+
if t == 'continuous':
|
| 114 |
+
pred_stats.append([np.mean(d), np.std(d)])
|
| 115 |
+
reg = Ridge(alpha=alpha).fit(x, d)
|
| 116 |
+
reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
|
| 117 |
+
reg_i = reg.intercept_
|
| 118 |
+
pred_w.append(reg_w)
|
| 119 |
+
pred_i.append(reg_i)
|
| 120 |
+
elif t == 'categorical':
|
| 121 |
+
pred_stats.append(sorted(list(set(list(d)))))
|
| 122 |
+
reg = LogisticRegression(C=LR_C).fit(x, d)
|
| 123 |
+
# Binary
|
| 124 |
+
if len(reg.coef_) == 1:
|
| 125 |
+
reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
|
| 126 |
+
reg_i = reg.intercept_[0]
|
| 127 |
+
pred_w.append(-reg_w)
|
| 128 |
+
pred_i.append(-reg_i)
|
| 129 |
+
pred_w.append(reg_w)
|
| 130 |
+
pred_i.append(reg_i)
|
| 131 |
+
# Categorical
|
| 132 |
+
else:
|
| 133 |
+
for i in range(len(reg.coef_)):
|
| 134 |
+
reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
|
| 135 |
+
reg_i = reg.intercept_[i]
|
| 136 |
+
pred_w.append(reg_w)
|
| 137 |
+
pred_i.append(reg_i)
|
| 138 |
+
else:
|
| 139 |
+
print(f'demographic type "{t}" not "continuous" or "categorical"')
|
| 140 |
+
raise Exception('Bad demographic type')
|
| 141 |
+
print(' done')
|
| 142 |
+
ret_obj.pred_stats = pred_stats
|
| 143 |
+
# Convert input to pytorch
|
| 144 |
+
print('Converting input to pytorch')
|
| 145 |
+
x = to_cuda(to_torch(x), vae.use_cuda)
|
| 146 |
+
# Convert demographics to pytorch
|
| 147 |
+
print('Converting demographics to pytorch')
|
| 148 |
+
demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
|
| 149 |
+
# Training loop
|
| 150 |
+
print('Beginning VAE training')
|
| 151 |
+
ce = nn.CrossEntropyLoss()
|
| 152 |
+
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 153 |
+
for e in range(nepochs):
|
| 154 |
+
for bs in range(0,len(x),bsize):
|
| 155 |
+
xb = x[bs:(bs+bsize)]
|
| 156 |
+
db = demo_t[bs:(bs+bsize)]
|
| 157 |
+
optim.zero_grad()
|
| 158 |
+
# Reconstruct
|
| 159 |
+
z = vae.enc(xb)
|
| 160 |
+
y = vae.dec(z, db)
|
| 161 |
+
loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
|
| 162 |
+
loss_decor, _ = decor_loss(z, db, vae.use_cuda)
|
| 163 |
+
loss_decor = sum(loss_decor)
|
| 164 |
+
loss_rec = rmse(xb, y)
|
| 165 |
+
# Sample demographics
|
| 166 |
+
demo_gen = []
|
| 167 |
+
for s,t in zip(pred_stats, demo_types):
|
| 168 |
+
if t == 'continuous':
|
| 169 |
+
mu = s[0]
|
| 170 |
+
std = s[1]
|
| 171 |
+
dd = torch.randn(100).float()
|
| 172 |
+
dd = dd*std+mu
|
| 173 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 174 |
+
demo_gen.append(dd)
|
| 175 |
+
elif t == 'categorical':
|
| 176 |
+
idx = random.randint(0, len(s)-1)
|
| 177 |
+
for i in range(len(s)):
|
| 178 |
+
if idx == i:
|
| 179 |
+
dd = torch.ones(100).float()
|
| 180 |
+
else:
|
| 181 |
+
dd = torch.zeros(100).float()
|
| 182 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 183 |
+
demo_gen.append(dd)
|
| 184 |
+
demo_gen = torch.stack(demo_gen).permute(1,0)
|
| 185 |
+
# Generate
|
| 186 |
+
z = vae.gen(100)
|
| 187 |
+
y = vae.dec(z, demo_gen)
|
| 188 |
+
# Regressor/classifier guidance loss
|
| 189 |
+
losses_pred = []
|
| 190 |
+
idcs = []
|
| 191 |
+
dg_idx = 0
|
| 192 |
+
for s,t in zip(pred_stats, demo_types):
|
| 193 |
+
if t == 'continuous':
|
| 194 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 195 |
+
loss = rmse(demo_gen[:,dg_idx], yy)
|
| 196 |
+
losses_pred.append(loss)
|
| 197 |
+
idcs.append(float(demo_gen[0,dg_idx]))
|
| 198 |
+
dg_idx += 1
|
| 199 |
+
elif t == 'categorical':
|
| 200 |
+
loss = 0
|
| 201 |
+
for i in range(len(s)):
|
| 202 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 203 |
+
loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
|
| 204 |
+
idcs.append(int(demo_gen[0,dg_idx]))
|
| 205 |
+
dg_idx += 1
|
| 206 |
+
losses_pred.append(loss)
|
| 207 |
+
total_loss = loss_C_mult*loss_C + loss_mu_mult*loss_mu + loss_rec_mult*loss_rec + loss_decor_mult*loss_decor + loss_pred_mult*sum(losses_pred)
|
| 208 |
+
total_loss.backward()
|
| 209 |
+
optim.step()
|
| 210 |
+
if e%pperiod == 0 or e == nepochs-1:
|
| 211 |
+
print(f'Epoch {e} ', end='')
|
| 212 |
+
print(f'ReconLoss {pretty(loss_rec)} ', end='')
|
| 213 |
+
print(f'CovarianceLoss {pretty(loss_C)} ', end='')
|
| 214 |
+
print(f'MeanLoss {pretty(loss_mu)} ', end='')
|
| 215 |
+
print(f'DecorLoss {pretty(loss_decor)} ', end='')
|
| 216 |
+
losses_pred = [pretty(loss) for loss in losses_pred]
|
| 217 |
+
print(f'GuidanceTargets {idcs} GuidanceLosses {losses_pred} ', end='')
|
| 218 |
+
print()
|
| 219 |
+
print('Training complete.')
|
| 220 |
+
|
| 221 |
+
|
src/demovae/sklearn.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
from demovae.model import VAE, train_vae, to_torch, to_cuda, to_numpy, demo_to_torch
|
| 3 |
+
|
| 4 |
+
from sklearn.base import BaseEstimator
|
| 5 |
+
|
| 6 |
+
# For saving
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
class DemoVAE(BaseEstimator):
|
| 10 |
+
def __init__(self, **params):
|
| 11 |
+
self.set_params(**params)
|
| 12 |
+
|
| 13 |
+
@staticmethod
|
| 14 |
+
def get_default_params():
|
| 15 |
+
return dict(latent_dim=60, # Latent dimension
|
| 16 |
+
use_cuda=True, # GPU acceleration
|
| 17 |
+
nepochs=3000, # Training epochs
|
| 18 |
+
pperiod=100, # Epochs between printing updates
|
| 19 |
+
bsize=1000, # Batch size
|
| 20 |
+
loss_C_mult=1, # Covariance loss (KL div)
|
| 21 |
+
loss_mu_mult=1, # Mean loss (KL div)
|
| 22 |
+
loss_rec_mult=100, # Reconstruction loss
|
| 23 |
+
loss_decor_mult=10, # Latent-demographic decorrelation loss
|
| 24 |
+
loss_pred_mult=0.001, # Classifier/regressor guidance loss
|
| 25 |
+
alpha=100, # Regularization for continuous guidance models
|
| 26 |
+
LR_C=100, # Regularization for categorical guidance models
|
| 27 |
+
lr=1e-4, # Learning rate
|
| 28 |
+
weight_decay=0, # L2 regularization for VAE model
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def get_params(self, **params):
|
| 32 |
+
return dict(latent_dim=self.latent_dim,
|
| 33 |
+
use_cuda=self.use_cuda,
|
| 34 |
+
nepochs=self.nepochs,
|
| 35 |
+
pperiod=self.pperiod,
|
| 36 |
+
bsize=self.bsize,
|
| 37 |
+
loss_C_mult=self.loss_C_mult,
|
| 38 |
+
loss_mu_mult=self.loss_mu_mult,
|
| 39 |
+
loss_rec_mult=self.loss_rec_mult,
|
| 40 |
+
loss_decor_mult=self.loss_decor_mult,
|
| 41 |
+
loss_pred_mult=self.loss_pred_mult,
|
| 42 |
+
alpha=self.alpha,
|
| 43 |
+
LR_C=self.LR_C,
|
| 44 |
+
lr=self.lr,
|
| 45 |
+
weight_decay=self.weight_decay,
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
def set_params(self, **params):
|
| 49 |
+
dft = DemoVAE.get_default_params()
|
| 50 |
+
for key in dft:
|
| 51 |
+
if key in params:
|
| 52 |
+
setattr(self, key, params[key])
|
| 53 |
+
else:
|
| 54 |
+
setattr(self, key, dft[key])
|
| 55 |
+
return self
|
| 56 |
+
|
| 57 |
+
def fit(self, x, demo, demo_types, **kwargs):
|
| 58 |
+
# Get demo_dim
|
| 59 |
+
demo_dim = 0
|
| 60 |
+
for d,t in zip(demo, demo_types):
|
| 61 |
+
if t == 'continuous':
|
| 62 |
+
demo_dim += 1
|
| 63 |
+
elif t == 'categorical':
|
| 64 |
+
ll = len(set(list(d)))
|
| 65 |
+
if ll == 1:
|
| 66 |
+
print('Only one type of category for categorical variable')
|
| 67 |
+
raise Exception('Bad categorical')
|
| 68 |
+
demo_dim += ll
|
| 69 |
+
else:
|
| 70 |
+
print(f'demographic type "{t}" not "continuous" or "categorical"')
|
| 71 |
+
raise Exception('Bad demographic type')
|
| 72 |
+
# Save parameters
|
| 73 |
+
self.input_dim = x.shape[1]
|
| 74 |
+
self.demo_dim = demo_dim
|
| 75 |
+
# Create model
|
| 76 |
+
self.vae = VAE(x.shape[1], self.latent_dim, demo_dim, self.use_cuda)
|
| 77 |
+
# Train model
|
| 78 |
+
train_vae(self.vae, x, demo, demo_types,
|
| 79 |
+
self.nepochs, self.pperiod, self.bsize,
|
| 80 |
+
self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult, self.loss_decor_mult, self.loss_pred_mult,
|
| 81 |
+
self.lr, self.weight_decay, self.alpha, self.LR_C,
|
| 82 |
+
self)
|
| 83 |
+
return self
|
| 84 |
+
|
| 85 |
+
def transform(self, x, demo, demo_types, **kwargs):
|
| 86 |
+
if isinstance(x, int):
|
| 87 |
+
# Generate
|
| 88 |
+
z = self.vae.gen(x)
|
| 89 |
+
else:
|
| 90 |
+
# Get latents for real data
|
| 91 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 92 |
+
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
|
| 93 |
+
y = self.vae.dec(z, demo_t)
|
| 94 |
+
return to_numpy(y)
|
| 95 |
+
|
| 96 |
+
def fit_transform(self, x, demo, demo_types, **kwargs):
|
| 97 |
+
self.fit(x, demo, demo_types)
|
| 98 |
+
return self.transform(x, demo, demo_types)
|
| 99 |
+
|
| 100 |
+
def get_latents(self, x):
|
| 101 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 102 |
+
return to_numpy(z)
|
| 103 |
+
|
| 104 |
+
def save(self, path):
|
| 105 |
+
params = self.get_params()
|
| 106 |
+
dct = dict(pred_stats=self.pred_stats,
|
| 107 |
+
params=params,
|
| 108 |
+
input_dim=self.input_dim,
|
| 109 |
+
demo_dim=self.demo_dim,
|
| 110 |
+
model_state_dict=self.vae.state_dict())
|
| 111 |
+
torch.save(dct, path)
|
| 112 |
+
|
| 113 |
+
def load(self, path):
|
| 114 |
+
dct = torch.load(path)
|
| 115 |
+
self.pred_stats = dct['pred_stats']
|
| 116 |
+
self.set_params(**dct['params'])
|
| 117 |
+
self.vae = VAE(dct['input_dim'],
|
| 118 |
+
dct['params']['latent_dim'],
|
| 119 |
+
dct['demo_dim'],
|
| 120 |
+
dct['params']['use_cuda'])
|
| 121 |
+
self.vae.load_state_dict(dct['model_state_dict'])
|
| 122 |
+
|
| 123 |
+
|
utils.py
ADDED
|
@@ -0,0 +1,186 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from sklearn.linear_model import Ridge, LogisticRegression
|
| 4 |
+
|
| 5 |
+
def to_torch(x):
|
| 6 |
+
return torch.from_numpy(x).float()
|
| 7 |
+
|
| 8 |
+
def to_cuda(x, use_cuda):
|
| 9 |
+
return x.cuda() if use_cuda else x
|
| 10 |
+
|
| 11 |
+
def to_numpy(x):
|
| 12 |
+
return x.detach().cpu().numpy()
|
| 13 |
+
|
| 14 |
+
def fc_matrix_from_triu(triu_values, n_rois=264):
|
| 15 |
+
fc_matrix = np.zeros((n_rois, n_rois))
|
| 16 |
+
triu_indices = np.triu_indices(n_rois, k=1)
|
| 17 |
+
triu_values = np.tanh(triu_values)
|
| 18 |
+
fc_matrix[triu_indices] = triu_values
|
| 19 |
+
fc_matrix = fc_matrix + fc_matrix.T
|
| 20 |
+
np.fill_diagonal(fc_matrix, 1)
|
| 21 |
+
return fc_matrix
|
| 22 |
+
|
| 23 |
+
def rmse(a, b, mean=torch.mean):
|
| 24 |
+
return mean((a-b)**2)**0.5
|
| 25 |
+
|
| 26 |
+
def latent_loss(z, use_cuda=True):
|
| 27 |
+
C = z.T@z
|
| 28 |
+
mu = torch.mean(z, dim=0)
|
| 29 |
+
tgt1 = to_cuda(torch.eye(z.shape[-1]).float(), use_cuda)*len(z)
|
| 30 |
+
tgt2 = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 31 |
+
loss_C = rmse(C, tgt1)
|
| 32 |
+
loss_mu = rmse(mu, tgt2)
|
| 33 |
+
return loss_C, loss_mu, C, mu
|
| 34 |
+
|
| 35 |
+
def decor_loss(z, demo, use_cuda=True):
|
| 36 |
+
ps = []
|
| 37 |
+
losses = []
|
| 38 |
+
for di in range(demo.shape[1]):
|
| 39 |
+
d = demo[:,di]
|
| 40 |
+
d = d - torch.mean(d)
|
| 41 |
+
p = torch.einsum('n,nz->z', d, z)
|
| 42 |
+
p = p/torch.std(d)
|
| 43 |
+
p = p/torch.einsum('nz,nz->z', z, z)
|
| 44 |
+
tgt = to_cuda(torch.zeros(z.shape[-1]).float(), use_cuda)
|
| 45 |
+
loss = rmse(p, tgt)
|
| 46 |
+
losses.append(loss)
|
| 47 |
+
ps.append(p)
|
| 48 |
+
losses = torch.stack(losses)
|
| 49 |
+
return losses, ps
|
| 50 |
+
def demo_to_torch(demo, demo_types, pred_stats, use_cuda):
|
| 51 |
+
demo_t = []
|
| 52 |
+
demo_idx = 0
|
| 53 |
+
for d, t, s in zip(demo, demo_types, pred_stats):
|
| 54 |
+
if t == 'continuous':
|
| 55 |
+
demo_t.append(to_cuda(to_torch(d), use_cuda))
|
| 56 |
+
elif t == 'categorical':
|
| 57 |
+
for dd in d:
|
| 58 |
+
if dd not in s:
|
| 59 |
+
print(f'Model not trained with value {dd} for categorical demographic {demo_idx}')
|
| 60 |
+
raise Exception('Bad demographic')
|
| 61 |
+
for ss in s:
|
| 62 |
+
idx = (d == ss).astype('bool')
|
| 63 |
+
zeros = torch.zeros(len(d))
|
| 64 |
+
zeros[idx] = 1
|
| 65 |
+
demo_t.append(to_cuda(zeros, use_cuda))
|
| 66 |
+
demo_idx += 1
|
| 67 |
+
demo_t = torch.stack(demo_t).permute(1,0)
|
| 68 |
+
return demo_t
|
| 69 |
+
|
| 70 |
+
def train_vae(vae, x, demo, demo_types, nepochs, pperiod, bsize,
|
| 71 |
+
loss_C_mult, loss_mu_mult, loss_rec_mult, loss_decor_mult,
|
| 72 |
+
loss_pred_mult, lr, weight_decay, alpha, LR_C, ret_obj):
|
| 73 |
+
# Get linear predictors for demographics
|
| 74 |
+
pred_w = []
|
| 75 |
+
pred_i = []
|
| 76 |
+
pred_stats = []
|
| 77 |
+
|
| 78 |
+
for i, d, t in zip(range(len(demo)), demo, demo_types):
|
| 79 |
+
print(f'Fitting auxiliary guidance model for demographic {i} {t}...', end='')
|
| 80 |
+
if t == 'continuous':
|
| 81 |
+
pred_stats.append([np.mean(d), np.std(d)])
|
| 82 |
+
reg = Ridge(alpha=alpha).fit(x, d)
|
| 83 |
+
reg_w = to_cuda(to_torch(reg.coef_), vae.use_cuda)
|
| 84 |
+
reg_i = reg.intercept_
|
| 85 |
+
pred_w.append(reg_w)
|
| 86 |
+
pred_i.append(reg_i)
|
| 87 |
+
elif t == 'categorical':
|
| 88 |
+
pred_stats.append(sorted(list(set(list(d)))))
|
| 89 |
+
reg = LogisticRegression(C=LR_C).fit(x, d)
|
| 90 |
+
if len(reg.coef_) == 1:
|
| 91 |
+
reg_w = to_cuda(to_torch(reg.coef_[0]), vae.use_cuda)
|
| 92 |
+
reg_i = reg.intercept_[0]
|
| 93 |
+
pred_w.append(-reg_w)
|
| 94 |
+
pred_i.append(-reg_i)
|
| 95 |
+
pred_w.append(reg_w)
|
| 96 |
+
pred_i.append(reg_i)
|
| 97 |
+
else:
|
| 98 |
+
for i in range(len(reg.coef_)):
|
| 99 |
+
reg_w = to_cuda(to_torch(reg.coef_[i]), vae.use_cuda)
|
| 100 |
+
reg_i = reg.intercept_[i]
|
| 101 |
+
pred_w.append(reg_w)
|
| 102 |
+
pred_i.append(reg_i)
|
| 103 |
+
print(' done')
|
| 104 |
+
|
| 105 |
+
ret_obj.pred_stats = pred_stats
|
| 106 |
+
|
| 107 |
+
# Convert input to pytorch
|
| 108 |
+
x = to_cuda(to_torch(x), vae.use_cuda)
|
| 109 |
+
|
| 110 |
+
# Convert demographics to pytorch
|
| 111 |
+
demo_t = demo_to_torch(demo, demo_types, pred_stats, vae.use_cuda)
|
| 112 |
+
|
| 113 |
+
# Training loop
|
| 114 |
+
ce = torch.nn.CrossEntropyLoss()
|
| 115 |
+
optim = torch.optim.Adam(vae.parameters(), lr=lr, weight_decay=weight_decay)
|
| 116 |
+
|
| 117 |
+
for e in range(nepochs):
|
| 118 |
+
for bs in range(0, len(x), bsize):
|
| 119 |
+
xb = x[bs:(bs+bsize)]
|
| 120 |
+
db = demo_t[bs:(bs+bsize)]
|
| 121 |
+
optim.zero_grad()
|
| 122 |
+
|
| 123 |
+
# Reconstruct
|
| 124 |
+
z = vae.enc(xb)
|
| 125 |
+
y = vae.dec(z, db)
|
| 126 |
+
loss_C, loss_mu, _, _ = latent_loss(z, vae.use_cuda)
|
| 127 |
+
loss_decor, _ = decor_loss(z, db, vae.use_cuda)
|
| 128 |
+
loss_decor = sum(loss_decor)
|
| 129 |
+
loss_rec = rmse(xb, y)
|
| 130 |
+
|
| 131 |
+
# Sample demographics
|
| 132 |
+
demo_gen = []
|
| 133 |
+
for s, t in zip(pred_stats, demo_types):
|
| 134 |
+
if t == 'continuous':
|
| 135 |
+
mu, std = s
|
| 136 |
+
dd = torch.randn(100).float()
|
| 137 |
+
dd = dd*std+mu
|
| 138 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 139 |
+
demo_gen.append(dd)
|
| 140 |
+
elif t == 'categorical':
|
| 141 |
+
idx = np.random.randint(0, len(s))
|
| 142 |
+
for i in range(len(s)):
|
| 143 |
+
dd = torch.ones(100).float() if idx == i else torch.zeros(100).float()
|
| 144 |
+
dd = to_cuda(dd, vae.use_cuda)
|
| 145 |
+
demo_gen.append(dd)
|
| 146 |
+
|
| 147 |
+
demo_gen = torch.stack(demo_gen).permute(1,0)
|
| 148 |
+
|
| 149 |
+
# Generate
|
| 150 |
+
z = vae.gen(100)
|
| 151 |
+
y = vae.dec(z, demo_gen)
|
| 152 |
+
|
| 153 |
+
# Regressor/classifier guidance loss
|
| 154 |
+
losses_pred = []
|
| 155 |
+
idcs = []
|
| 156 |
+
dg_idx = 0
|
| 157 |
+
|
| 158 |
+
for s, t in zip(pred_stats, demo_types):
|
| 159 |
+
if t == 'continuous':
|
| 160 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 161 |
+
loss = rmse(demo_gen[:,dg_idx], yy)
|
| 162 |
+
losses_pred.append(loss)
|
| 163 |
+
idcs.append(float(demo_gen[0,dg_idx]))
|
| 164 |
+
dg_idx += 1
|
| 165 |
+
elif t == 'categorical':
|
| 166 |
+
loss = 0
|
| 167 |
+
for i in range(len(s)):
|
| 168 |
+
yy = y@pred_w[dg_idx]+pred_i[dg_idx]
|
| 169 |
+
loss += ce(torch.stack([-yy, yy], dim=1), demo_gen[:,dg_idx].long())
|
| 170 |
+
idcs.append(int(demo_gen[0,dg_idx]))
|
| 171 |
+
dg_idx += 1
|
| 172 |
+
losses_pred.append(loss)
|
| 173 |
+
|
| 174 |
+
total_loss = (loss_C_mult*loss_C + loss_mu_mult*loss_mu +
|
| 175 |
+
loss_rec_mult*loss_rec + loss_decor_mult*loss_decor +
|
| 176 |
+
loss_pred_mult*sum(losses_pred))
|
| 177 |
+
|
| 178 |
+
total_loss.backward()
|
| 179 |
+
optim.step()
|
| 180 |
+
|
| 181 |
+
if e%pperiod == 0 or e == nepochs-1:
|
| 182 |
+
print(f'Epoch {e} ReconLoss {loss_rec:.4f} CovarianceLoss {loss_C:.4f} '
|
| 183 |
+
f'MeanLoss {loss_mu:.4f} DecorLoss {loss_decor:.4f}')
|
| 184 |
+
print(f'GuidanceTargets {idcs}')
|
| 185 |
+
print(f'GuidanceLosses {[f"{loss:.4f}" for loss in losses_pred]}')
|
| 186 |
+
|
vae_model.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import numpy as np
|
| 5 |
+
from utils import to_torch, to_cuda, to_numpy, demo_to_torch
|
| 6 |
+
from sklearn.base import BaseEstimator
|
| 7 |
+
|
| 8 |
+
class VAE(nn.Module):
|
| 9 |
+
def __init__(self, input_dim, latent_dim, demo_dim, use_cuda=True):
|
| 10 |
+
super(VAE, self).__init__()
|
| 11 |
+
self.input_dim = input_dim
|
| 12 |
+
self.latent_dim = latent_dim
|
| 13 |
+
self.demo_dim = demo_dim
|
| 14 |
+
self.use_cuda = use_cuda
|
| 15 |
+
|
| 16 |
+
# Encoder
|
| 17 |
+
self.enc1 = to_cuda(nn.Linear(input_dim, 1000).float(), use_cuda)
|
| 18 |
+
self.enc2 = to_cuda(nn.Linear(1000, latent_dim).float(), use_cuda)
|
| 19 |
+
|
| 20 |
+
# Decoder
|
| 21 |
+
self.dec1 = to_cuda(nn.Linear(latent_dim+demo_dim, 1000).float(), use_cuda)
|
| 22 |
+
self.dec2 = to_cuda(nn.Linear(1000, input_dim).float(), use_cuda)
|
| 23 |
+
|
| 24 |
+
# Batch normalization layers
|
| 25 |
+
self.bn1 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
|
| 26 |
+
self.bn2 = to_cuda(nn.BatchNorm1d(1000), use_cuda)
|
| 27 |
+
|
| 28 |
+
def enc(self, x):
|
| 29 |
+
x = self.bn1(F.relu(self.enc1(x)))
|
| 30 |
+
z = self.enc2(x)
|
| 31 |
+
return z
|
| 32 |
+
|
| 33 |
+
def gen(self, n):
|
| 34 |
+
return to_cuda(torch.randn(n, self.latent_dim).float(), self.use_cuda)
|
| 35 |
+
|
| 36 |
+
def dec(self, z, demo):
|
| 37 |
+
z = to_cuda(torch.cat([z, demo], dim=1), self.use_cuda)
|
| 38 |
+
x = self.bn2(F.relu(self.dec1(z)))
|
| 39 |
+
x = self.dec2(x)
|
| 40 |
+
return x
|
| 41 |
+
|
| 42 |
+
class DemoVAE(BaseEstimator):
|
| 43 |
+
def __init__(self, **params):
|
| 44 |
+
self.set_params(**params)
|
| 45 |
+
|
| 46 |
+
@staticmethod
|
| 47 |
+
def get_default_params():
|
| 48 |
+
return dict(
|
| 49 |
+
latent_dim=32,
|
| 50 |
+
use_cuda=True,
|
| 51 |
+
nepochs=1000,
|
| 52 |
+
pperiod=100,
|
| 53 |
+
bsize=16,
|
| 54 |
+
loss_C_mult=1,
|
| 55 |
+
loss_mu_mult=1,
|
| 56 |
+
loss_rec_mult=100,
|
| 57 |
+
loss_decor_mult=10,
|
| 58 |
+
loss_pred_mult=0.001,
|
| 59 |
+
alpha=100,
|
| 60 |
+
LR_C=100,
|
| 61 |
+
lr=1e-4,
|
| 62 |
+
weight_decay=0
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
def get_params(self, deep=True):
|
| 66 |
+
return {k: getattr(self, k) for k in self.get_default_params().keys()}
|
| 67 |
+
|
| 68 |
+
def set_params(self, **params):
|
| 69 |
+
for k, v in self.get_default_params().items():
|
| 70 |
+
setattr(self, k, params.get(k, v))
|
| 71 |
+
return self
|
| 72 |
+
|
| 73 |
+
def fit(self, x, demo, demo_types):
|
| 74 |
+
from utils import train_vae
|
| 75 |
+
|
| 76 |
+
# Calculate demo_dim
|
| 77 |
+
demo_dim = 0
|
| 78 |
+
for d, t in zip(demo, demo_types):
|
| 79 |
+
if t == 'continuous':
|
| 80 |
+
demo_dim += 1
|
| 81 |
+
elif t == 'categorical':
|
| 82 |
+
demo_dim += len(set(d))
|
| 83 |
+
else:
|
| 84 |
+
raise ValueError(f'Demographic type "{t}" not supported')
|
| 85 |
+
|
| 86 |
+
# Initialize VAE
|
| 87 |
+
self.input_dim = x.shape[1]
|
| 88 |
+
self.demo_dim = demo_dim
|
| 89 |
+
self.vae = VAE(self.input_dim, self.latent_dim, demo_dim, self.use_cuda)
|
| 90 |
+
|
| 91 |
+
# Train VAE
|
| 92 |
+
train_vae(
|
| 93 |
+
self.vae, x, demo, demo_types,
|
| 94 |
+
self.nepochs, self.pperiod, self.bsize,
|
| 95 |
+
self.loss_C_mult, self.loss_mu_mult, self.loss_rec_mult,
|
| 96 |
+
self.loss_decor_mult, self.loss_pred_mult,
|
| 97 |
+
self.lr, self.weight_decay, self.alpha, self.LR_C,
|
| 98 |
+
self
|
| 99 |
+
)
|
| 100 |
+
return self
|
| 101 |
+
|
| 102 |
+
def transform(self, x, demo, demo_types):
|
| 103 |
+
if isinstance(x, int):
|
| 104 |
+
z = self.vae.gen(x)
|
| 105 |
+
else:
|
| 106 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 107 |
+
demo_t = demo_to_torch(demo, demo_types, self.pred_stats, self.vae.use_cuda)
|
| 108 |
+
y = self.vae.dec(z, demo_t)
|
| 109 |
+
return to_numpy(y)
|
| 110 |
+
|
| 111 |
+
def get_latents(self, x):
|
| 112 |
+
z = self.vae.enc(to_cuda(to_torch(x), self.vae.use_cuda))
|
| 113 |
+
return to_numpy(z)
|
| 114 |
+
|
| 115 |
+
def save(self, path):
|
| 116 |
+
torch.save({
|
| 117 |
+
'model_state_dict': self.vae.state_dict(),
|
| 118 |
+
'params': self.get_params(),
|
| 119 |
+
'pred_stats': self.pred_stats,
|
| 120 |
+
'input_dim': self.input_dim,
|
| 121 |
+
'demo_dim': self.demo_dim
|
| 122 |
+
}, path)
|
| 123 |
+
|
| 124 |
+
def load(self, path):
|
| 125 |
+
checkpoint = torch.load(path)
|
| 126 |
+
self.set_params(**checkpoint['params'])
|
| 127 |
+
self.pred_stats = checkpoint['pred_stats']
|
| 128 |
+
self.input_dim = checkpoint['input_dim']
|
| 129 |
+
self.demo_dim = checkpoint['demo_dim']
|
| 130 |
+
self.vae = VAE(self.input_dim, self.latent_dim, self.demo_dim, self.use_cuda)
|
| 131 |
+
self.vae.load_state_dict(checkpoint['model_state_dict'])
|
| 132 |
+
|
| 133 |
+
def train_fc_vae(X, demo_data, demo_types, model_config):
|
| 134 |
+
n_rois = 264
|
| 135 |
+
input_dim = (n_rois * (n_rois - 1)) // 2
|
| 136 |
+
|
| 137 |
+
vae = DemoVAE(
|
| 138 |
+
latent_dim=model_config['latent_dim'],
|
| 139 |
+
nepochs=model_config['nepochs'],
|
| 140 |
+
bsize=model_config['bsize'],
|
| 141 |
+
loss_rec_mult=model_config['loss_rec_mult'],
|
| 142 |
+
loss_decor_mult=model_config['loss_decor_mult'],
|
| 143 |
+
lr=model_config['lr'],
|
| 144 |
+
use_cuda=torch.cuda.is_available()
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
vae.fit(X, demo_data, demo_types)
|
| 148 |
+
|
| 149 |
+
return vae, X, demo_data, demo_types
|
| 150 |
+
|
visualization.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
from utils import fc_matrix_from_triu
|
| 4 |
+
|
| 5 |
+
def visualize_fc_analysis(original_triu, reconstructed_triu, generated_triu, analysis_results=None):
|
| 6 |
+
fig = plt.figure(figsize=(15, 10))
|
| 7 |
+
gs = plt.GridSpec(2, 3)
|
| 8 |
+
|
| 9 |
+
ax1 = fig.add_subplot(gs[0, 0])
|
| 10 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 11 |
+
ax3 = fig.add_subplot(gs[0, 2])
|
| 12 |
+
|
| 13 |
+
original = fc_matrix_from_triu(original_triu)
|
| 14 |
+
reconstructed = fc_matrix_from_triu(reconstructed_triu)
|
| 15 |
+
generated = fc_matrix_from_triu(generated_triu)
|
| 16 |
+
|
| 17 |
+
im1 = ax1.imshow(original, cmap='RdBu_r', vmin=-1, vmax=1)
|
| 18 |
+
ax1.set_title('Original FC')
|
| 19 |
+
|
| 20 |
+
im2 = ax2.imshow(reconstructed, cmap='RdBu_r', vmin=-1, vmax=1)
|
| 21 |
+
ax2.set_title('Reconstructed FC')
|
| 22 |
+
|
| 23 |
+
im3 = ax3.imshow(generated, cmap='RdBu_r', vmin=-1, vmax=1)
|
| 24 |
+
ax3.set_title('Generated FC')
|
| 25 |
+
|
| 26 |
+
plt.colorbar(im1, ax=ax1)
|
| 27 |
+
plt.colorbar(im2, ax=ax2)
|
| 28 |
+
plt.colorbar(im3, ax=ax3)
|
| 29 |
+
|
| 30 |
+
if analysis_results is not None:
|
| 31 |
+
ax4 = fig.add_subplot(gs[1, :])
|
| 32 |
+
for demo_name, results in analysis_results.items():
|
| 33 |
+
significant_dims = np.where(np.array(results['p_values']) < 0.05)[0]
|
| 34 |
+
correlations = np.array(results['correlations'])
|
| 35 |
+
ax4.plot(correlations, label=f'{demo_name} (sig. dims: {len(significant_dims)})')
|
| 36 |
+
|
| 37 |
+
ax4.set_xlabel('Latent Dimension')
|
| 38 |
+
ax4.set_ylabel('Correlation Strength')
|
| 39 |
+
ax4.set_title('Demographic Correlations with Latent Dimensions')
|
| 40 |
+
ax4.legend()
|
| 41 |
+
|
| 42 |
+
plt.tight_layout()
|
| 43 |
+
return fig
|
| 44 |
+
|