machinelearning / models /train.py
JersonRuizAlva
Add application file
97a4bf8
# models/train.py
import time
import streamlit as st
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split, GridSearchCV
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.pipeline import Pipeline
from imblearn.over_sampling import SMOTE
from sklearn.utils import resample
from sklearn.metrics import (
mean_squared_error, r2_score,
accuracy_score, classification_report, confusion_matrix
)
import os
import sys
import pickle
import io
import h2o
from flaml import AutoML
from typing import Dict, Any, Optional
# Importaciones
from utils.model_utils import (
ModelTrainer, # Importar la clase
get_model_options,
train_model_pipeline,
process_classification_data,
create_class_distribution_plot
)
from utils.gemini_explainer import initialize_gemini_explainer
from utils.gemini_explainer import generate_model_explanation
from utils.shap_explainer import create_shap_analysis_dashboard
def safe_init_h2o(url=None, **kwargs):
"""
Safely initialize H2O cluster if not already running.
Args:
url (str, optional): H2O cluster URL. Defaults to None (local instance).
**kwargs: Additional arguments to pass to h2o.init()
Returns:
h2o._backend.H2OConnection: The H2O connection object
"""
# Get current H2O instance if exists
current = h2o.connection()
# Check if H2O is already running
if current and current.cluster:
print("H2O is already running at", current.base_url)
return current
# Initialize new H2O instance
print("Starting new H2O instance...")
return h2o.init(url=url, **kwargs)
def convert_h2o_to_pandas(h2o_df):
"""
Convierte un H2OFrame a pandas DataFrame utilizando m煤ltiples hilos.
Args:
h2o_df (h2o.H2OFrame): Frame de H2O a convertir.
Returns:
pd.DataFrame: DataFrame de pandas.
"""
return h2o_df.as_data_frame(use_multi_thread=True)
# Obtener la ruta del directorio ra铆z del proyecto
project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..'))
sys.path.insert(0, project_root)
def validate_data_preparation(train):
"""
Validar que los datos est茅n preparados correctamente
Args:
train (pd.DataFrame): Datos de entrenamiento
Returns:
bool: Indica si los datos est谩n listos para entrenamiento
"""
if train is None or train.empty:
st.warning("鈿狅笍 No hay datos preparados en la sesi贸n.")
return False
return True
def select_features_and_target(train):
"""
Permitir al usuario seleccionar caracter铆sticas y variable objetivo
Args:
train (pd.DataFrame): Datos de entrenamiento
Returns:
tuple: Variables predictoras (X) y variable objetivo (y)
"""
numeric_cols = train.select_dtypes(include=['int64', 'float64']).columns.tolist()
# Mantener las selecciones en session_state
if 'feature_cols' not in st.session_state:
st.session_state.feature_cols = []
feature_cols = st.multiselect(
"Selecciona las variables predictoras (X):",
numeric_cols,
default=st.session_state.feature_cols
)
st.session_state.feature_cols = feature_cols
# Obtener TODAS las columnas disponibles para target
all_cols = train.columns.tolist()
available_targets = [col for col in all_cols if col not in feature_cols]
if not available_targets:
st.warning("Por favor, deselecciona algunas variables predictoras para poder seleccionar la variable objetivo.")
return None, None
if ('target_col' not in st.session_state or
st.session_state.target_col not in available_targets):
st.session_state.target_col = available_targets[0]
target_col = st.selectbox(
"Selecciona la variable objetivo (y):",
available_targets,
index=available_targets.index(st.session_state.target_col)
)
st.session_state.target_col = target_col
if not (feature_cols and target_col):
st.warning("Por favor selecciona variables predictoras y objetivo.")
return None, None
X = train[feature_cols]
y = train[target_col]
return X, y
def determine_problem_type(y):
"""
Determinar el tipo de problema de machine learning
Args:
y (pd.Series): Variable objetivo
Returns:
str: Tipo de problema ('classification' o 'regression')
"""
is_categorical = y.dtype == 'object' or (y.dtype.name.startswith(('int', 'float')) and y.nunique() <= 10)
problem_type = 'classification' if is_categorical else 'regression'
st.write(f"Tipo de problema identificado: **{problem_type}**")
return problem_type
def handle_data_balancing(X, y, random_state):
"""
Manejar el desbalanceo de clases
Args:
X (pd.DataFrame): Variables predictoras
y (pd.Series): Variable objetivo
random_state (int): Semilla aleatoria
Returns:
tuple: Variables predictoras y objetivo balanceadas
"""
if y.value_counts().min() / y.value_counts().max() < 0.5:
st.write("鈿狅笍 Se detect贸 desbalanceo en las clases")
balance_method = st.selectbox(
"T茅cnica de balanceo:",
["Ninguno", "Submuestreo", "Sobremuestreo", "SMOTE"]
)
if balance_method != "Ninguno":
with st.spinner("Aplicando t茅cnica de balanceo..."):
if balance_method == "Submuestreo":
min_class_size = y.value_counts().min()
X, y = resample(X, y, n_samples=min_class_size*2, stratify=y)
elif balance_method == "Sobremuestreo":
max_class_size = y.value_counts().max()
X, y = resample(X, y, n_samples=max_class_size*2, stratify=y)
else: # SMOTE
smote = SMOTE(random_state=random_state)
X, y = smote.fit_resample(X, y)
st.success("Balanceo completado!")
return X, y
def safe_init_h2o(url=None, **kwargs):
"""
Safely initialize H2O cluster if not already running.
Args:
url (str, optional): H2O cluster URL. Defaults to None (local instance).
**kwargs: Additional arguments to pass to h2o.init()
Returns:
h2o._backend.H2OConnection: The H2O connection object
"""
# Get current H2O instance if exists
current = h2o.connection()
# Check if H2O is already running
if current and current.cluster:
print("H2O is already running at", current.base_url)
return current
# Initialize new H2O instance
print("Starting new H2O instance...")
return h2o.init(url=url, **kwargs)
class AutoMLTrainer:
"""Clase para gestionar el entrenamiento autom谩tico de modelos"""
@staticmethod
def train_h2o_automl(
X_train: pd.DataFrame,
y_train: pd.Series,
X_test: pd.DataFrame,
y_test: pd.Series,
problem_type: str,
time_limit: int = 3600,
max_models: int = 20
) -> Dict[str, Any]:
"""
Entrenar modelos usando H2O AutoML con manejo correcto de tipos de datos
"""
try:
safe_init_h2o()
# Crear un DataFrame combinado con la variable objetivo
train_df = X_train.copy()
test_df = X_test.copy()
# Manejar la variable objetivo seg煤n el tipo de problema
if problem_type == 'classification':
train_df['target'] = y_train.astype(str)
test_df['target'] = y_test.astype(str)
else:
train_df['target'] = y_train.astype(float)
test_df['target'] = y_test.astype(float)
# Convertir a H2OFrame
train = h2o.H2OFrame(train_df)
test = h2o.H2OFrame(test_df)
# Si es clasificaci贸n, convertir expl铆citamente la columna objetivo a factor
if problem_type == 'classification':
train['target'] = train['target'].asfactor()
test['target'] = test['target'].asfactor()
# Especificar columnas
feature_cols = X_train.columns.tolist()
target_col = 'target'
# Configurar AutoML
aml = h2o.automl.H2OAutoML(
max_runtime_secs=time_limit,
max_models=max_models,
seed=42,
sort_metric="AUTO"
)
# Entrenar
start_time = time.time()
aml.train(x=feature_cols, y=target_col, training_frame=train)
training_time = time.time() - start_time
# Obtener el mejor modelo
best_model = aml.leader
# Obtener hiperpar谩metros correctamente
hyperparameters = best_model.params
# Obtener predicciones
preds = best_model.predict(test[feature_cols])
predictions = preds.as_data_frame(use_pandas=True)
if problem_type == 'classification':
predictions = predictions['predict']
# Preparar resultados
results = {
'best_model': best_model,
'training_time': training_time,
'leaderboard': aml.leaderboard.as_data_frame(),
'hyperparameters': hyperparameters,
'predictions': predictions
}
# M茅tricas seg煤n tipo de problema
if problem_type == 'classification':
results.update({
'test_accuracy': accuracy_score(y_test.astype(str), predictions.astype(str)),
'classification_report': classification_report(
y_test.astype(str),
predictions.astype(str),
output_dict=True
)
})
else:
results.update({
'test_rmse': np.sqrt(mean_squared_error(y_test, predictions)),
'test_r2': r2_score(y_test, predictions)
})
return results
except Exception as e:
print(f"Error detallado en H2O AutoML: {str(e)}")
return {'error': str(e)}
@staticmethod
def train_flaml_automl(
X_train: pd.DataFrame,
y_train: pd.Series,
X_test: pd.DataFrame,
y_test: pd.Series,
problem_type: str,
time_limit: int = 3600,
metric: Optional[str] = None
) -> Dict[str, Any]:
"""
Entrenar modelos usando FLAML AutoML
Args:
X_train: Features de entrenamiento
y_train: Target de entrenamiento
X_test: Features de prueba
y_test: Target de prueba
problem_type: Tipo de problema
time_limit: L铆mite de tiempo en segundos
metric: M茅trica de evaluaci贸n
Returns:
Dict con resultados del entrenamiento
"""
try:
# Configurar AutoML
task = 'classification' if problem_type == 'classification' else 'regression'
metric = metric or ('accuracy' if task == 'classification' else 'r2')
automl = AutoML()
# Entrenar
start_time = time.time()
automl.fit(
X_train=X_train,
y_train=y_train,
task=task,
time_budget=time_limit,
metric=metric,
verbose=1
)
training_time = time.time() - start_time
# Predicciones
predictions = automl.predict(X_test)
# Preparar resultados
results = {
'best_model': automl.model,
'best_config': automl.best_config,
'training_time': training_time,
'best_estimator': automl.best_estimator,
'predictions': predictions
}
# M茅tricas espec铆ficas
if problem_type == 'classification':
results.update({
'test_accuracy': accuracy_score(y_test, predictions),
'classification_report': classification_report(y_test, predictions, output_dict=True)
})
else:
results.update({
'test_rmse': np.sqrt(mean_squared_error(y_test, predictions)),
'test_r2': r2_score(y_test, predictions)
})
return results
except Exception as e:
return {'error': str(e)}
def descargar_modelo_h2o(modelo_h2o, nombre_modelo):
"""
Guarda y prepara el modelo H2O para su descarga.
Args:
modelo_h2o: Objeto del modelo H2O.
nombre_modelo (str): Nombre del modelo para el archivo.
Returns:
bytes: Contenido del archivo del modelo.
"""
try:
# Guardar el modelo en una ruta temporal
modelo_path = h2o.save_model(model=modelo_h2o, path="/tmp", force=True)
# Leer el archivo del modelo
with open(modelo_path, "rb") as file:
modelo_data = file.read()
# Opcional: Eliminar el archivo temporal despu茅s de leerlo
os.remove(modelo_path)
return modelo_data
except Exception as e:
st.error(f"Error al preparar el modelo para descarga: {str(e)}")
return None
def show_automl_section(X: pd.DataFrame, y: pd.Series, problem_type: str):
"""Mostrar secci贸n de AutoML"""
st.header("馃 B煤squeda Autom谩tica del Mejor Modelo")
# Par谩metros de AutoML
col1, col2 = st.columns(2)
with col1:
time_limit = st.number_input(
"L铆mite de tiempo (segundos)",
min_value=60,
max_value=7200,
value=3600,
step=300,
key="automl_time_limit"
)
with col2:
framework = st.selectbox(
"Framework AutoML",
["H2O AutoML", "FLAML"],
key="automl_framework"
)
# Inicializar estado para modelos AutoML
if 'automl_models' not in st.session_state:
st.session_state.automl_models = {}
# Bot贸n de entrenamiento
train_button = st.button(
"Entrenar Modelos Autom谩ticamente",
key="train_automl_button",
use_container_width=True
)
if train_button:
try:
# Divisi贸n de datos
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=0.2, random_state=42,
stratify=y if problem_type == 'classification' else None
)
with st.spinner("Entrenando modelos autom谩ticamente..."):
if framework == "H2O AutoML":
results = AutoMLTrainer.train_h2o_automl(
X_train, y_train, X_test, y_test,
problem_type, time_limit
)
else: # FLAML
results = AutoMLTrainer.train_flaml_automl(
X_train, y_train, X_test, y_test,
problem_type, time_limit
)
# Almacenar resultados
st.session_state.automl_models[framework] = results
except Exception as e:
st.error(f"Error en entrenamiento AutoML: {str(e)}")
# Mostrar resultados si existen
if st.session_state.automl_models:
for framework, results in st.session_state.automl_models.items():
st.subheader(f"Resultados de {framework}")
if 'error' in results:
st.error(f"Error: {results['error']}")
continue
# M茅tricas principales
cols = st.columns(3)
with cols[0]:
st.metric(
"Tiempo de Entrenamiento",
f"{results['training_time']:.2f}s"
)
with cols[1]:
if problem_type == 'classification':
st.metric("Accuracy", f"{results['test_accuracy']:.4f}")
else:
st.metric("R虏 Score", f"{results['test_r2']:.4f}")
with cols[2]:
if problem_type == 'classification':
st.metric(
"F1 Score",
f"{results['classification_report']['macro avg']['f1-score']:.4f}"
)
else:
st.metric("RMSE", f"{results['test_rmse']:.4f}")
# Explicaci贸n del modelo
if st.button("Generar Explicaci贸n", key=f"{framework}_explain"):
if 'gemini_api_key' in st.session_state:
with st.spinner("Generando explicaci贸n..."):
explainer = initialize_gemini_explainer()
model_info = {
'name': framework,
'problem_type': problem_type,
'hyperparameters': results.get('hyperparameters', 'N/A'),
'performance_metric': results.get('test_accuracy', results.get('test_r2', 'N/A')),
'training_time': results.get('training_time', 'N/A')
}
explanation = explainer.generate_model_explanation(model_info)
st.markdown(explanation)
else:
st.warning("Configura tu API key de Gemini para generar explicaciones")
# An谩lisis SHAP
if st.button("Mostrar An谩lisis SHAP", key=f"{framework}_shap"):
create_shap_analysis_dashboard(
results['best_model'],
X,
problem_type
)
# Descarga del modelo
if st.button("Descargar Modelo", key=f"{framework}_download"):
modelo_data = descargar_modelo_h2o(results['best_model'], framework)
if modelo_data:
st.download_button(
label=f"Descargar {framework}",
data=modelo_data,
file_name=f"{framework.lower().replace(' ', '_')}_{int(time.time())}.zip",
mime="application/zip",
key=f"{framework}_download_button"
)
def show_train():
"""
Funci贸n principal para mostrar la interfaz de entrenamiento de modelos
"""
st.title("Desarrollo de Modelos")
# Verificar preparaci贸n de datos
if 'prepared_data' not in st.session_state:
st.warning("鈿狅笍 No hay datos preparados en la sesi贸n. Por favor, carga y prepara los datos primero.")
return
if st.session_state.prepared_data is None:
st.warning("鈿狅笍 Los datos preparados est谩n vac铆os. Por favor, verifica la preparaci贸n de datos.")
return
# Inicializar 'trained_models' si no existe
if 'trained_models' not in st.session_state:
st.session_state.trained_models = {}
train = st.session_state.prepared_data
try:
# Seleccionar caracter铆sticas y objetivo
X, y = select_features_and_target(train)
if X is None or y is None:
return
# Verificar valores nulos
if X.isnull().sum().sum() > 0 or y.isnull().sum() > 0:
st.error("Hay valores nulos en los datos. Por favor, vuelve a la p谩gina de preparaci贸n y maneja los valores faltantes.")
return
# Determinar tipo de problema
problem_type = determine_problem_type(y)
# Configuraciones de entrenamiento
col1, col2, col3 = st.columns(3)
with col1:
test_size = st.slider("Tama帽o del conjunto de prueba:", 0.1, 0.5, 0.2)
with col2:
random_state = st.number_input("Random State:", min_value=0, value=42)
with col3:
n_folds = st.number_input("N煤mero de folds para validaci贸n cruzada:", min_value=2, max_value=10, value=5)
st.session_state.n_folds = n_folds
# Preprocesamiento de datos para clasificaci贸n
if problem_type == 'classification':
y_original = y
le = LabelEncoder()
y = pd.Series(le.fit_transform(y))
st.session_state.label_encoder = le
st.write("Mapeo de clases:", dict(enumerate(le.classes_)))
# Visualizar distribuci贸n de clases
fig = create_class_distribution_plot(y_original)
st.plotly_chart(fig)
# Manejar desbalanceo de clases
X, y = handle_data_balancing(X, y, random_state)
show_automl_section(X, y, problem_type)
# Obtener opciones de modelos
model_options = get_model_options(problem_type)
# Gestionar modelos seleccionados
if 'selected_models' not in st.session_state:
st.session_state.selected_models = []
selected_models = st.multiselect(
"Selecciona los modelos a entrenar:",
list(model_options.keys()),
default=st.session_state.selected_models
)
st.session_state.selected_models = selected_models
if not selected_models:
st.warning("Por favor selecciona al menos un modelo para entrenar.")
return
# Configurar re-entrenamiento
if st.button("Reentrenar Modelos"):
st.session_state.retrain_models = True
else:
# Solo establecer a False si no est谩 ya en sesi贸n
if 'retrain_models' not in st.session_state:
st.session_state.retrain_models = False
# Dividir datos
X_train, X_test, y_train, y_test = train_test_split(
X, y, test_size=test_size, random_state=random_state,
stratify=y if problem_type == 'classification' else None
)
# Crear columnas para mostrar resultados de modelos
cols = st.columns(len(selected_models))
# Entrenar y mostrar resultados de cada modelo
for i, model_name in enumerate(selected_models):
with cols[i]:
st.write(f"### {model_name}")
# Verificar si el modelo ya est谩 entrenado y si no se solicita reentrenamiento
if (model_name not in st.session_state.trained_models) or st.session_state.retrain_models:
# Entrenar modelo
trained_model = train_model_pipeline(
X_train=X_train,
y_train=y_train,
model_config=model_options[model_name],
X_test=X_test,
y_test=y_test,
cv=st.session_state.n_folds,
scoring=None,
random_state=random_state, # Pasar random_state
n_jobs=-1, # Para usar todos los n煤cleos disponibles
verbose=1
)
# Almacenar el modelo entrenado en session_state
if 'trained_models' not in st.session_state:
st.session_state.trained_models = {}
st.session_state.trained_models[model_name] = trained_model
else:
# Reutilizar el modelo ya entrenado
trained_model = st.session_state.trained_models[model_name]
# Mostrar resultados del modelo
show_model_results(
model_name,
problem_type,
y_test,
cols[i],
trained_model
)
except Exception as e:
st.error(f"Error inesperado: {str(e)}")
def show_model_results(model_name, problem_type, y_test, col, trained_model):
"""
Mostrar resultados detallados de un modelo entrenado
Args:
model_name (str): Nombre del modelo
problem_type (str): Tipo de problema
y_test (pd.Series): Datos de prueba
col (streamlit.delta_generator.DeltaGenerator): Columna de Streamlit
trained_model (dict): Resultados del entrenamiento
"""
with col:
# Verificar si el modelo est谩 en la sesi贸n de modelos entrenados
if model_name in st.session_state.trained_models:
results = st.session_state.trained_models[model_name]
# Verificar si hubo un error durante el entrenamiento
if 'error' in results:
st.error(results['error'])
return
# Mostrar m茅tricas de rendimiento
if 'training_time' in results:
st.success(f"隆Entrenamiento completado en {results['training_time']:.2f} segundos!")
st.write("Mejores par谩metros:", results.get('best_params', 'N/A'))
# M茅tricas espec铆ficas seg煤n el tipo de problema
if problem_type == 'classification':
st.write("Accuracy:", results.get('test_accuracy', 'N/A'))
st.text("Reporte de clasificaci贸n:")
st.text(pd.DataFrame(results.get('classification_report', {})).transpose().to_string())
else:
st.write("R虏 Score:", results.get('test_r2', 'N/A'))
st.write("RMSE:", results.get('test_rmse', 'N/A'))
# Secci贸n de explicaci贸n de par谩metros con Gemini
st.write("---")
st.write("### Explicaci贸n de Par谩metros")
# Verificar disponibilidad de API key de Gemini
has_api_key = 'gemini_api_key' in st.session_state and st.session_state.gemini_api_key
if not has_api_key:
st.warning("Configure su API key de Gemini en la secci贸n superior izquierda para usar la explicaci贸n autom谩tica de los par谩metros.")
# Inicializar el explainer si no lo has hecho ya
if 'explainer' not in st.session_state:
st.session_state.explainer = initialize_gemini_explainer()
explainer = st.session_state.explainer
# Inicializar explicaciones en el estado de la sesi贸n
if 'model_explanations' not in st.session_state:
st.session_state.model_explanations = {}
# Bot贸n para generar explicaci贸n
explain_button = st.button(
"Explicar Par谩metros",
disabled=not has_api_key,
key=f"explain_{model_name}"
)
# Mostrar explicaci贸n existente si est谩 disponible
if model_name in st.session_state.model_explanations:
st.markdown(st.session_state.model_explanations[model_name])
# Inicializar el explainer solo cuando se necesite
if 'explain_button' in locals() and explain_button and has_api_key:
explainer = initialize_gemini_explainer()
if explainer: # Verificar que el explainer se inicializ贸 correctamente
try:
with st.spinner("Generando explicaci贸n..."):
model_info = {
'name': model_name,
'problem_type': problem_type,
'hyperparameters': results.get('hyperparameters', 'N/A'),
'performance_metric': results.get('test_accuracy', results.get('test_r2', 'N/A')),
'training_time': results.get('training_time', 'N/A')
}
explanation = explainer.generate_model_explanation(model_info)
# Almacenar explicaci贸n
st.session_state.model_explanations[model_name] = explanation
# Mostrar explicaci贸n
st.markdown(explanation)
except Exception as e:
st.error(f"Error al generar la explicaci贸n: {str(e)}")
else:
st.error("No se pudo inicializar el explicador de Gemini")
# Secci贸n de an谩lisis SHAP
st.write("---")
st.write("### An谩lisis SHAP")
if st.button("Mostrar An谩lisis SHAP", key=f"shap_button_{model_name}"):
try:
# Obtener datos preparados
X = st.session_state.prepared_data[st.session_state.feature_cols]
# Crear dashboard de an谩lisis SHAP
create_shap_analysis_dashboard(
results['best_model'], # Usar el mejor modelo
X,
problem_type
)
except Exception as e:
st.error(f"Error en el an谩lisis SHAP: {str(e)}")
# Secci贸n de descarga del modelo
st.write("---")
st.write("### Descarga del modelo")
# Generar nombre de archivo
model_file_key = f"model_file_{model_name}"
if model_file_key not in st.session_state:
st.session_state[model_file_key] = f"{model_name.lower().replace(' ', '_')}_{int(time.time())}.pkl"
# Input para nombre de archivo
model_name_input = st.text_input(
"Nombre del archivo:",
value=st.session_state[model_file_key],
key=f"name_input_{model_name}"
)
# Bot贸n de descarga
model_buffer = io.BytesIO()
pickle.dump(results['best_model'], model_buffer) # Guardar el mejor modelo
model_buffer.seek(0)
download_key = f"download_{model_name}"
st.download_button(
label="Descargar Modelo",
data=model_buffer,
file_name=model_name_input,
mime="application/octet-stream",
key=download_key
)
# # Bot贸n de descarga
# modelo_data = descargar_modelo_h2o(results['best_model'], model_name)
# if modelo_data:
# st.download_button(
# label="Descargar Modelo",
# data=modelo_data,
# file_name=model_name_input,
# mime="application/zip",
# key=f"download_{model_name}"
# )