Spaces:
Sleeping
Sleeping
| # src/train.py - Enhanced training with validation for patient data | |
| import torch | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader, TensorDataset | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.model_selection import train_test_split | |
| from src.model import TabularVAE | |
| import joblib | |
| import os | |
| # Hyperparameters | |
| BATCH_SIZE = 32 # Smaller batch size for smaller dataset | |
| LR = 1e-3 | |
| EPOCHS = 150 # More epochs for smaller dataset | |
| LATENT_DIM = 8 # Smaller latent dim for smaller dataset | |
| BETA = 1.0 # KL divergence weight | |
| def vae_loss(recon, x, mu, logvar, beta=1.0): | |
| """Enhanced VAE loss with proper normalization""" | |
| batch_size = x.size(0) | |
| # Reconstruction loss (MSE) | |
| recon_loss = F.mse_loss(recon, x, reduction='sum') / batch_size | |
| # KL divergence loss | |
| kld = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) / batch_size | |
| return recon_loss + beta * kld, recon_loss, kld | |
| def train_vae(progress_callback=None): | |
| # Check if preprocessed data exists, if not create it | |
| if not os.path.exists("data/processed_patient_data.csv"): | |
| print("Preprocessed data not found. Running data preprocessing...") | |
| from src.data_preprocessing import preprocess_patient_data | |
| feature_df, encoders = preprocess_patient_data() | |
| else: | |
| print("Loading preprocessed data...") | |
| feature_df = pd.read_csv("data/processed_patient_data.csv") | |
| print(f"Dataset shape: {feature_df.shape}") | |
| print(f"Features: {list(feature_df.columns)}") | |
| # Handle missing values | |
| feature_df = feature_df.fillna(feature_df.mean()) | |
| # Split data | |
| train_df, val_df = train_test_split(feature_df, test_size=0.2, random_state=42) | |
| # Scale data | |
| scaler = StandardScaler() | |
| train_data = scaler.fit_transform(train_df.values) | |
| val_data = scaler.transform(val_df.values) | |
| print(f"Training data shape: {train_data.shape}") | |
| print(f"Validation data shape: {val_data.shape}") | |
| # Create data loaders | |
| train_tensor = torch.tensor(train_data, dtype=torch.float32) | |
| val_tensor = torch.tensor(val_data, dtype=torch.float32) | |
| train_loader = DataLoader(TensorDataset(train_tensor), batch_size=BATCH_SIZE, shuffle=True) | |
| val_loader = DataLoader(TensorDataset(val_tensor), batch_size=BATCH_SIZE, shuffle=False) | |
| # Initialize model with correct input dimension | |
| input_dim = train_data.shape[1] | |
| model = TabularVAE(input_dim=input_dim, latent_dim=LATENT_DIM, hidden_dims=(32, 16)) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=LR) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=15, factor=0.5) | |
| best_val_loss = float('inf') | |
| patience_counter = 0 | |
| early_stopping_patience = 30 | |
| print(f"Model initialized with {input_dim} input features and {LATENT_DIM} latent dimensions") | |
| print(f"Training for {EPOCHS} epochs...") | |
| # Training loop | |
| for epoch in range(EPOCHS): | |
| # Training | |
| model.train() | |
| train_loss = 0 | |
| train_recon = 0 | |
| train_kld = 0 | |
| for (batch,) in train_loader: | |
| optimizer.zero_grad() | |
| recon, mu, logvar = model(batch) | |
| loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA) | |
| loss.backward() | |
| optimizer.step() | |
| train_loss += loss.item() | |
| train_recon += recon_loss.item() | |
| train_kld += kld_loss.item() | |
| # Validation | |
| model.eval() | |
| val_loss = 0 | |
| val_recon = 0 | |
| val_kld = 0 | |
| with torch.no_grad(): | |
| for (batch,) in val_loader: | |
| recon, mu, logvar = model(batch) | |
| loss, recon_loss, kld_loss = vae_loss(recon, batch, mu, logvar, BETA) | |
| val_loss += loss.item() | |
| val_recon += recon_loss.item() | |
| val_kld += kld_loss.item() | |
| # Calculate averages | |
| train_loss /= len(train_loader) | |
| val_loss /= len(val_loader) | |
| # Learning rate scheduling | |
| scheduler.step(val_loss) | |
| # Save best model | |
| if val_loss < best_val_loss: | |
| best_val_loss = val_loss | |
| torch.save(model.state_dict(), "models/best_vae_model.pth") | |
| patience_counter = 0 | |
| else: | |
| patience_counter += 1 | |
| # Early stopping | |
| if patience_counter >= early_stopping_patience: | |
| print(f"Early stopping at epoch {epoch+1}") | |
| break | |
| # Print progress | |
| if epoch % 10 == 0 or epoch == EPOCHS - 1: | |
| print(f"Epoch {epoch+1}/{EPOCHS}") | |
| print(f"Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}") | |
| print(f"Train Recon: {train_recon:.4f}, Train KLD: {train_kld:.4f}") | |
| print(f"LR: {optimizer.param_groups[0]['lr']:.6f}") | |
| # Call progress callback if provided | |
| if progress_callback: | |
| progress_callback(epoch+1, train_loss, val_loss, best_val_loss) | |
| # Save final model and scaler | |
| torch.save(model.state_dict(), "models/vae_model.pth") | |
| joblib.dump(scaler, "models/scaler.pkl") | |
| # Save feature names for API | |
| feature_names = list(feature_df.columns) | |
| joblib.dump(feature_names, "models/feature_names.pkl") | |
| print("Training completed!") | |
| print(f"Best validation loss: {best_val_loss:.4f}") | |
| print(f"Model saved with {input_dim} input features") | |
| return model, scaler, feature_names | |
| if __name__ == "__main__": | |
| model, scaler, feature_names = train_vae() |