File size: 1,544 Bytes
763369a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import os
from sklearn.base import BaseEstimator
import json

class SimplifiedVAE(nn.Module):
    def __init__(self, input_dim=100, latent_dim=16, demo_dim=4):
        super(SimplifiedVAE, self).__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        self.demo_dim = demo_dim
        
        # Create layers with explicit dtype
        self.enc1 = nn.Linear(input_dim, 128)
        self.enc2 = nn.Linear(128, latent_dim)
        
        # Decoder
        self.dec1 = nn.Linear(latent_dim+demo_dim, 128)
        self.dec2 = nn.Linear(128, input_dim)
    
    def encode(self, x):
        h = F.relu(self.enc1(x))
        return self.enc2(h)
    
    def decode(self, z, demo):
        z_combined = torch.cat([z, demo], dim=1)
        h = F.relu(self.dec1(z_combined))
        return self.dec2(h)

# Create basic synthetic data
input_dim = 100
demo_dim = 4
latent_dim = 16

# Create model
print("Creating model...")
model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
print(f"Model created successfully.")

# Save state dict
os.makedirs("models", exist_ok=True)
print("Saving model...")
torch.save(model.state_dict(), "models/simple_vae.pt")
print("Model saved.")

# Create a new model and load the state dict
print("Loading model...")
new_model = SimplifiedVAE(input_dim, latent_dim, demo_dim)
new_model.load_state_dict(torch.load("models/simple_vae.pt"))
print("Model loaded successfully.")

print("All tests passed!")