# src/train.py - Enhanced training with validation for patient data import torch import torch.nn.functional as F from torch.utils.data import DataLoader, TensorDataset import pandas as pd import numpy as np from sklearn.preprocessing import StandardScaler from sklearn.model_selection import train_test_split from src.model import TabularVAE import joblib import os # Hyperparameters BATCH_SIZE = 32 # Smaller batch size for smaller dataset LR = 1e-3 EPOCHS = 150 # More epochs for smaller dataset LATENT_DIM = 8 # Smaller latent dim for smaller dataset BETA = 1.0 # KL divergence weight def vae_loss(recon, x, mu, logvar, beta=1.0): """Enhanced VAE loss with proper normalization""" batch_size = x.size(0) # Reconstruction loss (MSE) recon_loss = F.mse_loss(recon, x, reduction='sum') / batch_size # KL divergence loss kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size return recon_loss + beta * kld, recon_loss, kld def train_vae(progress_callback=None): # Check if preprocessed data exists, if not create it if not os.path.exists("data/processed_patient_data.csv"): print("Preprocessed data not found. Running data preprocessing...") from src.data_preprocessing import preprocess_patient_data feature_df, encoders = preprocess_patient_data() else: print("Loading preprocessed data...") feature_df = pd.read_csv("data/processed_patient_data.csv") print(f"Dataset shape: {feature_df.shape}") print(f"Features: {list(feature_df.columns)}") # Handle missing values feature_df = feature_df.fillna(feature_df.mean()) # Split data train_df, val_df = train_test_split(feature_df, test_size=0.2, random_state=42) # Scale data scaler = StandardScaler() train_data = scaler.fit_transform(train_df.values) val_data = scaler.transform(val_df.values) print(f"Training data shape: {train_data.shape}") print(f"Validation data shape: {val_data.shape}") # Create data loaders train_tensor = torch.tensor(train_data, dtype=torch.float32) val_tensor = torch.tensor(val_data, dtype=torch.float32) train_loader = DataLoader(TensorDataset(train_tensor), batch_size=BATCH_SIZE, shuffle=True) val_loader = DataLoader(TensorDataset(val_tensor), batch_size=BATCH_SIZE, shuffle=False) # Initialize model with correct input dimension input_dim = train_data.shape[1] model = TabularVAE(input_dim=input_dim, latent_dim=LATENT_DIM, hidden_dims=(32, 16)) optimizer = torch.optim.Adam(model.parameters(), lr=LR) scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15, factor=0.5) best_val_loss = float('inf') patience_counter = 0 early_stopping_patience = 30 print(f"Model initialized with {input_dim} input features and {LATENT_DIM} latent dimensions") print(f"Training for {EPOCHS} epochs...") # Training loop for epoch in range(EPOCHS): # Training model.train() train_loss = 0 train_recon = 0 train_kld = 0 for (batch,) in train_loader: optimizer.zero_grad() recon, mu, logvar = model(batch) loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA) loss.backward() optimizer.step() train_loss += loss.item() train_recon += recon_loss.item() train_kld += kld_loss.item() # Validation model.eval() val_loss = 0 val_recon = 0 val_kld = 0 with torch.no_grad(): for (batch,) in val_loader: recon, mu, logvar = model(batch) loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA) val_loss += loss.item() val_recon += recon_loss.item() val_kld += kld_loss.item() # Calculate averages train_loss /= len(train_loader) val_loss /= len(val_loader) # Learning rate scheduling scheduler.step(val_loss) # Save best model if val_loss < best_val_loss: best_val_loss = val_loss torch.save(model.state_dict(), "models/best_vae_model.pth") patience_counter = 0 else: patience_counter += 1 # Early stopping if patience_counter >= early_stopping_patience: print(f"Early stopping at epoch {epoch+1}") break # Print progress if epoch % 10 == 0 or epoch == EPOCHS - 1: print(f"Epoch {epoch+1}/{EPOCHS}") print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") print(f"Train Recon: {train_recon:.4f}, Train KLD: {train_kld:.4f}") print(f"LR: {optimizer.param_groups[0]['lr']:.6f}") # Call progress callback if provided if progress_callback: progress_callback(epoch+1, train_loss, val_loss, best_val_loss) # Save final model and scaler torch.save(model.state_dict(), "models/vae_model.pth") joblib.dump(scaler, "models/scaler.pkl") # Save feature names for API feature_names = list(feature_df.columns) joblib.dump(feature_names, "models/feature_names.pkl") print("Training completed!") print(f"Best validation loss: {best_val_loss:.4f}") print(f"Model saved with {input_dim} input features") return model, scaler, feature_names if __name__ == "__main__": model, scaler, feature_names = train_vae()