File size: 2,682 Bytes
763369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import os
# Set Huggingface cache directory to avoid permission issues
os.environ['TRANSFORMERS_CACHE'] = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'hf_cache')
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True)

import numpy as np
import torch
import matplotlib.pyplot as plt
from vae_model import DemoVAE
from visualization import plot_learning_curves, plot_fc_matrices
from config import MODEL_CONFIG

# Create small synthetic dataset with only 5 samples
input_dim = 100
n_samples = 5
demo_dim = 4

print(f"Creating test dataset with {n_samples} samples...")

# Synthetic FC matrices (Upper triangular values)
X = np.random.randn(n_samples, input_dim)

# Synthetic demographics
demo_data = [
    np.random.normal(60, 10, n_samples),  # age
    np.random.choice(['M', 'F'], n_samples),  # sex
    np.random.normal(24, 12, n_samples),  # months post stroke
    np.random.normal(50, 15, n_samples)   # WAB score
]

# Types of demographics
demo_types = ['continuous', 'categorical', 'continuous', 'continuous']

# Initialize model with updated config
print("Config settings:")
print(f"- Epochs: {MODEL_CONFIG['nepochs']}")
print(f"- Batch size: {MODEL_CONFIG['bsize']}")
print(f"- Latent dim: {MODEL_CONFIG['latent_dim']}")

print("Initializing model...")
vae = DemoVAE(**MODEL_CONFIG)

# Train model
print(f"Training model with {n_samples} samples...")
train_losses, val_losses = vae.fit(X, demo_data, demo_types)

print(f"Training complete! Final train loss: {train_losses[-1]:.4f}")
print(f"Final validation loss: {val_losses[-1]:.4f}")

# Save model
os.makedirs("models", exist_ok=True)
os.makedirs("results", exist_ok=True)
print("Saving model...")
vae.save('models/vae_model_small.pt')

# Create learning curve visualization
print("Generating learning curve visualization...")
learning_fig = plot_learning_curves(train_losses, val_losses)
learning_fig.savefig('results/learning_curves_small.png')
print("Learning curve saved to results/learning_curves_small.png")

# Generate reconstructed data
print("Generating reconstructions...")
reconstructed = vae.transform(X, demo_data, demo_types)

# Get a single sample for FC visualization
original = X[0].reshape(10, 10)  # Reshape to square matrix for visualization
recon = reconstructed[0].reshape(10, 10)
generated = vae.transform(1, [d[:1] for d in demo_data], demo_types)[0].reshape(10, 10)

# Create FC visualization
print("Generating FC matrix visualization...")
fc_fig = plot_fc_matrices(original, recon, generated)
fc_fig.savefig('results/fc_visualization_small.png')
print("FC visualization saved to results/fc_visualization_small.png")

print("Test with small sample size completed successfully!")