""" Script to visualize FC matrices from HuggingFace dataset, comparing original FC to VAE-generated FC. """ import os import numpy as np # Configure matplotlib for headless environment import matplotlib matplotlib.use('Agg') # Use non-interactive backend import matplotlib.pyplot as plt from datasets import load_dataset from fc_visualization import FCVisualizer from pathlib import Path import tempfile import requests from config import DATASET_CONFIG, PREPROCESS_CONFIG, MODEL_CONFIG from data_preprocessing import process_single_fmri from vae_model import VariationalAutoencoder def download_sample_fmri(dataset, temp_dir, max_samples=5): """ Download sample fMRI files from HuggingFace dataset. Args: dataset: HuggingFace dataset object temp_dir: Directory to save downloaded files max_samples: Maximum number of samples to download Returns: list of paths to downloaded files, demographic data, and file keys """ # Get first few samples to search for NIfTI files nifti_keys = [] # Look through dataset features to find NIfTI files for i, sample in enumerate(dataset): if i >= 5: # Check first 5 samples break for key, value in sample.items(): if isinstance(value, str) and (value.endswith('.nii') or value.endswith('.nii.gz')): if key not in nifti_keys: nifti_keys.append(key) print(f"Found {len(nifti_keys)} NIfTI file types in the dataset: {nifti_keys}") if not nifti_keys: print("No NIfTI files found in the dataset") return [], [], [] # Collect nifti files and demographics nifti_files = [] demo_data = [] # Process a limited number of samples num_samples = min(max_samples, len(dataset)) for sample_idx in range(num_samples): sample = dataset[sample_idx] for key in nifti_keys: try: file_url = sample[key] if not file_url or not isinstance(file_url, str): continue print(f"Processing sample {sample_idx+1}, file: {key}") # Download and save the file local_file = os.path.join(temp_dir, f"sample_{sample_idx}_{key}.nii.gz") print(f"Downloading {file_url} to {local_file}") response = requests.get(file_url) with open(local_file, 'wb') as f: f.write(response.content) nifti_files.append(local_file) # Extract demo data if available (or use placeholders) age = sample.get('age', 65.0) if 'age' in sample else 65.0 sex = sample.get('sex', 'M') if 'sex' in sample else 'M' mpo = sample.get('months_post_onset', 12.0) if 'months_post_onset' in sample else 12.0 wab = sample.get('wab_aq', 50.0) if 'wab_aq' in sample else 50.0 demo_sample = [age, sex, mpo, wab] demo_data.append(demo_sample) except Exception as e: print(f"Error processing sample {sample_idx}, {key}: {e}") return nifti_files, demo_data, nifti_keys class VariationalAutoencoder: """ Simplified VAE implementation for the visualization script. """ def __init__(self, n_features, latent_dim, demo_data, demo_types, **kwargs): """ Initialize the VAE. Args: n_features: Number of input features latent_dim: Dimension of latent space demo_data: Demographic data demo_types: Types of demographic variables **kwargs: Additional parameters """ import torch import torch.nn as nn self.n_features = n_features self.latent_dim = latent_dim self.demo_dim = self._calculate_demo_dim(demo_data, demo_types) self.nepochs = kwargs.get('nepochs', 100) self.batch_size = kwargs.get('bsize', 8) self.learning_rate = kwargs.get('lr', 1e-3) self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Build VAE model self.encoder = nn.Sequential( nn.Linear(n_features, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Linear(256, latent_dim * 2) # mu and logvar ).to(self.device) self.decoder = nn.Sequential( nn.Linear(latent_dim + self.demo_dim, 256), nn.ReLU(), nn.BatchNorm1d(256), nn.Linear(256, 512), nn.ReLU(), nn.BatchNorm1d(512), nn.Linear(512, n_features) ).to(self.device) self.optimizer = torch.optim.Adam( list(self.encoder.parameters()) + list(self.decoder.parameters()), lr=self.learning_rate ) self.demo_stats = None # Will be set during training def _calculate_demo_dim(self, demo_data, demo_types): """Calculate dimension of demographic data after one-hot encoding""" demo_dim = 0 for d, t in zip(demo_data, demo_types): if t == 'continuous': demo_dim += 1 elif t == 'categorical': if isinstance(d[0], str): # Get unique categories unique_values = list(set(d)) demo_dim += len(unique_values) else: demo_dim += len(set(d)) return demo_dim def _encode(self, x): """Encode input data to latent space""" import torch x_tensor = torch.tensor(x, dtype=torch.float32).to(self.device) h = self.encoder(x_tensor) mu, logvar = h[:, :self.latent_dim], h[:, self.latent_dim:] return mu, logvar def _reparameterize(self, mu, logvar): """Reparameterization trick for sampling from latent space""" import torch std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) z = mu + eps * std return z def _decode(self, z, demo): """Decode latent representation back to input space""" import torch # Concatenate latent code with demographic data z_concat = torch.cat([z, demo], dim=1) return self.decoder(z_concat) def _prepare_demographics(self, demo_data, demo_types): """Convert demographics to tensor with one-hot encoding for categorical variables""" import torch import numpy as np if self.demo_stats is None: # First time - compute stats self.demo_stats = [] for d, t in zip(demo_data, demo_types): if t == 'continuous': # Standardize continuous features self.demo_stats.append(('continuous', (np.mean(d), np.std(d)))) elif t == 'categorical': # Record unique values for one-hot encoding if isinstance(d[0], str): unique_values = sorted(list(set(d))) else: unique_values = sorted(list(set(d))) self.demo_stats.append(('categorical', unique_values)) # Process demographics based on saved stats demo_tensors = [] for (d, (dtype, stats)) in zip(demo_data, self.demo_stats): if dtype == 'continuous': mean, std = stats # Standardize standardized = (np.array(d) - mean) / (std + 1e-10) demo_tensors.append(torch.tensor(standardized, dtype=torch.float32).reshape(-1, 1)) else: # categorical unique_values = stats # One-hot encode one_hot_vectors = [] for val in d: try: idx = unique_values.index(val) vec = [0.0] * len(unique_values) vec[idx] = 1.0 one_hot_vectors.append(vec) except ValueError: # Handle unseen categories - use all zeros vec = [0.0] * len(unique_values) one_hot_vectors.append(vec) demo_tensors.append(torch.tensor(one_hot_vectors, dtype=torch.float32)) # Concatenate all demographic features return torch.cat(demo_tensors, dim=1).to(self.device) def fit(self, X, demo_data, demo_types): """ Train the VAE model. Args: X: Input data (FC matrices) demo_data: List of demographic variables demo_types: Types of demographic variables """ import torch import torch.nn.functional as F import numpy as np from torch.utils.data import DataLoader, TensorDataset print(f"Training VAE on {len(X)} samples for {self.nepochs} epochs...") # Prepare demographic data demo_tensor = self._prepare_demographics(demo_data, demo_types) # Convert input data to tensor X_tensor = torch.tensor(X, dtype=torch.float32).to(self.device) # Create dataset and dataloader dataset = TensorDataset(X_tensor, demo_tensor) dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True) # Training loop self.train_losses = [] for epoch in range(self.nepochs): epoch_losses = [] for batch_x, batch_demo in dataloader: # Forward pass mu, logvar = self._encode(batch_x) z = self._reparameterize(mu, logvar) x_recon = self._decode(z, batch_demo) # Compute loss recon_loss = F.mse_loss(x_recon, batch_x) kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) kl_loss = kl_loss / batch_x.size(0) # Normalize by batch size # Total loss loss = recon_loss + 0.1 * kl_loss # Backward and optimize self.optimizer.zero_grad() loss.backward() self.optimizer.step() epoch_losses.append(loss.item()) # Record average loss for this epoch avg_loss = np.mean(epoch_losses) self.train_losses.append(avg_loss) # Print progress every 10 epochs if (epoch + 1) % 10 == 0: print(f"Epoch {epoch+1}/{self.nepochs}, Loss: {avg_loss:.6f}") print("VAE training complete!") return self.train_losses def reconstruct(self, X, demo_data=None, demo_types=None): """ Reconstruct input data. Args: X: Input data demo_data: Demographic data (optional) demo_types: Types of demographic variables (optional) Returns: Reconstructed data """ import torch # Set to evaluation mode self.encoder.eval() self.decoder.eval() with torch.no_grad(): # Encode to latent space mu, _ = self._encode(X) # Use demo data if provided, otherwise use the demo data from training if demo_data is not None and demo_types is not None: demo_tensor = self._prepare_demographics(demo_data, demo_types) else: # This would fail if model wasn't trained raise ValueError("Demo data and types must be provided for reconstruction") # Decode recon = self._decode(mu, demo_tensor) # Convert to numpy return recon.cpu().numpy() def generate(self, n_samples, demo_data, demo_types): """ Generate new samples from the latent space. Args: n_samples: Number of samples to generate demo_data: Demographic data demo_types: Types of demographic variables Returns: Generated samples """ import torch # Set to evaluation mode self.decoder.eval() with torch.no_grad(): # Sample from standard normal z = torch.randn(n_samples, self.latent_dim).to(self.device) # Prepare demographic data demo_tensor = self._prepare_demographics(demo_data, demo_types) # Check dimensions if demo_tensor.shape[0] != n_samples: # Handle mismatch - repeat the first demographic sample if demo_tensor.shape[0] >= 1: demo_tensor = demo_tensor[0].unsqueeze(0).repeat(n_samples, 1) # Generate samples generated = self._decode(z, demo_tensor) # Convert to numpy return generated.cpu().numpy() def generate_comparison(): """Download, process and visualize FC matrices from the HuggingFace dataset, comparing original to VAE-generated matrices.""" print("Loading dataset from HuggingFace...") # Load the HuggingFace dataset using config dataset_name = DATASET_CONFIG.get('name', 'SreekarB/OSFData1') dataset_split = DATASET_CONFIG.get('split', 'train') dataset = load_dataset(dataset_name, split=dataset_split) print(f"Dataset loaded: {dataset}") # Create temporary directory for downloaded NIfTI files temp_dir = tempfile.mkdtemp(prefix="hf_nifti_") print(f"Created temp directory for NIfTI files: {temp_dir}") # Download and process fMRI files nifti_files, demo_samples, nifti_keys = download_sample_fmri(dataset, temp_dir, max_samples=5) if not nifti_files: print("No valid fMRI files were found") return # Process all fMRI files to FC matrices fc_matrices = [] demo_data = [] for file_idx, (file_path, demo_sample) in enumerate(zip(nifti_files, demo_samples)): try: print(f"Processing file {file_idx+1}/{len(nifti_files)}: {file_path}") fc_triu = process_single_fmri(file_path, allow_synthetic=False) fc_matrices.append(fc_triu) demo_data.append(demo_sample) except Exception as e: print(f"Error processing file {file_path}: {e}") if not fc_matrices: print("No valid FC matrices were generated") return # Convert to numpy arrays X = np.array(fc_matrices) # Normalize the data X = (X - np.mean(X, axis=0)) / np.std(X, axis=0) # Prepare demographic data # Transpose to get [feature_type][sample] format demo_data = np.array(demo_data).T.tolist() demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] # Train a VAE on the FC matrices print("Training VAE on the FC matrices...") n_features = X.shape[1] # Configure a smaller/faster VAE for demonstration vae = VariationalAutoencoder( n_features=n_features, latent_dim=MODEL_CONFIG.get('latent_dim', 32), demo_data=demo_data, demo_types=demo_types, nepochs=100, # Reduced for demo bsize=2, lr=1e-3 ) # Train the VAE vae.fit(X, demo_data, demo_types) # Generate reconstructed FC matrices print("Generating reconstructed FC matrices...") reconstructed = vae.reconstruct(X, demo_data, demo_types) # Generate a synthetic FC matrix print("Generating a synthetic FC matrix...") # For generating a new sample, we'll use demographics from first patient first_demo_data = [[d[0]] for d in demo_data] generated = vae.generate(1, first_demo_data, demo_types) # Visualize original, reconstructed, and generated FC matrices visualizer = FCVisualizer() # Process each sample to generate comparisons for i in range(min(3, len(X))): # Convert upper triangular vectors to full matrices for visualization original_matrix = visualizer._triu_to_matrix(X[i]) recon_matrix = visualizer._triu_to_matrix(reconstructed[i]) # Use the generate method for a single synthetic sample if i == 0: gen_matrix = visualizer._triu_to_matrix(generated[0]) # Visualize all three - original, reconstructed, generated fig = visualizer.plot_matrix_comparison( [original_matrix, recon_matrix, gen_matrix], titles=["Original FC", "Reconstructed FC", "Generated FC"] ) output_file = f"fc_comparison_with_generated.png" fig.savefig(output_file, dpi=300, bbox_inches='tight') print(f"Saved full comparison to {output_file}") # Visualize original vs reconstructed for each sample fig = visualizer.plot_matrix_comparison( [original_matrix, recon_matrix], titles=[f"Original FC (Sample {i+1})", f"Reconstructed FC (Sample {i+1})"] ) output_file = f"sample_{i}_original_vs_reconstructed.png" fig.savefig(output_file, dpi=300, bbox_inches='tight') print(f"Saved comparison to {output_file}") # Save the matrices np.save(f"sample_{i}_original_fc.npy", original_matrix) np.save(f"sample_{i}_reconstructed_fc.npy", recon_matrix) # Save the generated matrix np.save("generated_fc.npy", gen_matrix) print("Processing complete") if __name__ == "__main__": generate_comparison()