import os # Set Huggingface cache directory to avoid permission issues os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache') os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True) os.makedirs('models', exist_ok=True) import numpy as np import torch from vae_model import DemoVAE import matplotlib.pyplot as plt from visualization import plot_learning_curves print("Creating synthetic test data...") # Create small synthetic dataset with only 5 samples input_dim = 100 n_samples = 5 X = np.random.randn(n_samples, input_dim) demo_data = [ np.random.normal(60, 10, n_samples), # age np.random.choice(['M', 'F'], n_samples), # sex np.random.normal(24, 12, n_samples), # months post stroke np.random.normal(50, 15, n_samples) # WAB score ] demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] print("Testing DemoVAE initialization...") # Initialize with nepochs=3 for fast testing vae = DemoVAE(latent_dim=16, nepochs=3, bsize=5) print("Testing DemoVAE fit method...") # Fit model train_losses, val_losses = vae.fit(X, demo_data, demo_types) print(f"Train losses shape: {len(train_losses)}") print(f"Val losses shape: {len(val_losses)}") print("Testing get_latents method...") # Test get_latents latents = vae.get_latents(X) print(f"Latents shape: {latents.shape}") print("Testing encode method...") # Test encode latents2 = vae.encode(X) print(f"Latents from encode shape: {latents2.shape}") print("Testing model save...") # Save model vae.save('models/test_vae.pt') print("Testing model load...") # Load model vae2 = DemoVAE() vae2.load('models/test_vae.pt') print("Testing learning curve plotting...") # Test learning curve plotting fig = plot_learning_curves(vae2.train_losses, vae2.val_losses) plt.savefig('test_learning.png') print("Learning curve saved to test_learning.png") print("All tests passed!")