Spaces:
Sleeping
Sleeping
File size: 1,797 Bytes
902fa1b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# Enhanced version with key improvements
# model.py - Add validation and better loss
import torch
import torch.nn as nn
import torch.nn.functional as F
class TabularVAE(nn.Module):
def __init__(self, input_dim: int, hidden_dims=(64, 32), latent_dim=16):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
# Encoder
dims = [input_dim, *hidden_dims]
self.encoder_layers = nn.ModuleList([
nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)
])
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)
# Decoder
dims_rev = [latent_dim, *reversed(hidden_dims)]
self.decoder_layers = nn.ModuleList([
nn.Linear(dims_rev[i], dims_rev[i+1]) for i in range(len(dims_rev)-1)
])
self.output_layer = nn.Linear(hidden_dims[0], input_dim)
# Add dropout for better generalization
self.dropout = nn.Dropout(0.1)
def encode(self, x):
h = x
for layer in self.encoder_layers:
h = F.relu(layer(h))
h = self.dropout(h)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = z
for layer in self.decoder_layers:
h = F.relu(layer(h))
h = self.dropout(h)
return self.output_layer(h)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar |