import torch import torch.nn as nn import torch.nn.functional as F import numpy as np import os from sklearn.base import BaseEstimator import json class SimplifiedVAE(nn.Module): def __init__(self, input_dim=100, latent_dim=16, demo_dim=4): super(SimplifiedVAE, self).__init__() self.input_dim = input_dim self.latent_dim = latent_dim self.demo_dim = demo_dim # Create layers with explicit dtype self.enc1 = nn.Linear(input_dim, 128) self.enc2 = nn.Linear(128, latent_dim) # Decoder self.dec1 = nn.Linear(latent_dim+demo_dim, 128) self.dec2 = nn.Linear(128, input_dim) def encode(self, x): h = F.relu(self.enc1(x)) return self.enc2(h) def decode(self, z, demo): z_combined = torch.cat([z, demo], dim=1) h = F.relu(self.dec1(z_combined)) return self.dec2(h) # Create basic synthetic data input_dim = 100 demo_dim = 4 latent_dim = 16 # Create model print("Creating model...") model = SimplifiedVAE(input_dim, latent_dim, demo_dim) print(f"Model created successfully.") # Save state dict os.makedirs("models", exist_ok=True) print("Saving model...") torch.save(model.state_dict(), "models/simple_vae.pt") print("Model saved.") # Create a new model and load the state dict print("Loading model...") new_model = SimplifiedVAE(input_dim, latent_dim, demo_dim) new_model.load_state_dict(torch.load("models/simple_vae.pt")) print("Model loaded successfully.") print("All tests passed!")