AphasiaPred / test_vae.py
SreekarB's picture
Upload 10 files
763369a verified
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from sklearn.base import BaseEstimator
import json
class SimplifiedVAE(nn.Module):
def __init__(self, input_dim=100, latent_dim=16, demo_dim=4):
super(SimplifiedVAE, self).__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
self.demo_dim = demo_dim
# Create layers with explicit dtype
self.enc1 = nn.Linear(input_dim, 128)
self.enc2 = nn.Linear(128, latent_dim)
# Decoder
self.dec1 = nn.Linear(latent_dim+demo_dim, 128)
self.dec2 = nn.Linear(128, input_dim)
def encode(self, x):
h = F.relu(self.enc1(x))
return self.enc2(h)
def decode(self, z, demo):
z_combined = torch.cat([z, demo], dim=1)
h = F.relu(self.dec1(z_combined))
return self.dec2(h)
# Create basic synthetic data
input_dim = 100
demo_dim = 4
latent_dim = 16
# Create model
print("Creating model...")
model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
print(f"Model created successfully.")
# Save state dict
os.makedirs("models", exist_ok=True)
print("Saving model...")
torch.save(model.state_dict(), "models/simple_vae.pt")
print("Model saved.")
# Create a new model and load the state dict
print("Loading model...")
new_model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
new_model.load_state_dict(torch.load("models/simple_vae.pt"))
print("Model loaded successfully.")
print("All tests passed!")