import numpy as np import torch from vae_model import DemoVAE # Create synthetic data input_dim = 100 n_samples = 20 demo_dim = 4 # Synthetic FC matrices (Upper triangular values) X = np.random.randn(n_samples, input_dim) # Synthetic demographics demo_data = [ np.random.normal(60, 10, n_samples), # age np.random.choice([0, 1], n_samples), # sex np.random.normal(24, 12, n_samples), # months post stroke np.random.normal(50, 15, n_samples) # WAB score ] # Types of demographics demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] # Initialize model model_config = { 'latent_dim': 16, 'nepochs': 5, 'bsize': 5, 'use_cuda': False } print("Initializing model...") vae = DemoVAE(**model_config) # Train model print("Training model...") train_losses, val_losses = vae.fit(X, demo_data, demo_types) print(f"Training complete! Train loss: {train_losses[-1]}, Val loss: {val_losses[-1]}") # Check shapes of losses print(f"Train losses shape: {len(train_losses)}") print(f"Val losses shape: {len(val_losses)}") # Save model print("Saving model...") vae.save('models/vae_model.pt') # Try loading the model print("Loading model...") vae2 = DemoVAE() vae2.load('models/vae_model.pt') # Test reconstruction print("Testing reconstruction...") reconstructed = vae2.transform(X, demo_data, demo_types) print(f"Reconstructed shape: {reconstructed.shape}") print("All tests passed!")