Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| from vae_model import DemoVAE | |
| # Create synthetic data | |
| input_dim = 100 | |
| n_samples = 20 | |
| demo_dim = 4 | |
| # Synthetic FC matrices (Upper triangular values) | |
| X = np.random.randn(n_samples, input_dim) | |
| # Synthetic demographics | |
| demo_data = [ | |
| np.random.normal(60, 10, n_samples), # age | |
| np.random.choice([0, 1], n_samples), # sex | |
| np.random.normal(24, 12, n_samples), # months post stroke | |
| np.random.normal(50, 15, n_samples) # WAB score | |
| ] | |
| # Types of demographics | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| # Initialize model | |
| model_config = { | |
| 'latent_dim': 16, | |
| 'nepochs': 5, | |
| 'bsize': 5, | |
| 'use_cuda': False | |
| } | |
| print("Initializing model...") | |
| vae = DemoVAE(**model_config) | |
| # Train model | |
| print("Training model...") | |
| train_losses, val_losses = vae.fit(X, demo_data, demo_types) | |
| print(f"Training complete! Train loss: {train_losses[-1]}, Val loss: {val_losses[-1]}") | |
| # Check shapes of losses | |
| print(f"Train losses shape: {len(train_losses)}") | |
| print(f"Val losses shape: {len(val_losses)}") | |
| # Save model | |
| print("Saving model...") | |
| vae.save('models/vae_model.pt') | |
| # Try loading the model | |
| print("Loading model...") | |
| vae2 = DemoVAE() | |
| vae2.load('models/vae_model.pt') | |
| # Test reconstruction | |
| print("Testing reconstruction...") | |
| reconstructed = vae2.transform(X, demo_data, demo_types) | |
| print(f"Reconstructed shape: {reconstructed.shape}") | |
| print("All tests passed!") |