# Enhanced version with key improvements # model.py - Add validation and better loss import torch import torch.nn as nn import torch.nn.functional as F class TabularVAE(nn.Module): def __init__(self, input_dim: int, hidden_dims=(64, 32), latent_dim=16): super().__init__() self.input_dim = input_dim self.latent_dim = latent_dim # Encoder dims = [input_dim, *hidden_dims] self.encoder_layers = nn.ModuleList([ nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1) ]) self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim) self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim) # Decoder dims_rev = [latent_dim, *reversed(hidden_dims)] self.decoder_layers = nn.ModuleList([ nn.Linear(dims_rev[i], dims_rev[i+1]) for i in range(len(dims_rev)-1) ]) self.output_layer = nn.Linear(hidden_dims[0], input_dim) # Add dropout for better generalization self.dropout = nn.Dropout(0.1) def encode(self, x): h = x for layer in self.encoder_layers: h = F.relu(layer(h)) h = self.dropout(h) mu = self.fc_mu(h) logvar = self.fc_logvar(h) return mu, logvar def reparameterize(self, mu, logvar): std = torch.exp(0.5 * logvar) eps = torch.randn_like(std) return mu + eps * std def decode(self, z): h = z for layer in self.decoder_layers: h = F.relu(layer(h)) h = self.dropout(h) return self.output_layer(h) def forward(self, x): mu, logvar = self.encode(x) z = self.reparameterize(mu, logvar) recon = self.decode(z) return recon, mu, logvar