Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import numpy as np | |
| import os | |
| from sklearn.base import BaseEstimator | |
| import json | |
| class SimplifiedVAE(nn.Module): | |
| def __init__(self, input_dim=100, latent_dim=16, demo_dim=4): | |
| super(SimplifiedVAE, self).__init__() | |
| self.input_dim = input_dim | |
| self.latent_dim = latent_dim | |
| self.demo_dim = demo_dim | |
| # Create layers with explicit dtype | |
| self.enc1 = nn.Linear(input_dim, 128) | |
| self.enc2 = nn.Linear(128, latent_dim) | |
| # Decoder | |
| self.dec1 = nn.Linear(latent_dim+demo_dim, 128) | |
| self.dec2 = nn.Linear(128, input_dim) | |
| def encode(self, x): | |
| h = F.relu(self.enc1(x)) | |
| return self.enc2(h) | |
| def decode(self, z, demo): | |
| z_combined = torch.cat([z, demo], dim=1) | |
| h = F.relu(self.dec1(z_combined)) | |
| return self.dec2(h) | |
| # Create basic synthetic data | |
| input_dim = 100 | |
| demo_dim = 4 | |
| latent_dim = 16 | |
| # Create model | |
| print("Creating model...") | |
| model = SimplifiedVAE(input_dim, latent_dim, demo_dim) | |
| print(f"Model created successfully.") | |
| # Save state dict | |
| os.makedirs("models", exist_ok=True) | |
| print("Saving model...") | |
| torch.save(model.state_dict(), "models/simple_vae.pt") | |
| print("Model saved.") | |
| # Create a new model and load the state dict | |
| print("Loading model...") | |
| new_model = SimplifiedVAE(input_dim, latent_dim, demo_dim) | |
| new_model.load_state_dict(torch.load("models/simple_vae.pt")) | |
| print("Model loaded successfully.") | |
| print("All tests passed!") |