File size: 1,797 Bytes
902fa1b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
# Enhanced version with key improvements

# model.py - Add validation and better loss
import torch
import torch.nn as nn
import torch.nn.functional as F

class TabularVAE(nn.Module):
    def __init__(self, input_dim: int, hidden_dims=(64, 32), latent_dim=16):
        super().__init__()
        self.input_dim = input_dim
        self.latent_dim = latent_dim
        
        # Encoder
        dims = [input_dim, *hidden_dims]
        self.encoder_layers = nn.ModuleList([
            nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)
        ])
        self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
        self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)
        
        # Decoder
        dims_rev = [latent_dim, *reversed(hidden_dims)]
        self.decoder_layers = nn.ModuleList([
            nn.Linear(dims_rev[i], dims_rev[i+1]) for i in range(len(dims_rev)-1)
        ])
        self.output_layer = nn.Linear(hidden_dims[0], input_dim)
        
        # Add dropout for better generalization
        self.dropout = nn.Dropout(0.1)

    def encode(self, x):
        h = x
        for layer in self.encoder_layers:
            h = F.relu(layer(h))
            h = self.dropout(h)
        mu = self.fc_mu(h)
        logvar = self.fc_logvar(h)
        return mu, logvar

    def reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = z
        for layer in self.decoder_layers:
            h = F.relu(layer(h))
            h = self.dropout(h)
        return self.output_layer(h)

    def forward(self, x):
        mu, logvar = self.encode(x)
        z = self.reparameterize(mu, logvar)
        recon = self.decode(z)
        return recon, mu, logvar