Spaces:
Sleeping
Sleeping
| import os | |
| # Set Huggingface cache directory to avoid permission issues | |
| os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache') | |
| os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True) | |
| os.makedirs('models', exist_ok=True) | |
| import numpy as np | |
| import torch | |
| from vae_model import DemoVAE | |
| import matplotlib.pyplot as plt | |
| from visualization import plot_learning_curves | |
| print("Creating synthetic test data...") | |
| # Create small synthetic dataset with only 5 samples | |
| input_dim = 100 | |
| n_samples = 5 | |
| X = np.random.randn(n_samples, input_dim) | |
| demo_data = [ | |
| np.random.normal(60, 10, n_samples), # age | |
| np.random.choice(['M', 'F'], n_samples), # sex | |
| np.random.normal(24, 12, n_samples), # months post stroke | |
| np.random.normal(50, 15, n_samples) # WAB score | |
| ] | |
| demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] | |
| print("Testing DemoVAE initialization...") | |
| # Initialize with nepochs=3 for fast testing | |
| vae = DemoVAE(latent_dim=16, nepochs=3, bsize=5) | |
| print("Testing DemoVAE fit method...") | |
| # Fit model | |
| train_losses, val_losses = vae.fit(X, demo_data, demo_types) | |
| print(f"Train losses shape: {len(train_losses)}") | |
| print(f"Val losses shape: {len(val_losses)}") | |
| print("Testing get_latents method...") | |
| # Test get_latents | |
| latents = vae.get_latents(X) | |
| print(f"Latents shape: {latents.shape}") | |
| print("Testing encode method...") | |
| # Test encode | |
| latents2 = vae.encode(X) | |
| print(f"Latents from encode shape: {latents2.shape}") | |
| print("Testing model save...") | |
| # Save model | |
| vae.save('models/test_vae.pt') | |
| print("Testing model load...") | |
| # Load model | |
| vae2 = DemoVAE() | |
| vae2.load('models/test_vae.pt') | |
| print("Testing learning curve plotting...") | |
| # Test learning curve plotting | |
| fig = plot_learning_curves(vae2.train_losses, vae2.val_losses) | |
| plt.savefig('test_learning.png') | |
| print("Learning curve saved to test_learning.png") | |
| print("All tests passed!") |