Spaces:
Paused
Paused
| # models/train.py | |
| import time | |
| import streamlit as st | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.model_selection import train_test_split, GridSearchCV | |
| from sklearn.preprocessing import StandardScaler, LabelEncoder | |
| from sklearn.pipeline import Pipeline | |
| from imblearn.over_sampling import SMOTE | |
| from sklearn.utils import resample | |
| from sklearn.metrics import ( | |
| mean_squared_error, r2_score, | |
| accuracy_score, classification_report, confusion_matrix | |
| ) | |
| import os | |
| import sys | |
| import pickle | |
| import io | |
| import h2o | |
| from flaml import AutoML | |
| from typing import Dict, Any, Optional | |
| # Importaciones | |
| from utils.model_utils import ( | |
| ModelTrainer, # Importar la clase | |
| get_model_options, | |
| train_model_pipeline, | |
| process_classification_data, | |
| create_class_distribution_plot | |
| ) | |
| from utils.gemini_explainer import initialize_gemini_explainer | |
| from utils.gemini_explainer import generate_model_explanation | |
| from utils.shap_explainer import create_shap_analysis_dashboard | |
| def safe_init_h2o(url=None, **kwargs): | |
| """ | |
| Safely initialize H2O cluster if not already running. | |
| Args: | |
| url (str, optional): H2O cluster URL. Defaults to None (local instance). | |
| **kwargs: Additional arguments to pass to h2o.init() | |
| Returns: | |
| h2o._backend.H2OConnection: The H2O connection object | |
| """ | |
| # Get current H2O instance if exists | |
| current = h2o.connection() | |
| # Check if H2O is already running | |
| if current and current.cluster: | |
| print("H2O is already running at", current.base_url) | |
| return current | |
| # Initialize new H2O instance | |
| print("Starting new H2O instance...") | |
| return h2o.init(url=url, **kwargs) | |
| def convert_h2o_to_pandas(h2o_df): | |
| """ | |
| Convierte un H2OFrame a pandas DataFrame utilizando m煤ltiples hilos. | |
| Args: | |
| h2o_df (h2o.H2OFrame): Frame de H2O a convertir. | |
| Returns: | |
| pd.DataFrame: DataFrame de pandas. | |
| """ | |
| return h2o_df.as_data_frame(use_multi_thread=True) | |
| # Obtener la ruta del directorio ra铆z del proyecto | |
| project_root = os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..')) | |
| sys.path.insert(0, project_root) | |
| def validate_data_preparation(train): | |
| """ | |
| Validar que los datos est茅n preparados correctamente | |
| Args: | |
| train (pd.DataFrame): Datos de entrenamiento | |
| Returns: | |
| bool: Indica si los datos est谩n listos para entrenamiento | |
| """ | |
| if train is None or train.empty: | |
| st.warning("鈿狅笍 No hay datos preparados en la sesi贸n.") | |
| return False | |
| return True | |
| def select_features_and_target(train): | |
| """ | |
| Permitir al usuario seleccionar caracter铆sticas y variable objetivo | |
| Args: | |
| train (pd.DataFrame): Datos de entrenamiento | |
| Returns: | |
| tuple: Variables predictoras (X) y variable objetivo (y) | |
| """ | |
| numeric_cols = train.select_dtypes(include=['int64', 'float64']).columns.tolist() | |
| # Mantener las selecciones en session_state | |
| if 'feature_cols' not in st.session_state: | |
| st.session_state.feature_cols = [] | |
| feature_cols = st.multiselect( | |
| "Selecciona las variables predictoras (X):", | |
| numeric_cols, | |
| default=st.session_state.feature_cols | |
| ) | |
| st.session_state.feature_cols = feature_cols | |
| # Obtener TODAS las columnas disponibles para target | |
| all_cols = train.columns.tolist() | |
| available_targets = [col for col in all_cols if col not in feature_cols] | |
| if not available_targets: | |
| st.warning("Por favor, deselecciona algunas variables predictoras para poder seleccionar la variable objetivo.") | |
| return None, None | |
| if ('target_col' not in st.session_state or | |
| st.session_state.target_col not in available_targets): | |
| st.session_state.target_col = available_targets[0] | |
| target_col = st.selectbox( | |
| "Selecciona la variable objetivo (y):", | |
| available_targets, | |
| index=available_targets.index(st.session_state.target_col) | |
| ) | |
| st.session_state.target_col = target_col | |
| if not (feature_cols and target_col): | |
| st.warning("Por favor selecciona variables predictoras y objetivo.") | |
| return None, None | |
| X = train[feature_cols] | |
| y = train[target_col] | |
| return X, y | |
| def determine_problem_type(y): | |
| """ | |
| Determinar el tipo de problema de machine learning | |
| Args: | |
| y (pd.Series): Variable objetivo | |
| Returns: | |
| str: Tipo de problema ('classification' o 'regression') | |
| """ | |
| is_categorical = y.dtype == 'object' or (y.dtype.name.startswith(('int', 'float')) and y.nunique() <= 10) | |
| problem_type = 'classification' if is_categorical else 'regression' | |
| st.write(f"Tipo de problema identificado: **{problem_type}**") | |
| return problem_type | |
| def handle_data_balancing(X, y, random_state): | |
| """ | |
| Manejar el desbalanceo de clases | |
| Args: | |
| X (pd.DataFrame): Variables predictoras | |
| y (pd.Series): Variable objetivo | |
| random_state (int): Semilla aleatoria | |
| Returns: | |
| tuple: Variables predictoras y objetivo balanceadas | |
| """ | |
| if y.value_counts().min() / y.value_counts().max() < 0.5: | |
| st.write("鈿狅笍 Se detect贸 desbalanceo en las clases") | |
| balance_method = st.selectbox( | |
| "T茅cnica de balanceo:", | |
| ["Ninguno", "Submuestreo", "Sobremuestreo", "SMOTE"] | |
| ) | |
| if balance_method != "Ninguno": | |
| with st.spinner("Aplicando t茅cnica de balanceo..."): | |
| if balance_method == "Submuestreo": | |
| min_class_size = y.value_counts().min() | |
| X, y = resample(X, y, n_samples=min_class_size*2, stratify=y) | |
| elif balance_method == "Sobremuestreo": | |
| max_class_size = y.value_counts().max() | |
| X, y = resample(X, y, n_samples=max_class_size*2, stratify=y) | |
| else: # SMOTE | |
| smote = SMOTE(random_state=random_state) | |
| X, y = smote.fit_resample(X, y) | |
| st.success("Balanceo completado!") | |
| return X, y | |
| def safe_init_h2o(url=None, **kwargs): | |
| """ | |
| Safely initialize H2O cluster if not already running. | |
| Args: | |
| url (str, optional): H2O cluster URL. Defaults to None (local instance). | |
| **kwargs: Additional arguments to pass to h2o.init() | |
| Returns: | |
| h2o._backend.H2OConnection: The H2O connection object | |
| """ | |
| # Get current H2O instance if exists | |
| current = h2o.connection() | |
| # Check if H2O is already running | |
| if current and current.cluster: | |
| print("H2O is already running at", current.base_url) | |
| return current | |
| # Initialize new H2O instance | |
| print("Starting new H2O instance...") | |
| return h2o.init(url=url, **kwargs) | |
| class AutoMLTrainer: | |
| """Clase para gestionar el entrenamiento autom谩tico de modelos""" | |
| def train_h2o_automl( | |
| X_train: pd.DataFrame, | |
| y_train: pd.Series, | |
| X_test: pd.DataFrame, | |
| y_test: pd.Series, | |
| problem_type: str, | |
| time_limit: int = 3600, | |
| max_models: int = 20 | |
| ) -> Dict[str, Any]: | |
| """ | |
| Entrenar modelos usando H2O AutoML con manejo correcto de tipos de datos | |
| """ | |
| try: | |
| safe_init_h2o() | |
| # Crear un DataFrame combinado con la variable objetivo | |
| train_df = X_train.copy() | |
| test_df = X_test.copy() | |
| # Manejar la variable objetivo seg煤n el tipo de problema | |
| if problem_type == 'classification': | |
| train_df['target'] = y_train.astype(str) | |
| test_df['target'] = y_test.astype(str) | |
| else: | |
| train_df['target'] = y_train.astype(float) | |
| test_df['target'] = y_test.astype(float) | |
| # Convertir a H2OFrame | |
| train = h2o.H2OFrame(train_df) | |
| test = h2o.H2OFrame(test_df) | |
| # Si es clasificaci贸n, convertir expl铆citamente la columna objetivo a factor | |
| if problem_type == 'classification': | |
| train['target'] = train['target'].asfactor() | |
| test['target'] = test['target'].asfactor() | |
| # Especificar columnas | |
| feature_cols = X_train.columns.tolist() | |
| target_col = 'target' | |
| # Configurar AutoML | |
| aml = h2o.automl.H2OAutoML( | |
| max_runtime_secs=time_limit, | |
| max_models=max_models, | |
| seed=42, | |
| sort_metric="AUTO" | |
| ) | |
| # Entrenar | |
| start_time = time.time() | |
| aml.train(x=feature_cols, y=target_col, training_frame=train) | |
| training_time = time.time() - start_time | |
| # Obtener el mejor modelo | |
| best_model = aml.leader | |
| # Obtener hiperpar谩metros correctamente | |
| hyperparameters = best_model.params | |
| # Obtener predicciones | |
| preds = best_model.predict(test[feature_cols]) | |
| predictions = preds.as_data_frame(use_pandas=True) | |
| if problem_type == 'classification': | |
| predictions = predictions['predict'] | |
| # Preparar resultados | |
| results = { | |
| 'best_model': best_model, | |
| 'training_time': training_time, | |
| 'leaderboard': aml.leaderboard.as_data_frame(), | |
| 'hyperparameters': hyperparameters, | |
| 'predictions': predictions | |
| } | |
| # M茅tricas seg煤n tipo de problema | |
| if problem_type == 'classification': | |
| results.update({ | |
| 'test_accuracy': accuracy_score(y_test.astype(str), predictions.astype(str)), | |
| 'classification_report': classification_report( | |
| y_test.astype(str), | |
| predictions.astype(str), | |
| output_dict=True | |
| ) | |
| }) | |
| else: | |
| results.update({ | |
| 'test_rmse': np.sqrt(mean_squared_error(y_test, predictions)), | |
| 'test_r2': r2_score(y_test, predictions) | |
| }) | |
| return results | |
| except Exception as e: | |
| print(f"Error detallado en H2O AutoML: {str(e)}") | |
| return {'error': str(e)} | |
| def train_flaml_automl( | |
| X_train: pd.DataFrame, | |
| y_train: pd.Series, | |
| X_test: pd.DataFrame, | |
| y_test: pd.Series, | |
| problem_type: str, | |
| time_limit: int = 3600, | |
| metric: Optional[str] = None | |
| ) -> Dict[str, Any]: | |
| """ | |
| Entrenar modelos usando FLAML AutoML | |
| Args: | |
| X_train: Features de entrenamiento | |
| y_train: Target de entrenamiento | |
| X_test: Features de prueba | |
| y_test: Target de prueba | |
| problem_type: Tipo de problema | |
| time_limit: L铆mite de tiempo en segundos | |
| metric: M茅trica de evaluaci贸n | |
| Returns: | |
| Dict con resultados del entrenamiento | |
| """ | |
| try: | |
| # Configurar AutoML | |
| task = 'classification' if problem_type == 'classification' else 'regression' | |
| metric = metric or ('accuracy' if task == 'classification' else 'r2') | |
| automl = AutoML() | |
| # Entrenar | |
| start_time = time.time() | |
| automl.fit( | |
| X_train=X_train, | |
| y_train=y_train, | |
| task=task, | |
| time_budget=time_limit, | |
| metric=metric, | |
| verbose=1 | |
| ) | |
| training_time = time.time() - start_time | |
| # Predicciones | |
| predictions = automl.predict(X_test) | |
| # Preparar resultados | |
| results = { | |
| 'best_model': automl.model, | |
| 'best_config': automl.best_config, | |
| 'training_time': training_time, | |
| 'best_estimator': automl.best_estimator, | |
| 'predictions': predictions | |
| } | |
| # M茅tricas espec铆ficas | |
| if problem_type == 'classification': | |
| results.update({ | |
| 'test_accuracy': accuracy_score(y_test, predictions), | |
| 'classification_report': classification_report(y_test, predictions, output_dict=True) | |
| }) | |
| else: | |
| results.update({ | |
| 'test_rmse': np.sqrt(mean_squared_error(y_test, predictions)), | |
| 'test_r2': r2_score(y_test, predictions) | |
| }) | |
| return results | |
| except Exception as e: | |
| return {'error': str(e)} | |
| def descargar_modelo_h2o(modelo_h2o, nombre_modelo): | |
| """ | |
| Guarda y prepara el modelo H2O para su descarga. | |
| Args: | |
| modelo_h2o: Objeto del modelo H2O. | |
| nombre_modelo (str): Nombre del modelo para el archivo. | |
| Returns: | |
| bytes: Contenido del archivo del modelo. | |
| """ | |
| try: | |
| # Guardar el modelo en una ruta temporal | |
| modelo_path = h2o.save_model(model=modelo_h2o, path="/tmp", force=True) | |
| # Leer el archivo del modelo | |
| with open(modelo_path, "rb") as file: | |
| modelo_data = file.read() | |
| # Opcional: Eliminar el archivo temporal despu茅s de leerlo | |
| os.remove(modelo_path) | |
| return modelo_data | |
| except Exception as e: | |
| st.error(f"Error al preparar el modelo para descarga: {str(e)}") | |
| return None | |
| def show_automl_section(X: pd.DataFrame, y: pd.Series, problem_type: str): | |
| """Mostrar secci贸n de AutoML""" | |
| st.header("馃 B煤squeda Autom谩tica del Mejor Modelo") | |
| # Par谩metros de AutoML | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| time_limit = st.number_input( | |
| "L铆mite de tiempo (segundos)", | |
| min_value=60, | |
| max_value=7200, | |
| value=3600, | |
| step=300, | |
| key="automl_time_limit" | |
| ) | |
| with col2: | |
| framework = st.selectbox( | |
| "Framework AutoML", | |
| ["H2O AutoML", "FLAML"], | |
| key="automl_framework" | |
| ) | |
| # Inicializar estado para modelos AutoML | |
| if 'automl_models' not in st.session_state: | |
| st.session_state.automl_models = {} | |
| # Bot贸n de entrenamiento | |
| train_button = st.button( | |
| "Entrenar Modelos Autom谩ticamente", | |
| key="train_automl_button", | |
| use_container_width=True | |
| ) | |
| if train_button: | |
| try: | |
| # Divisi贸n de datos | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=0.2, random_state=42, | |
| stratify=y if problem_type == 'classification' else None | |
| ) | |
| with st.spinner("Entrenando modelos autom谩ticamente..."): | |
| if framework == "H2O AutoML": | |
| results = AutoMLTrainer.train_h2o_automl( | |
| X_train, y_train, X_test, y_test, | |
| problem_type, time_limit | |
| ) | |
| else: # FLAML | |
| results = AutoMLTrainer.train_flaml_automl( | |
| X_train, y_train, X_test, y_test, | |
| problem_type, time_limit | |
| ) | |
| # Almacenar resultados | |
| st.session_state.automl_models[framework] = results | |
| except Exception as e: | |
| st.error(f"Error en entrenamiento AutoML: {str(e)}") | |
| # Mostrar resultados si existen | |
| if st.session_state.automl_models: | |
| for framework, results in st.session_state.automl_models.items(): | |
| st.subheader(f"Resultados de {framework}") | |
| if 'error' in results: | |
| st.error(f"Error: {results['error']}") | |
| continue | |
| # M茅tricas principales | |
| cols = st.columns(3) | |
| with cols[0]: | |
| st.metric( | |
| "Tiempo de Entrenamiento", | |
| f"{results['training_time']:.2f}s" | |
| ) | |
| with cols[1]: | |
| if problem_type == 'classification': | |
| st.metric("Accuracy", f"{results['test_accuracy']:.4f}") | |
| else: | |
| st.metric("R虏 Score", f"{results['test_r2']:.4f}") | |
| with cols[2]: | |
| if problem_type == 'classification': | |
| st.metric( | |
| "F1 Score", | |
| f"{results['classification_report']['macro avg']['f1-score']:.4f}" | |
| ) | |
| else: | |
| st.metric("RMSE", f"{results['test_rmse']:.4f}") | |
| # Explicaci贸n del modelo | |
| if st.button("Generar Explicaci贸n", key=f"{framework}_explain"): | |
| if 'gemini_api_key' in st.session_state: | |
| with st.spinner("Generando explicaci贸n..."): | |
| explainer = initialize_gemini_explainer() | |
| model_info = { | |
| 'name': framework, | |
| 'problem_type': problem_type, | |
| 'hyperparameters': results.get('hyperparameters', 'N/A'), | |
| 'performance_metric': results.get('test_accuracy', results.get('test_r2', 'N/A')), | |
| 'training_time': results.get('training_time', 'N/A') | |
| } | |
| explanation = explainer.generate_model_explanation(model_info) | |
| st.markdown(explanation) | |
| else: | |
| st.warning("Configura tu API key de Gemini para generar explicaciones") | |
| # An谩lisis SHAP | |
| if st.button("Mostrar An谩lisis SHAP", key=f"{framework}_shap"): | |
| create_shap_analysis_dashboard( | |
| results['best_model'], | |
| X, | |
| problem_type | |
| ) | |
| # Descarga del modelo | |
| if st.button("Descargar Modelo", key=f"{framework}_download"): | |
| modelo_data = descargar_modelo_h2o(results['best_model'], framework) | |
| if modelo_data: | |
| st.download_button( | |
| label=f"Descargar {framework}", | |
| data=modelo_data, | |
| file_name=f"{framework.lower().replace(' ', '_')}_{int(time.time())}.zip", | |
| mime="application/zip", | |
| key=f"{framework}_download_button" | |
| ) | |
| def show_train(): | |
| """ | |
| Funci贸n principal para mostrar la interfaz de entrenamiento de modelos | |
| """ | |
| st.title("Desarrollo de Modelos") | |
| # Verificar preparaci贸n de datos | |
| if 'prepared_data' not in st.session_state: | |
| st.warning("鈿狅笍 No hay datos preparados en la sesi贸n. Por favor, carga y prepara los datos primero.") | |
| return | |
| if st.session_state.prepared_data is None: | |
| st.warning("鈿狅笍 Los datos preparados est谩n vac铆os. Por favor, verifica la preparaci贸n de datos.") | |
| return | |
| # Inicializar 'trained_models' si no existe | |
| if 'trained_models' not in st.session_state: | |
| st.session_state.trained_models = {} | |
| train = st.session_state.prepared_data | |
| try: | |
| # Seleccionar caracter铆sticas y objetivo | |
| X, y = select_features_and_target(train) | |
| if X is None or y is None: | |
| return | |
| # Verificar valores nulos | |
| if X.isnull().sum().sum() > 0 or y.isnull().sum() > 0: | |
| st.error("Hay valores nulos en los datos. Por favor, vuelve a la p谩gina de preparaci贸n y maneja los valores faltantes.") | |
| return | |
| # Determinar tipo de problema | |
| problem_type = determine_problem_type(y) | |
| # Configuraciones de entrenamiento | |
| col1, col2, col3 = st.columns(3) | |
| with col1: | |
| test_size = st.slider("Tama帽o del conjunto de prueba:", 0.1, 0.5, 0.2) | |
| with col2: | |
| random_state = st.number_input("Random State:", min_value=0, value=42) | |
| with col3: | |
| n_folds = st.number_input("N煤mero de folds para validaci贸n cruzada:", min_value=2, max_value=10, value=5) | |
| st.session_state.n_folds = n_folds | |
| # Preprocesamiento de datos para clasificaci贸n | |
| if problem_type == 'classification': | |
| y_original = y | |
| le = LabelEncoder() | |
| y = pd.Series(le.fit_transform(y)) | |
| st.session_state.label_encoder = le | |
| st.write("Mapeo de clases:", dict(enumerate(le.classes_))) | |
| # Visualizar distribuci贸n de clases | |
| fig = create_class_distribution_plot(y_original) | |
| st.plotly_chart(fig) | |
| # Manejar desbalanceo de clases | |
| X, y = handle_data_balancing(X, y, random_state) | |
| show_automl_section(X, y, problem_type) | |
| # Obtener opciones de modelos | |
| model_options = get_model_options(problem_type) | |
| # Gestionar modelos seleccionados | |
| if 'selected_models' not in st.session_state: | |
| st.session_state.selected_models = [] | |
| selected_models = st.multiselect( | |
| "Selecciona los modelos a entrenar:", | |
| list(model_options.keys()), | |
| default=st.session_state.selected_models | |
| ) | |
| st.session_state.selected_models = selected_models | |
| if not selected_models: | |
| st.warning("Por favor selecciona al menos un modelo para entrenar.") | |
| return | |
| # Configurar re-entrenamiento | |
| if st.button("Reentrenar Modelos"): | |
| st.session_state.retrain_models = True | |
| else: | |
| # Solo establecer a False si no est谩 ya en sesi贸n | |
| if 'retrain_models' not in st.session_state: | |
| st.session_state.retrain_models = False | |
| # Dividir datos | |
| X_train, X_test, y_train, y_test = train_test_split( | |
| X, y, test_size=test_size, random_state=random_state, | |
| stratify=y if problem_type == 'classification' else None | |
| ) | |
| # Crear columnas para mostrar resultados de modelos | |
| cols = st.columns(len(selected_models)) | |
| # Entrenar y mostrar resultados de cada modelo | |
| for i, model_name in enumerate(selected_models): | |
| with cols[i]: | |
| st.write(f"### {model_name}") | |
| # Verificar si el modelo ya est谩 entrenado y si no se solicita reentrenamiento | |
| if (model_name not in st.session_state.trained_models) or st.session_state.retrain_models: | |
| # Entrenar modelo | |
| trained_model = train_model_pipeline( | |
| X_train=X_train, | |
| y_train=y_train, | |
| model_config=model_options[model_name], | |
| X_test=X_test, | |
| y_test=y_test, | |
| cv=st.session_state.n_folds, | |
| scoring=None, | |
| random_state=random_state, # Pasar random_state | |
| n_jobs=-1, # Para usar todos los n煤cleos disponibles | |
| verbose=1 | |
| ) | |
| # Almacenar el modelo entrenado en session_state | |
| if 'trained_models' not in st.session_state: | |
| st.session_state.trained_models = {} | |
| st.session_state.trained_models[model_name] = trained_model | |
| else: | |
| # Reutilizar el modelo ya entrenado | |
| trained_model = st.session_state.trained_models[model_name] | |
| # Mostrar resultados del modelo | |
| show_model_results( | |
| model_name, | |
| problem_type, | |
| y_test, | |
| cols[i], | |
| trained_model | |
| ) | |
| except Exception as e: | |
| st.error(f"Error inesperado: {str(e)}") | |
| def show_model_results(model_name, problem_type, y_test, col, trained_model): | |
| """ | |
| Mostrar resultados detallados de un modelo entrenado | |
| Args: | |
| model_name (str): Nombre del modelo | |
| problem_type (str): Tipo de problema | |
| y_test (pd.Series): Datos de prueba | |
| col (streamlit.delta_generator.DeltaGenerator): Columna de Streamlit | |
| trained_model (dict): Resultados del entrenamiento | |
| """ | |
| with col: | |
| # Verificar si el modelo est谩 en la sesi贸n de modelos entrenados | |
| if model_name in st.session_state.trained_models: | |
| results = st.session_state.trained_models[model_name] | |
| # Verificar si hubo un error durante el entrenamiento | |
| if 'error' in results: | |
| st.error(results['error']) | |
| return | |
| # Mostrar m茅tricas de rendimiento | |
| if 'training_time' in results: | |
| st.success(f"隆Entrenamiento completado en {results['training_time']:.2f} segundos!") | |
| st.write("Mejores par谩metros:", results.get('best_params', 'N/A')) | |
| # M茅tricas espec铆ficas seg煤n el tipo de problema | |
| if problem_type == 'classification': | |
| st.write("Accuracy:", results.get('test_accuracy', 'N/A')) | |
| st.text("Reporte de clasificaci贸n:") | |
| st.text(pd.DataFrame(results.get('classification_report', {})).transpose().to_string()) | |
| else: | |
| st.write("R虏 Score:", results.get('test_r2', 'N/A')) | |
| st.write("RMSE:", results.get('test_rmse', 'N/A')) | |
| # Secci贸n de explicaci贸n de par谩metros con Gemini | |
| st.write("---") | |
| st.write("### Explicaci贸n de Par谩metros") | |
| # Verificar disponibilidad de API key de Gemini | |
| has_api_key = 'gemini_api_key' in st.session_state and st.session_state.gemini_api_key | |
| if not has_api_key: | |
| st.warning("Configure su API key de Gemini en la secci贸n superior izquierda para usar la explicaci贸n autom谩tica de los par谩metros.") | |
| # Inicializar el explainer si no lo has hecho ya | |
| if 'explainer' not in st.session_state: | |
| st.session_state.explainer = initialize_gemini_explainer() | |
| explainer = st.session_state.explainer | |
| # Inicializar explicaciones en el estado de la sesi贸n | |
| if 'model_explanations' not in st.session_state: | |
| st.session_state.model_explanations = {} | |
| # Bot贸n para generar explicaci贸n | |
| explain_button = st.button( | |
| "Explicar Par谩metros", | |
| disabled=not has_api_key, | |
| key=f"explain_{model_name}" | |
| ) | |
| # Mostrar explicaci贸n existente si est谩 disponible | |
| if model_name in st.session_state.model_explanations: | |
| st.markdown(st.session_state.model_explanations[model_name]) | |
| # Inicializar el explainer solo cuando se necesite | |
| if 'explain_button' in locals() and explain_button and has_api_key: | |
| explainer = initialize_gemini_explainer() | |
| if explainer: # Verificar que el explainer se inicializ贸 correctamente | |
| try: | |
| with st.spinner("Generando explicaci贸n..."): | |
| model_info = { | |
| 'name': model_name, | |
| 'problem_type': problem_type, | |
| 'hyperparameters': results.get('hyperparameters', 'N/A'), | |
| 'performance_metric': results.get('test_accuracy', results.get('test_r2', 'N/A')), | |
| 'training_time': results.get('training_time', 'N/A') | |
| } | |
| explanation = explainer.generate_model_explanation(model_info) | |
| # Almacenar explicaci贸n | |
| st.session_state.model_explanations[model_name] = explanation | |
| # Mostrar explicaci贸n | |
| st.markdown(explanation) | |
| except Exception as e: | |
| st.error(f"Error al generar la explicaci贸n: {str(e)}") | |
| else: | |
| st.error("No se pudo inicializar el explicador de Gemini") | |
| # Secci贸n de an谩lisis SHAP | |
| st.write("---") | |
| st.write("### An谩lisis SHAP") | |
| if st.button("Mostrar An谩lisis SHAP", key=f"shap_button_{model_name}"): | |
| try: | |
| # Obtener datos preparados | |
| X = st.session_state.prepared_data[st.session_state.feature_cols] | |
| # Crear dashboard de an谩lisis SHAP | |
| create_shap_analysis_dashboard( | |
| results['best_model'], # Usar el mejor modelo | |
| X, | |
| problem_type | |
| ) | |
| except Exception as e: | |
| st.error(f"Error en el an谩lisis SHAP: {str(e)}") | |
| # Secci贸n de descarga del modelo | |
| st.write("---") | |
| st.write("### Descarga del modelo") | |
| # Generar nombre de archivo | |
| model_file_key = f"model_file_{model_name}" | |
| if model_file_key not in st.session_state: | |
| st.session_state[model_file_key] = f"{model_name.lower().replace(' ', '_')}_{int(time.time())}.pkl" | |
| # Input para nombre de archivo | |
| model_name_input = st.text_input( | |
| "Nombre del archivo:", | |
| value=st.session_state[model_file_key], | |
| key=f"name_input_{model_name}" | |
| ) | |
| # Bot贸n de descarga | |
| model_buffer = io.BytesIO() | |
| pickle.dump(results['best_model'], model_buffer) # Guardar el mejor modelo | |
| model_buffer.seek(0) | |
| download_key = f"download_{model_name}" | |
| st.download_button( | |
| label="Descargar Modelo", | |
| data=model_buffer, | |
| file_name=model_name_input, | |
| mime="application/octet-stream", | |
| key=download_key | |
| ) | |
| # # Bot贸n de descarga | |
| # modelo_data = descargar_modelo_h2o(results['best_model'], model_name) | |
| # if modelo_data: | |
| # st.download_button( | |
| # label="Descargar Modelo", | |
| # data=modelo_data, | |
| # file_name=model_name_input, | |
| # mime="application/zip", | |
| # key=f"download_{model_name}" | |
| # ) |