AphasiaPred / rcf_prediction.py
SreekarB's picture
Upload 13 files
37a1b01 verified
import numpy as np
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier
from sklearn.model_selection import cross_val_score, KFold
import pandas as pd
from sklearn.metrics import mean_squared_error, r2_score, accuracy_score, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
import os
import joblib
import logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class AphasiaTreatmentPredictor:
def __init__(self, prediction_type="regression", n_estimators=100, max_depth=None, random_state=42):
"""
Initialize the Treatment Predictor with Random Forest
Args:
prediction_type (str): "classification" or "regression" depending on outcome variable type
n_estimators (int): Number of trees in the forest
max_depth (int): Maximum depth of trees (None for unlimited)
random_state (int): Random seed for reproducibility
"""
self.prediction_type = prediction_type
self.n_estimators = n_estimators
self.max_depth = max_depth
self.random_state = random_state
self.feature_importance = None
self.feature_names = None
if prediction_type == "classification":
self.model = RandomForestClassifier(
n_estimators=n_estimators,
max_depth=max_depth,
random_state=random_state
)
else: # regression
self.model = RandomForestRegressor(
n_estimators=n_estimators,
max_depth=max_depth,
random_state=random_state
)
def prepare_features(self, latents, demographics):
"""
Combine latent features with demographics
Args:
latents (np.ndarray): Latent representations from VAE
demographics (dict or pd.DataFrame): Demographic information
Returns:
tuple: Combined features array and feature names
"""
if isinstance(demographics, dict):
# For dictionary input, ensure all arrays are same length as latents
n_samples = latents.shape[0]
aligned_demos = {}
for key, values in demographics.items():
if len(values) != n_samples:
print(f"WARNING: Demographics '{key}' length ({len(values)}) doesn't match latents ({n_samples})")
# Truncate or pad to match latent samples
if len(values) > n_samples:
aligned_demos[key] = values[:n_samples] # Truncate
print(f" Truncated '{key}' to {n_samples} samples")
else:
# Pad with repeated values or zeros depending on type
if len(values) > 0:
# Use mean for numerical, mode for categorical
if isinstance(values[0], (int, float, np.number)):
filler = np.mean(values)
else:
# Use most common value
from collections import Counter
filler = Counter(values).most_common(1)[0][0]
padding = [filler] * (n_samples - len(values))
aligned_demos[key] = list(values) + padding
print(f" Padded '{key}' with {filler} to {n_samples} samples")
else:
# Empty array, fill with zeros
aligned_demos[key] = [0] * n_samples
print(f" Filled empty '{key}' with zeros to {n_samples} samples")
else:
aligned_demos[key] = values
demo_df = pd.DataFrame(aligned_demos)
else:
demo_df = demographics.copy()
# Ensure DataFrame has same number of rows as latents
if len(demo_df) != latents.shape[0]:
print(f"WARNING: Demographics DataFrame size ({len(demo_df)}) doesn't match latents ({latents.shape[0]})")
if len(demo_df) > latents.shape[0]:
demo_df = demo_df.iloc[:latents.shape[0]] # Truncate
print(f" Truncated demographics to {latents.shape[0]} samples")
else:
# Cannot easily pad DataFrame, use last row or means
print(f" ERROR: Cannot pad demographics DataFrame - using latents only")
# Create a DataFrame with the same columns but zeros
demo_df = pd.DataFrame(0, index=range(latents.shape[0]), columns=demo_df.columns)
# Get categorical columns
cat_columns = demo_df.select_dtypes(include=['object']).columns.tolist()
# Convert categorical variables to dummy variables
if cat_columns:
demo_df = pd.get_dummies(demo_df, columns=cat_columns)
# Get feature names
latent_names = [f'latent_{i}' for i in range(latents.shape[1])]
demo_names = demo_df.columns.tolist()
feature_names = latent_names + demo_names
# Combine latents with demographics
try:
features = np.hstack([latents, demo_df.values])
except ValueError as e:
print(f"ERROR combining features: {e}")
print(f"Latents shape: {latents.shape}, Demographics shape: {demo_df.values.shape}")
# Fall back to using just latents
print("Falling back to using only latent features")
features = latents
feature_names = latent_names
return features, feature_names
def fit(self, latents, demographics, treatment_outcomes):
"""
Fit the random forest model
Args:
latents (np.ndarray): Latent representations from VAE
demographics (dict or pd.DataFrame): Demographic information
treatment_outcomes (np.ndarray): Treatment outcome values to predict
Returns:
self: Trained model instance
"""
X, feature_names = self.prepare_features(latents, demographics)
self.feature_names = feature_names
logger.info(f"Training {self.prediction_type} model with {X.shape[0]} samples and {X.shape[1]} features")
print(f"Random Forest: Building {self.n_estimators} trees...")
# Track progress during fit with verbose
# Set verbose to 2 for detailed per-tree progress
self.model.verbose = 1
self.model.fit(X, treatment_outcomes)
# Calculate feature importance
self.feature_importance = pd.DataFrame({
'feature': feature_names,
'importance': self.model.feature_importances_
}).sort_values('importance', ascending=False)
print(f"Random Forest: Training complete. Top features: {', '.join(self.feature_importance['feature'].head(3).tolist())}")
return self
def predict(self, latents, demographics):
"""
Predict treatment outcomes for new patients
Args:
latents (np.ndarray): Latent representations from VAE
demographics (dict or pd.DataFrame): Demographic information
Returns:
tuple: Predictions and prediction uncertainty (std deviation)
"""
X, _ = self.prepare_features(latents, demographics)
predictions = self.model.predict(X)
# Get prediction intervals using tree variance
if self.prediction_type == "regression":
tree_predictions = np.array([tree.predict(X)
for tree in self.model.estimators_])
prediction_std = np.std(tree_predictions, axis=0)
else: # classification
# For classification, use probability as a measure of confidence
proba = self.model.predict_proba(X)
# Use max probability as confidence measure
prediction_std = 1 - np.max(proba, axis=1)
return predictions, prediction_std
def predict_proba(self, latents, demographics):
"""
Get probability estimates for classification
Args:
latents (np.ndarray): Latent representations from VAE
demographics (dict or pd.DataFrame): Demographic information
Returns:
np.ndarray: Probability estimates for each class
"""
if self.prediction_type != "classification":
raise ValueError("Probability prediction only available for classification")
X, _ = self.prepare_features(latents, demographics)
return self.model.predict_proba(X)
def cross_validate(self, latents, demographics, treatment_outcomes, n_splits=5):
"""
Perform cross-validation
Args:
latents (np.ndarray): Latent representations from VAE
demographics (dict or pd.DataFrame): Demographic information
treatment_outcomes (np.ndarray): Treatment outcome values to predict
n_splits (int): Number of folds for cross-validation
Returns:
dict: Cross-validation results
"""
X, feature_names = self.prepare_features(latents, demographics)
self.feature_names = feature_names
# Adjust n_splits if we have too few samples
sample_count = len(treatment_outcomes)
if sample_count < n_splits * 2: # Need at least 2 samples per fold
adjusted_n_splits = max(2, sample_count // 2) # At least 2 folds, each with multiple samples
logger.warning(f"Too few samples ({sample_count}) for {n_splits} folds. Adjusting to {adjusted_n_splits} folds.")
print(f"Random Forest: Starting {adjusted_n_splits}-fold cross-validation with {sample_count} samples")
n_splits = adjusted_n_splits
else:
logger.info(f"Running {n_splits}-fold cross-validation on {sample_count} samples")
print(f"Random Forest: Starting {n_splits}-fold cross-validation with {sample_count} samples")
# Use stratified KFold for regression to ensure balanced folds
# or LeaveOneOut for very small datasets
if sample_count <= 5:
from sklearn.model_selection import LeaveOneOut
logger.warning(f"Using Leave-One-Out CV for small dataset with {sample_count} samples")
print(f"Random Forest: Using Leave-One-Out cross-validation due to small sample size ({sample_count})")
kf = LeaveOneOut()
cv_iterator = kf.split(X)
else:
kf = KFold(n_splits=n_splits, shuffle=True, random_state=self.random_state)
cv_iterator = kf.split(X)
cv_scores = []
predictions = np.zeros_like(treatment_outcomes)
prediction_stds = np.zeros_like(treatment_outcomes)
fold_metrics = []
for fold, (train_idx, test_idx) in enumerate(cv_iterator):
X_train, X_test = X[train_idx], X[test_idx]
y_train, y_test = treatment_outcomes[train_idx], treatment_outcomes[test_idx]
print(f"Random Forest: Training fold {fold+1}/{n_splits} - {len(X_train)} training samples, {len(X_test)} test samples")
# Clone the model for this fold
if self.prediction_type == "classification":
fold_model = RandomForestClassifier(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
random_state=self.random_state,
verbose=1 # Add verbosity
)
else:
fold_model = RandomForestRegressor(
n_estimators=self.n_estimators,
max_depth=self.max_depth,
random_state=self.random_state,
verbose=1 # Add verbosity
)
# Train the model
fold_model.fit(X_train, y_train)
# Make predictions
pred = fold_model.predict(X_test)
# Store predictions
predictions[test_idx] = pred
# Calculate metrics
if self.prediction_type == "regression":
rmse = np.sqrt(mean_squared_error(y_test, pred))
# R-squared requires at least 2 samples and some variance in the target
if len(y_test) >= 2 and np.var(y_test) > 1e-10:
r2 = r2_score(y_test, pred)
else:
r2 = np.nan
logger.warning(f"Fold {fold+1}: R² not calculated (insufficient samples or variance)")
print(f"Random Forest: Fold {fold+1} - R² not calculated (insufficient samples or variance)")
# MSE can always be calculated
mse = rmse**2
metrics = {
"r2": r2,
"rmse": rmse,
"mse": mse
}
# Add other useful metrics if there are enough samples
if len(y_test) >= 2 and np.var(y_test) > 1e-10:
from sklearn.metrics import explained_variance_score
try:
ev = explained_variance_score(y_test, pred)
metrics["explained_variance"] = ev
except:
# Skip if it can't be calculated
pass
# Get prediction intervals using tree variance
tree_predictions = np.array([tree.predict(X_test)
for tree in fold_model.estimators_])
pred_std = np.std(tree_predictions, axis=0)
prediction_stds[test_idx] = pred_std
else: # classification
acc = accuracy_score(y_test, pred)
prec = precision_score(y_test, pred, average='weighted', zero_division=0)
rec = recall_score(y_test, pred, average='weighted', zero_division=0)
f1 = f1_score(y_test, pred, average='weighted', zero_division=0)
metrics = {
"accuracy": acc,
"precision": prec,
"recall": rec,
"f1": f1
}
# Use probability as a measure of confidence
proba = fold_model.predict_proba(X_test)
# Use max probability as confidence measure
pred_std = 1 - np.max(proba, axis=1)
prediction_stds[test_idx] = pred_std
fold_metrics.append(metrics)
logger.info(f"Fold {fold+1} metrics: {metrics}")
# Print a more user-friendly version of the fold results
if self.prediction_type == "regression":
r2_val = metrics.get('r2', np.nan)
rmse_val = metrics.get('rmse', np.nan)
r2_text = f"R² = {r2_val:.4f}" if not np.isnan(r2_val) else "R² = N/A"
print(f"Random Forest: Fold {fold+1} results - {r2_text}, RMSE = {rmse_val:.4f}")
else:
acc_val = metrics.get('accuracy', 0)
f1_val = metrics.get('f1', 0)
print(f"Random Forest: Fold {fold+1} results - Accuracy = {acc_val:.4f}, F1 = {f1_val:.4f}")
# Calculate average metrics
avg_metrics = {}
for key in fold_metrics[0].keys():
# Filter out nan values when calculating means
values = [fold[key] for fold in fold_metrics if key in fold and not (isinstance(fold[key], float) and np.isnan(fold[key]))]
if values: # Only calculate mean if we have valid values
avg_metrics[key] = np.mean(values)
else:
avg_metrics[key] = np.nan
logger.info(f"Average CV metrics: {avg_metrics}")
# Print a summary of cross-validation performance
if self.prediction_type == "regression":
r2_avg = avg_metrics.get('r2', np.nan)
rmse_avg = avg_metrics.get('rmse', np.nan)
r2_text = f"R² = {r2_avg:.4f}" if not np.isnan(r2_avg) else "R² = N/A"
print(f"Random Forest: Cross-validation complete - Average {r2_text}, RMSE = {rmse_avg:.4f}")
else:
acc_avg = avg_metrics.get('accuracy', 0)
f1_avg = avg_metrics.get('f1', 0)
print(f"Random Forest: Cross-validation complete - Average Accuracy = {acc_avg:.4f}, F1 = {f1_avg:.4f}")
# Train final model on all data
print(f"Random Forest: Training final model on all {len(X)} samples...")
self.model.verbose = 1
self.model.fit(X, treatment_outcomes)
# Calculate feature importance
self.feature_importance = pd.DataFrame({
'feature': feature_names,
'importance': self.model.feature_importances_
}).sort_values('importance', ascending=False)
return {
"mean_metrics": avg_metrics,
"fold_metrics": fold_metrics,
"predictions": predictions,
"prediction_stds": prediction_stds,
"feature_importance": self.feature_importance
}
def get_feature_importance(self):
"""
Get feature importance from the trained model
Returns:
pd.DataFrame: Feature importance values
"""
if self.feature_importance is None:
raise ValueError("Model must be trained first")
return self.feature_importance
def plot_feature_importance(self, top_n=10):
"""
Plot feature importance
Args:
top_n (int): Number of top features to show
Returns:
matplotlib.figure.Figure: Feature importance plot
"""
if self.feature_importance is None:
raise ValueError("Model must be trained first")
# Get top N features
top_features = self.feature_importance.head(top_n)
plt.figure(figsize=(10, 6))
plt.barh(range(len(top_features)),
top_features['importance'],
align='center')
plt.yticks(range(len(top_features)),
top_features['feature'])
plt.xlabel('Importance')
plt.ylabel('Features')
plt.title('Feature Importance in Treatment Outcome Prediction')
plt.tight_layout()
return plt.gcf()
def save_model(self, path="results/treatment_predictor.joblib"):
"""
Save the trained model to disk
Args:
path (str): Path to save the model
"""
# Create directory if it doesn't exist
os.makedirs(os.path.dirname(path), exist_ok=True)
# Save model and metadata
joblib.dump({
'model': self.model,
'feature_names': self.feature_names,
'feature_importance': self.feature_importance,
'prediction_type': self.prediction_type,
'n_estimators': self.n_estimators,
'max_depth': self.max_depth,
'random_state': self.random_state
}, path)
logger.info(f"Model saved to {path}")
@classmethod
def load_model(cls, path="results/treatment_predictor.joblib"):
"""
Load a trained model from disk
Args:
path (str): Path to load the model from
Returns:
AphasiaTreatmentPredictor: Loaded model instance
"""
data = joblib.load(path)
# Create new instance
predictor = cls(
prediction_type=data['prediction_type'],
n_estimators=data['n_estimators'],
max_depth=data['max_depth'],
random_state=data['random_state']
)
# Restore model and metadata
predictor.model = data['model']
predictor.feature_names = data['feature_names']
predictor.feature_importance = data['feature_importance']
logger.info(f"Model loaded from {path}")
return predictor
def train_predictor_from_latents(latents, outcomes, demographics=None, prediction_type="regression", cv=5, **kwargs):
"""
Train a treatment outcome predictor from VAE latent representations
Args:
latents (np.ndarray): Latent representations from VAE
outcomes (np.ndarray): Treatment outcome values
demographics (dict or pd.DataFrame, optional): Demographic information to include as features
prediction_type (str): "classification" or "regression"
cv (int): Number of folds for cross-validation
**kwargs: Additional parameters for the AphasiaTreatmentPredictor
Returns:
dict: Training results and trained model
"""
logger.info(f"Training {prediction_type} model for treatment prediction")
# Create predictor
predictor = AphasiaTreatmentPredictor(prediction_type=prediction_type, **kwargs)
# Run cross-validation
cv_results = predictor.cross_validate(latents, demographics, outcomes, n_splits=cv)
# Save the model
predictor.save_model()
return {
"predictor": predictor,
"cv_results": cv_results,
"feature_importance": predictor.get_feature_importance()
}