import os # Set Huggingface cache directory to avoid permission issues os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache') os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True) import numpy as np import torch import matplotlib.pyplot as plt from vae_model import DemoVAE from visualization import plot_learning_curves, plot_fc_matrices from config import MODEL_CONFIG # Create small synthetic dataset with only 5 samples input_dim = 100 n_samples = 5 demo_dim = 4 print(f"Creating test dataset with {n_samples} samples...") # Synthetic FC matrices (Upper triangular values) X = np.random.randn(n_samples, input_dim) # Synthetic demographics demo_data = [ np.random.normal(60, 10, n_samples), # age np.random.choice(['M', 'F'], n_samples), # sex np.random.normal(24, 12, n_samples), # months post stroke np.random.normal(50, 15, n_samples) # WAB score ] # Types of demographics demo_types = ['continuous', 'categorical', 'continuous', 'continuous'] # Initialize model with updated config print("Config settings:") print(f"- Epochs: {MODEL_CONFIG['nepochs']}") print(f"- Batch size: {MODEL_CONFIG['bsize']}") print(f"- Latent dim: {MODEL_CONFIG['latent_dim']}") print("Initializing model...") vae = DemoVAE(**MODEL_CONFIG) # Train model print(f"Training model with {n_samples} samples...") train_losses, val_losses = vae.fit(X, demo_data, demo_types) print(f"Training complete! Final train loss: {train_losses[-1]:.4f}") print(f"Final validation loss: {val_losses[-1]:.4f}") # Save model os.makedirs("models", exist_ok=True) os.makedirs("results", exist_ok=True) print("Saving model...") vae.save('models/vae_model_small.pt') # Create learning curve visualization print("Generating learning curve visualization...") learning_fig = plot_learning_curves(train_losses, val_losses) learning_fig.savefig('results/learning_curves_small.png') print("Learning curve saved to results/learning_curves_small.png") # Generate reconstructed data print("Generating reconstructions...") reconstructed = vae.transform(X, demo_data, demo_types) # Get a single sample for FC visualization original = X[0].reshape(10, 10) # Reshape to square matrix for visualization recon = reconstructed[0].reshape(10, 10) generated = vae.transform(1, [d[:1] for d in demo_data], demo_types)[0].reshape(10, 10) # Create FC visualization print("Generating FC matrix visualization...") fc_fig = plot_fc_matrices(original, recon, generated) fc_fig.savefig('results/fc_visualization_small.png') print("FC visualization saved to results/fc_visualization_small.png") print("Test with small sample size completed successfully!")