Spaces:
Sleeping
Sleeping
File size: 2,682 Bytes
763369a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 |
import os
# Set Huggingface cache directory to avoid permission issues
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)
import numpy as np
import torch
import matplotlib.pyplot as plt
from vae_model import DemoVAE
from visualization import plot_learning_curves, plot_fc_matrices
from config import MODEL_CONFIG
# Create small synthetic dataset with only 5 samples
input_dim = 100
n_samples = 5
demo_dim = 4
print(f"Creating test dataset with {n_samples} samples...")
# Synthetic FC matrices (Upper triangular values)
X = np.random.randn(n_samples, input_dim)
# Synthetic demographics
demo_data = [
np.random.normal(60, 10, n_samples), # age
np.random.choice(['M', 'F'], n_samples), # sex
np.random.normal(24, 12, n_samples), # months post stroke
np.random.normal(50, 15, n_samples) # WAB score
]
# Types of demographics
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']
# Initialize model with updated config
print("Config settings:")
print(f"- Epochs: {MODEL_CONFIG['nepochs']}")
print(f"- Batch size: {MODEL_CONFIG['bsize']}")
print(f"- Latent dim: {MODEL_CONFIG['latent_dim']}")
print("Initializing model...")
vae = DemoVAE(**MODEL_CONFIG)
# Train model
print(f"Training model with {n_samples} samples...")
train_losses, val_losses = vae.fit(X, demo_data, demo_types)
print(f"Training complete! Final train loss: {train_losses[-1]:.4f}")
print(f"Final validation loss: {val_losses[-1]:.4f}")
# Save model
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)
print("Saving model...")
vae.save('models/vae_model_small.pt')
# Create learning curve visualization
print("Generating learning curve visualization...")
learning_fig = plot_learning_curves(train_losses, val_losses)
learning_fig.savefig('results/learning_curves_small.png')
print("Learning curve saved to results/learning_curves_small.png")
# Generate reconstructed data
print("Generating reconstructions...")
reconstructed = vae.transform(X, demo_data, demo_types)
# Get a single sample for FC visualization
original = X[0].reshape(10, 10) # Reshape to square matrix for visualization
recon = reconstructed[0].reshape(10, 10)
generated = vae.transform(1, [d[:1] for d in demo_data], demo_types)[0].reshape(10, 10)
# Create FC visualization
print("Generating FC matrix visualization...")
fc_fig = plot_fc_matrices(original, recon, generated)
fc_fig.savefig('results/fc_visualization_small.png')
print("FC visualization saved to results/fc_visualization_small.png")
print("Test with small sample size completed successfully!") |