Healthmodels / src /model.py
theaniketgiri's picture
first
902fa1b
# Enhanced version with key improvements
# model.py - Add validation and better loss
import torch
import torch.nn as nn
import torch.nn.functional as F
class TabularVAE(nn.Module):
def __init__(self, input_dim: int, hidden_dims=(64, 32), latent_dim=16):
super().__init__()
self.input_dim = input_dim
self.latent_dim = latent_dim
# Encoder
dims = [input_dim, *hidden_dims]
self.encoder_layers = nn.ModuleList([
nn.Linear(dims[i], dims[i+1]) for i in range(len(dims)-1)
])
self.fc_mu = nn.Linear(hidden_dims[-1], latent_dim)
self.fc_logvar = nn.Linear(hidden_dims[-1], latent_dim)
# Decoder
dims_rev = [latent_dim, *reversed(hidden_dims)]
self.decoder_layers = nn.ModuleList([
nn.Linear(dims_rev[i], dims_rev[i+1]) for i in range(len(dims_rev)-1)
])
self.output_layer = nn.Linear(hidden_dims[0], input_dim)
# Add dropout for better generalization
self.dropout = nn.Dropout(0.1)
def encode(self, x):
h = x
for layer in self.encoder_layers:
h = F.relu(layer(h))
h = self.dropout(h)
mu = self.fc_mu(h)
logvar = self.fc_logvar(h)
return mu, logvar
def reparameterize(self, mu, logvar):
std = torch.exp(0.5 * logvar)
eps = torch.randn_like(std)
return mu + eps * std
def decode(self, z):
h = z
for layer in self.decoder_layers:
h = F.relu(layer(h))
h = self.dropout(h)
return self.output_layer(h)
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
recon = self.decode(z)
return recon, mu, logvar