Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier | |
| from sklearn.model_selection import cross_val_score, KFold | |
| import pandas as pd | |
| from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score | |
| import matplotlib.pyplot as plt | |
| import os | |
| import joblib | |
| import logging | |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') | |
| logger = logging.getLogger(__name__) | |
| class AphasiaTreatmentPredictor: | |
| def __init__(self, prediction_type="regression", n_estimators=100, max_depth=None, random_state=42): | |
| """ | |
| Initialize the Treatment Predictor with Random Forest | |
| Args: | |
| prediction_type (str): "classification" or "regression" depending on outcome variable type | |
| n_estimators (int): Number of trees in the forest | |
| max_depth (int): Maximum depth of trees (None for unlimited) | |
| random_state (int): Random seed for reproducibility | |
| """ | |
| self.prediction_type = prediction_type | |
| self.n_estimators = n_estimators | |
| self.max_depth = max_depth | |
| self.random_state = random_state | |
| self.feature_importance = None | |
| self.feature_names = None | |
| if prediction_type == "classification": | |
| self.model = RandomForestClassifier( | |
| n_estimators=n_estimators, | |
| max_depth=max_depth, | |
| random_state=random_state | |
| ) | |
| else: # regression | |
| self.model = RandomForestRegressor( | |
| n_estimators=n_estimators, | |
| max_depth=max_depth, | |
| random_state=random_state | |
| ) | |
| def prepare_features(self, latents, demographics): | |
| """ | |
| Combine latent features with demographics | |
| Args: | |
| latents (np.ndarray): Latent representations from VAE | |
| demographics (dict or pd.DataFrame): Demographic information | |
| Returns: | |
| tuple: Combined features array and feature names | |
| """ | |
| if isinstance(demographics, dict): | |
| # For dictionary input, ensure all arrays are same length as latents | |
| n_samples = latents.shape[0] | |
| aligned_demos = {} | |
| for key, values in demographics.items(): | |
| if len(values) != n_samples: | |
| print(f"WARNING: Demographics '{key}' length ({len(values)}) doesn't match latents ({n_samples})") | |
| # Truncate or pad to match latent samples | |
| if len(values) > n_samples: | |
| aligned_demos[key] = values[:n_samples] # Truncate | |
| print(f" Truncated '{key}' to {n_samples} samples") | |
| else: | |
| # Pad with repeated values or zeros depending on type | |
| if len(values) > 0: | |
| # Use mean for numerical, mode for categorical | |
| if isinstance(values[0], (int, float, np.number)): | |
| filler = np.mean(values) | |
| else: | |
| # Use most common value | |
| from collections import Counter | |
| filler = Counter(values).most_common(1)[0][0] | |
| padding = [filler] * (n_samples - len(values)) | |
| aligned_demos[key] = list(values) + padding | |
| print(f" Padded '{key}' with {filler} to {n_samples} samples") | |
| else: | |
| # Empty array, fill with zeros | |
| aligned_demos[key] = [0] * n_samples | |
| print(f" Filled empty '{key}' with zeros to {n_samples} samples") | |
| else: | |
| aligned_demos[key] = values | |
| demo_df = pd.DataFrame(aligned_demos) | |
| else: | |
| demo_df = demographics.copy() | |
| # Ensure DataFrame has same number of rows as latents | |
| if len(demo_df) != latents.shape[0]: | |
| print(f"WARNING: Demographics DataFrame size ({len(demo_df)}) doesn't match latents ({latents.shape[0]})") | |
| if len(demo_df) > latents.shape[0]: | |
| demo_df = demo_df.iloc[:latents.shape[0]] # Truncate | |
| print(f" Truncated demographics to {latents.shape[0]} samples") | |
| else: | |
| # Cannot easily pad DataFrame, use last row or means | |
| print(f" ERROR: Cannot pad demographics DataFrame - using latents only") | |
| # Create a DataFrame with the same columns but zeros | |
| demo_df = pd.DataFrame(0, index=range(latents.shape[0]), columns=demo_df.columns) | |
| # Get categorical columns | |
| cat_columns = demo_df.select_dtypes(include=['object']).columns.tolist() | |
| # Convert categorical variables to dummy variables | |
| if cat_columns: | |
| demo_df = pd.get_dummies(demo_df, columns=cat_columns) | |
| # Get feature names | |
| latent_names = [f'latent_{i}' for i in range(latents.shape[1])] | |
| demo_names = demo_df.columns.tolist() | |
| feature_names = latent_names + demo_names | |
| # Combine latents with demographics | |
| try: | |
| features = np.hstack([latents, demo_df.values]) | |
| except ValueError as e: | |
| print(f"ERROR combining features: {e}") | |
| print(f"Latents shape: {latents.shape}, Demographics shape: {demo_df.values.shape}") | |
| # Fall back to using just latents | |
| print("Falling back to using only latent features") | |
| features = latents | |
| feature_names = latent_names | |
| return features, feature_names | |
| def fit(self, latents, demographics, treatment_outcomes): | |
| """ | |
| Fit the random forest model | |
| Args: | |
| latents (np.ndarray): Latent representations from VAE | |
| demographics (dict or pd.DataFrame): Demographic information | |
| treatment_outcomes (np.ndarray): Treatment outcome values to predict | |
| Returns: | |
| self: Trained model instance | |
| """ | |
| X, feature_names = self.prepare_features(latents, demographics) | |
| self.feature_names = feature_names | |
| logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features") | |
| print(f"Random Forest: Building {self.n_estimators} trees...") | |
| # Track progress during fit with verbose | |
| # Set verbose to 2 for detailed per-tree progress | |
| self.model.verbose = 1 | |
| self.model.fit(X, treatment_outcomes) | |
| # Calculate feature importance | |
| self.feature_importance = pd.DataFrame({ | |
| 'feature': feature_names, | |
| 'importance': self.model.feature_importances_ | |
| }).sort_values('importance', ascending=False) | |
| print(f"Random Forest: Training complete. Top features: {', '.join(self.feature_importance['feature'].head(3).tolist())}") | |
| return self | |
| def predict(self, latents, demographics): | |
| """ | |
| Predict treatment outcomes for new patients | |
| Args: | |
| latents (np.ndarray): Latent representations from VAE | |
| demographics (dict or pd.DataFrame): Demographic information | |
| Returns: | |
| tuple: Predictions and prediction uncertainty (std deviation) | |
| """ | |
| X, _ = self.prepare_features(latents, demographics) | |
| predictions = self.model.predict(X) | |
| # Get prediction intervals using tree variance | |
| if self.prediction_type == "regression": | |
| tree_predictions = np.array([tree.predict(X) | |
| for tree in self.model.estimators_]) | |
| prediction_std = np.std(tree_predictions, axis=0) | |
| else: # classification | |
| # For classification, use probability as a measure of confidence | |
| proba = self.model.predict_proba(X) | |
| # Use max probability as confidence measure | |
| prediction_std = 1 - np.max(proba, axis=1) | |
| return predictions, prediction_std | |
| def predict_proba(self, latents, demographics): | |
| """ | |
| Get probability estimates for classification | |
| Args: | |
| latents (np.ndarray): Latent representations from VAE | |
| demographics (dict or pd.DataFrame): Demographic information | |
| Returns: | |
| np.ndarray: Probability estimates for each class | |
| """ | |
| if self.prediction_type != "classification": | |
| raise ValueError("Probability prediction only available for classification") | |
| X, _ = self.prepare_features(latents, demographics) | |
| return self.model.predict_proba(X) | |
| def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5): | |
| """ | |
| Perform cross-validation | |
| Args: | |
| latents (np.ndarray): Latent representations from VAE | |
| demographics (dict or pd.DataFrame): Demographic information | |
| treatment_outcomes (np.ndarray): Treatment outcome values to predict | |
| n_splits (int): Number of folds for cross-validation | |
| Returns: | |
| dict: Cross-validation results | |
| """ | |
| X, feature_names = self.prepare_features(latents, demographics) | |
| self.feature_names = feature_names | |
| # Adjust n_splits if we have too few samples | |
| sample_count = len(treatment_outcomes) | |
| if sample_count < n_splits * 2: # Need at least 2 samples per fold | |
| adjusted_n_splits = max(2, sample_count // 2) # At least 2 folds, each with multiple samples | |
| logger.warning(f"Too few samples ({sample_count}) for {n_splits} folds. Adjusting to {adjusted_n_splits} folds.") | |
| print(f"Random Forest: Starting {adjusted_n_splits}-fold cross-validation with {sample_count} samples") | |
| n_splits = adjusted_n_splits | |
| else: | |
| logger.info(f"Running {n_splits}-fold cross-validation on {sample_count} samples") | |
| print(f"Random Forest: Starting {n_splits}-fold cross-validation with {sample_count} samples") | |
| # Use stratified KFold for regression to ensure balanced folds | |
| # or LeaveOneOut for very small datasets | |
| if sample_count <= 5: | |
| from sklearn.model_selection import LeaveOneOut | |
| logger.warning(f"Using Leave-One-Out CV for small dataset with {sample_count} samples") | |
| print(f"Random Forest: Using Leave-One-Out cross-validation due to small sample size ({sample_count})") | |
| kf = LeaveOneOut() | |
| cv_iterator = kf.split(X) | |
| else: | |
| kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state) | |
| cv_iterator = kf.split(X) | |
| cv_scores = [] | |
| predictions = np.zeros_like(treatment_outcomes) | |
| prediction_stds = np.zeros_like(treatment_outcomes) | |
| fold_metrics = [] | |
| for fold, (train_idx, test_idx) in enumerate(cv_iterator): | |
| X_train, X_test = X[train_idx], X[test_idx] | |
| y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx] | |
| print(f"Random Forest: Training fold {fold+1}/{n_splits} - {len(X_train)} training samples, {len(X_test)} test samples") | |
| # Clone the model for this fold | |
| if self.prediction_type == "classification": | |
| fold_model = RandomForestClassifier( | |
| n_estimators=self.n_estimators, | |
| max_depth=self.max_depth, | |
| random_state=self.random_state, | |
| verbose=1 # Add verbosity | |
| ) | |
| else: | |
| fold_model = RandomForestRegressor( | |
| n_estimators=self.n_estimators, | |
| max_depth=self.max_depth, | |
| random_state=self.random_state, | |
| verbose=1 # Add verbosity | |
| ) | |
| # Train the model | |
| fold_model.fit(X_train, y_train) | |
| # Make predictions | |
| pred = fold_model.predict(X_test) | |
| # Store predictions | |
| predictions[test_idx] = pred | |
| # Calculate metrics | |
| if self.prediction_type == "regression": | |
| rmse = np.sqrt(mean_squared_error(y_test, pred)) | |
| # R-squared requires at least 2 samples and some variance in the target | |
| if len(y_test) >= 2 and np.var(y_test) > 1e-10: | |
| r2 = r2_score(y_test, pred) | |
| else: | |
| r2 = np.nan | |
| logger.warning(f"Fold {fold+1}: R² not calculated (insufficient samples or variance)") | |
| print(f"Random Forest: Fold {fold+1} - R² not calculated (insufficient samples or variance)") | |
| # MSE can always be calculated | |
| mse = rmse**2 | |
| metrics = { | |
| "r2": r2, | |
| "rmse": rmse, | |
| "mse": mse | |
| } | |
| # Add other useful metrics if there are enough samples | |
| if len(y_test) >= 2 and np.var(y_test) > 1e-10: | |
| from sklearn.metrics import explained_variance_score | |
| try: | |
| ev = explained_variance_score(y_test, pred) | |
| metrics["explained_variance"] = ev | |
| except: | |
| # Skip if it can't be calculated | |
| pass | |
| # Get prediction intervals using tree variance | |
| tree_predictions = np.array([tree.predict(X_test) | |
| for tree in fold_model.estimators_]) | |
| pred_std = np.std(tree_predictions, axis=0) | |
| prediction_stds[test_idx] = pred_std | |
| else: # classification | |
| acc = accuracy_score(y_test, pred) | |
| prec = precision_score(y_test, pred, average='weighted', zero_division=0) | |
| rec = recall_score(y_test, pred, average='weighted', zero_division=0) | |
| f1 = f1_score(y_test, pred, average='weighted', zero_division=0) | |
| metrics = { | |
| "accuracy": acc, | |
| "precision": prec, | |
| "recall": rec, | |
| "f1": f1 | |
| } | |
| # Use probability as a measure of confidence | |
| proba = fold_model.predict_proba(X_test) | |
| # Use max probability as confidence measure | |
| pred_std = 1 - np.max(proba, axis=1) | |
| prediction_stds[test_idx] = pred_std | |
| fold_metrics.append(metrics) | |
| logger.info(f"Fold {fold+1} metrics: {metrics}") | |
| # Print a more user-friendly version of the fold results | |
| if self.prediction_type == "regression": | |
| r2_val = metrics.get('r2', np.nan) | |
| rmse_val = metrics.get('rmse', np.nan) | |
| r2_text = f"R² = {r2_val:.4f}" if not np.isnan(r2_val) else "R² = N/A" | |
| print(f"Random Forest: Fold {fold+1} results - {r2_text}, RMSE = {rmse_val:.4f}") | |
| else: | |
| acc_val = metrics.get('accuracy', 0) | |
| f1_val = metrics.get('f1', 0) | |
| print(f"Random Forest: Fold {fold+1} results - Accuracy = {acc_val:.4f}, F1 = {f1_val:.4f}") | |
| # Calculate average metrics | |
| avg_metrics = {} | |
| for key in fold_metrics[0].keys(): | |
| # Filter out nan values when calculating means | |
| values = [fold[key] for fold in fold_metrics if key in fold and not (isinstance(fold[key], float) and np.isnan(fold[key]))] | |
| if values: # Only calculate mean if we have valid values | |
| avg_metrics[key] = np.mean(values) | |
| else: | |
| avg_metrics[key] = np.nan | |
| logger.info(f"Average CV metrics: {avg_metrics}") | |
| # Print a summary of cross-validation performance | |
| if self.prediction_type == "regression": | |
| r2_avg = avg_metrics.get('r2', np.nan) | |
| rmse_avg = avg_metrics.get('rmse', np.nan) | |
| r2_text = f"R² = {r2_avg:.4f}" if not np.isnan(r2_avg) else "R² = N/A" | |
| print(f"Random Forest: Cross-validation complete - Average {r2_text}, RMSE = {rmse_avg:.4f}") | |
| else: | |
| acc_avg = avg_metrics.get('accuracy', 0) | |
| f1_avg = avg_metrics.get('f1', 0) | |
| print(f"Random Forest: Cross-validation complete - Average Accuracy = {acc_avg:.4f}, F1 = {f1_avg:.4f}") | |
| # Train final model on all data | |
| print(f"Random Forest: Training final model on all {len(X)} samples...") | |
| self.model.verbose = 1 | |
| self.model.fit(X, treatment_outcomes) | |
| # Calculate feature importance | |
| self.feature_importance = pd.DataFrame({ | |
| 'feature': feature_names, | |
| 'importance': self.model.feature_importances_ | |
| }).sort_values('importance', ascending=False) | |
| return { | |
| "mean_metrics": avg_metrics, | |
| "fold_metrics": fold_metrics, | |
| "predictions": predictions, | |
| "prediction_stds": prediction_stds, | |
| "feature_importance": self.feature_importance | |
| } | |
| def get_feature_importance(self): | |
| """ | |
| Get feature importance from the trained model | |
| Returns: | |
| pd.DataFrame: Feature importance values | |
| """ | |
| if self.feature_importance is None: | |
| raise ValueError("Model must be trained first") | |
| return self.feature_importance | |
| def plot_feature_importance(self, top_n=10): | |
| """ | |
| Plot feature importance | |
| Args: | |
| top_n (int): Number of top features to show | |
| Returns: | |
| matplotlib.figure.Figure: Feature importance plot | |
| """ | |
| if self.feature_importance is None: | |
| raise ValueError("Model must be trained first") | |
| # Get top N features | |
| top_features = self.feature_importance.head(top_n) | |
| plt.figure(figsize=(10, 6)) | |
| plt.barh(range(len(top_features)), | |
| top_features['importance'], | |
| align='center') | |
| plt.yticks(range(len(top_features)), | |
| top_features['feature']) | |
| plt.xlabel('Importance') | |
| plt.ylabel('Features') | |
| plt.title('Feature Importance in Treatment Outcome Prediction') | |
| plt.tight_layout() | |
| return plt.gcf() | |
| def save_model(self, path="results/treatment_predictor.joblib"): | |
| """ | |
| Save the trained model to disk | |
| Args: | |
| path (str): Path to save the model | |
| """ | |
| # Create directory if it doesn't exist | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| # Save model and metadata | |
| joblib.dump({ | |
| 'model': self.model, | |
| 'feature_names': self.feature_names, | |
| 'feature_importance': self.feature_importance, | |
| 'prediction_type': self.prediction_type, | |
| 'n_estimators': self.n_estimators, | |
| 'max_depth': self.max_depth, | |
| 'random_state': self.random_state | |
| }, path) | |
| logger.info(f"Model saved to {path}") | |
| def load_model(cls, path="results/treatment_predictor.joblib"): | |
| """ | |
| Load a trained model from disk | |
| Args: | |
| path (str): Path to load the model from | |
| Returns: | |
| AphasiaTreatmentPredictor: Loaded model instance | |
| """ | |
| data = joblib.load(path) | |
| # Create new instance | |
| predictor = cls( | |
| prediction_type=data['prediction_type'], | |
| n_estimators=data['n_estimators'], | |
| max_depth=data['max_depth'], | |
| random_state=data['random_state'] | |
| ) | |
| # Restore model and metadata | |
| predictor.model = data['model'] | |
| predictor.feature_names = data['feature_names'] | |
| predictor.feature_importance = data['feature_importance'] | |
| logger.info(f"Model loaded from {path}") | |
| return predictor | |
| def train_predictor_from_latents(latents, outcomes, demographics=None, prediction_type="regression", cv=5, **kwargs): | |
| """ | |
| Train a treatment outcome predictor from VAE latent representations | |
| Args: | |
| latents (np.ndarray): Latent representations from VAE | |
| outcomes (np.ndarray): Treatment outcome values | |
| demographics (dict or pd.DataFrame, optional): Demographic information to include as features | |
| prediction_type (str): "classification" or "regression" | |
| cv (int): Number of folds for cross-validation | |
| **kwargs: Additional parameters for the AphasiaTreatmentPredictor | |
| Returns: | |
| dict: Training results and trained model | |
| """ | |
| logger.info(f"Training {prediction_type} model for treatment prediction") | |
| # Create predictor | |
| predictor = AphasiaTreatmentPredictor(prediction_type=prediction_type, **kwargs) | |
| # Run cross-validation | |
| cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv) | |
| # Save the model | |
| predictor.save_model() | |
| return { | |
| "predictor": predictor, | |
| "cv_results": cv_results, | |
| "feature_importance": predictor.get_feature_importance() | |
| } |