Healthmodels / src /train.py
theaniketgiri's picture
first
902fa1b
# src/train.py - Enhanced training with validation for patient data
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from src.model import TabularVAE
import joblib
import os
# Hyperparameters
BATCH_SIZE = 32 # Smaller batch size for smaller dataset
LR = 1e-3
EPOCHS = 150 # More epochs for smaller dataset
LATENT_DIM = 8 # Smaller latent dim for smaller dataset
BETA = 1.0 # KL divergence weight
def vae_loss(recon, x, mu, logvar, beta=1.0):
"""Enhanced VAE loss with proper normalization"""
batch_size = x.size(0)
# Reconstruction loss (MSE)
recon_loss = F.mse_loss(recon, x, reduction='sum') / batch_size
# KL divergence loss
kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size
return recon_loss + beta * kld, recon_loss, kld
def train_vae(progress_callback=None):
# Check if preprocessed data exists, if not create it
if not os.path.exists("data/processed_patient_data.csv"):
print("Preprocessed data not found. Running data preprocessing...")
from src.data_preprocessing import preprocess_patient_data
feature_df, encoders = preprocess_patient_data()
else:
print("Loading preprocessed data...")
feature_df = pd.read_csv("data/processed_patient_data.csv")
print(f"Dataset shape: {feature_df.shape}")
print(f"Features: {list(feature_df.columns)}")
# Handle missing values
feature_df = feature_df.fillna(feature_df.mean())
# Split data
train_df, val_df = train_test_split(feature_df, test_size=0.2, random_state=42)
# Scale data
scaler = StandardScaler()
train_data = scaler.fit_transform(train_df.values)
val_data = scaler.transform(val_df.values)
print(f"Training data shape: {train_data.shape}")
print(f"Validation data shape: {val_data.shape}")
# Create data loaders
train_tensor = torch.tensor(train_data, dtype=torch.float32)
val_tensor = torch.tensor(val_data, dtype=torch.float32)
train_loader = DataLoader(TensorDataset(train_tensor), batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(TensorDataset(val_tensor), batch_size=BATCH_SIZE, shuffle=False)
# Initialize model with correct input dimension
input_dim = train_data.shape[1]
model = TabularVAE(input_dim=input_dim, latent_dim=LATENT_DIM, hidden_dims=(32, 16))
optimizer = torch.optim.Adam(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15, factor=0.5)
best_val_loss = float('inf')
patience_counter = 0
early_stopping_patience = 30
print(f"Model initialized with {input_dim} input features and {LATENT_DIM} latent dimensions")
print(f"Training for {EPOCHS} epochs...")
# Training loop
for epoch in range(EPOCHS):
# Training
model.train()
train_loss = 0
train_recon = 0
train_kld = 0
for (batch,) in train_loader:
optimizer.zero_grad()
recon, mu, logvar = model(batch)
loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA)
loss.backward()
optimizer.step()
train_loss += loss.item()
train_recon += recon_loss.item()
train_kld += kld_loss.item()
# Validation
model.eval()
val_loss = 0
val_recon = 0
val_kld = 0
with torch.no_grad():
for (batch,) in val_loader:
recon, mu, logvar = model(batch)
loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA)
val_loss += loss.item()
val_recon += recon_loss.item()
val_kld += kld_loss.item()
# Calculate averages
train_loss /= len(train_loader)
val_loss /= len(val_loader)
# Learning rate scheduling
scheduler.step(val_loss)
# Save best model
if val_loss < best_val_loss:
best_val_loss = val_loss
torch.save(model.state_dict(), "models/best_vae_model.pth")
patience_counter = 0
else:
patience_counter += 1
# Early stopping
if patience_counter >= early_stopping_patience:
print(f"Early stopping at epoch {epoch+1}")
break
# Print progress
if epoch % 10 == 0 or epoch == EPOCHS - 1:
print(f"Epoch {epoch+1}/{EPOCHS}")
print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
print(f"Train Recon: {train_recon:.4f}, Train KLD: {train_kld:.4f}")
print(f"LR: {optimizer.param_groups[0]['lr']:.6f}")
# Call progress callback if provided
if progress_callback:
progress_callback(epoch+1, train_loss, val_loss, best_val_loss)
# Save final model and scaler
torch.save(model.state_dict(), "models/vae_model.pth")
joblib.dump(scaler, "models/scaler.pkl")
# Save feature names for API
feature_names = list(feature_df.columns)
joblib.dump(feature_names, "models/feature_names.pkl")
print("Training completed!")
print(f"Best validation loss: {best_val_loss:.4f}")
print(f"Model saved with {input_dim} input features")
return model, scaler, feature_names
if __name__ == "__main__":
model, scaler, feature_names = train_vae()