import copy import time import warnings from contextlib import nullcontext from typing import Callable, List, Optional, Tuple, Union import numpy as np import pandas as pd from rich.progress import Progress, track from pytorch_tabular import TabularModel, models from pytorch_tabular.config import ( DataConfig, ExperimentConfig, ModelConfig, OptimizerConfig, TrainerConfig, ) from pytorch_tabular.utils import ( OOMException, OutOfMemoryHandler, available_models, get_logger, int_to_human_readable, suppress_lightning_logs, ) from pytorch_lightning.tuner.tuning import Tuner from pathlib import Path import os logger = get_logger("pytorch_tabular") MODEL_SWEEP_PRESETS = { "lite": ( ("CategoryEmbeddingModelConfig", {"layers": "256-128-64"}), ("GANDALFConfig", {"gflu_stages": 6}), ( "TabNetModelConfig", { "n_d": 32, "n_a": 32, "n_steps": 3, "gamma": 1.5, "n_independent": 1, "n_shared": 2, }, ), ), "standard": ( ("CategoryEmbeddingModelConfig", {"layers": "256-128-64"}), ("CategoryEmbeddingModelConfig", {"layers": "512-128-64"}), ("GANDALFConfig", {"gflu_stages": 6}), ("GANDALFConfig", {"gflu_stages": 15}), ( "TabNetModelConfig", { "n_d": 32, "n_a": 32, "n_steps": 3, "gamma": 1.5, "n_independent": 1, "n_shared": 2, }, ), ( "TabNetModelConfig", { "n_d": 32, "n_a": 32, "n_steps": 5, "gamma": 1.5, "n_independent": 2, "n_shared": 3, }, ), ("FTTransformerConfig", {"num_heads": 4, "num_attn_blocks": 4}), ), "full": (m for m in available_models() if m not in ["MDNConfig", "NodeConfig"]), "high_memory": (m for m in available_models() if m not in ["MDNConfig"]), } def _validate_args( task: str, train: pd.DataFrame, test: pd.DataFrame, metrics: Optional[List[Union[str, Callable]]] = None, metrics_params: Optional[List[dict]] = None, metrics_prob_input: Optional[List[bool]] = None, common_model_args: Optional[dict] = {}, rank_metric: Optional[str] = "loss", ): assert task in [ "classification", "regression", ], f"task must be one of ['classification', 'regression'], but got {task}" assert isinstance(train, pd.DataFrame), f"train must be a pandas DataFrame, but got {type(train)}" assert isinstance(test, pd.DataFrame), f"test must be a pandas DataFrame, but got {type(test)}" if metrics is not None: assert isinstance(metrics, list), f"metrics must be a list of strings or callables, but got {type(metrics)}" assert all( isinstance(m, (str, Callable)) for m in metrics ), f"metrics must be a list of strings or callables, but got {metrics}" assert metrics_params is not None, "metric_params cannot be None when metrics is not None" assert metrics_prob_input is not None, "metrics_prob_inputs cannot be None when metrics is not None" assert isinstance( metrics_params, list ), f"metric_params must be a list of dicts, but got {type(metrics_params)}" assert isinstance(metrics_prob_input, list), ( "metrics_prob_inputs must be a list of bools, but got" f" {type(metrics_prob_input)}" ) assert len(metrics) == len(metrics_params), ( "metrics and metric_params must be of the same length, but got" f" {len(metrics)} and {len(metrics_params)}" ) assert len(metrics) == len(metrics_prob_input), ( "metrics and metrics_prob_inputs must be of the same length, but got" f" {len(metrics)} and {len(metrics_prob_input)}" ) assert all( isinstance(m, dict) for m in metrics_params ), f"metric_params must be a list of dicts, but got {metrics_params}" if common_model_args is not None: # all args should be members of ModelConfig assert all(k in ModelConfig.__dataclass_fields__.keys() for k in common_model_args.keys()), ( "common_model_args must be a subset of ModelConfig, but got" f" {common_model_args.keys()}" ) if rank_metric[0] not in ["loss", "accuracy", "mean_squared_error"]: assert rank_metric[0] in metrics, f"rank_metric must be one of {metrics}, but got {rank_metric}" assert rank_metric[1] in [ "lower_is_better", "higher_is_better", ], ( "rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but" f" got {rank_metric[1]}" ) def _validate_arg_model_list(model_list, task): assert model_list is not None, "models cannot be None" assert isinstance( model_list, (str, list) ), f"models must be a string or list of strings, but got {type(model_list)}" if isinstance(model_list, str): assert ( model_list in MODEL_SWEEP_PRESETS.keys() ), f"models must be one of {MODEL_SWEEP_PRESETS.keys()}, but got {model_list}" else: # isinstance(models, list): assert all( isinstance(m, (str, ModelConfig)) for m in model_list ), f"models must be a list of strings or ModelConfigs, but got {model_list}" assert all(task == m.task for m in model_list if isinstance(m, ModelConfig)), ( f"task must be the same as the task in ModelConfig, but got {task} and" f" {[m.task for m in model_list if isinstance(m, ModelConfig)]}" ) def model_sweep_custom( task: str, train: pd.DataFrame, test: pd.DataFrame, data_config: Union[DataConfig, str], optimizer_config: Union[OptimizerConfig, str], trainer_config: Union[TrainerConfig, str], model_list: Union[str, List[Union[ModelConfig, str]]] = "lite", metrics: Optional[List[Union[str, Callable]]] = None, metrics_params: Optional[List[dict]] = None, metrics_prob_input: Optional[List[bool]] = None, validation: Optional[pd.DataFrame] = None, experiment_config: Optional[Union[ExperimentConfig, str]] = None, common_model_args: Optional[dict] = {}, rank_metric: Optional[Tuple[str, str]] = ("loss", "lower_is_better"), return_best_model: bool = True, seed: int = 42, ignore_oom: bool = True, progress_bar: bool = True, verbose: bool = True, suppress_lightning_logger: bool = True, min_lr: float = 1e-8, max_lr: float = 1, num_training: int = 100, mode: str = "exponential", early_stop_threshold: Optional[float] = 4.0, **kwargs, ): """Compare multiple models on the same dataset. Args: task (str): The type of prediction task. Either 'classification' or 'regression' train (pd.DataFrame): The training data test (pd.DataFrame): The test data on which performance is evaluated data_config (Union[DataConfig, str]): DataConfig object or path to the yaml file. optimizer_config (Union[OptimizerConfig, str]): OptimizerConfig object or path to the yaml file. trainer_config (Union[TrainerConfig, str]): TrainerConfig object or path to the yaml file. model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare. This can be one of the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS`` or a list of ``ModelConfig`` objects. Defaults to "lite". metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics should be one of the functional metrics implemented in ``torchmetrics``. By default, it is accuracy if classification and mean_squared_error for regression metrics_prob_input (Optional[bool]): Is a mandatory parameter for classification metrics defined in the config. This defines whether the input to the metric function is the probability or the class. Length should be same as the number of metrics. Defaults to None. metrics_params (Optional[List]): The parameters to be passed to the metrics function. `task` is forced to be `multiclass` because the multiclass version can handle binary as well and for simplicity we are only using `multiclass`. validation (Optional[DataFrame], optional): If provided, will use this dataframe as the validation while training. Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation. Defaults to None. experiment_config (Optional[Union[ExperimentConfig, str]], optional): ExperimentConfig object or path to the yaml file. common_model_args (Optional[dict], optional): The model argument which are common to all models. The list of params can be found in ``ModelConfig``. If not provided, will use defaults. Defaults to {}. rank_metric (Optional[Tuple[str, str]], optional): The metric to use for ranking the models. The first element of the tuple is the metric name and the second element is the direction. Defaults to ('loss', "lower_is_better"). return_best_model (bool, optional): If True, will return the best model. Defaults to True. seed (int, optional): The seed for reproducibility. Defaults to 42. ignore_oom (bool, optional): If True, will ignore the Out of Memory error and continue with the next model. progress_bar (bool, optional): If True, will show a progress bar. Defaults to True. verbose (bool, optional): If True, will print the progress. Defaults to True. suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True. min_lr (Optional[float], optional): minimum learning rate to investigate max_lr (Optional[float], optional): maximum learning rate to investigate num_training (Optional[int], optional): number of learning rates to test mode (Optional[str], optional): search strategy, either 'linear' or 'exponential'. If set to 'linear' the learning rate will be searched by linearly increasing after each batch. If set to 'exponential', will increase learning rate exponentially. early_stop_threshold (Optional[float], optional): threshold for stopping the search. If the loss at any point is larger than early_stop_threshold*best_loss then the search is stopped. To disable, set to None. **kwargs: Additional keyword arguments to be passed to the TabularModel fit. Returns: results: Training results. best_model: If return_best_model is True, return best_model otherwise return None. """ _validate_args( task=task, train=train, test=test, metrics=metrics, metrics_params=metrics_params, metrics_prob_input=metrics_prob_input, common_model_args=common_model_args, rank_metric=rank_metric, ) _validate_arg_model_list(model_list, task) if suppress_lightning_logger: suppress_lightning_logs() if progress_bar: if trainer_config.progress_bar != "none": # Turning off thie internal progress bar to avoid conflict with sweep progress bar warnings.warn( "Training Progress bar is not `none`. Set `progress_bar=none` in" " `trainer_config` to remove this warning" ) trainer_config.progress_bar = "none" if model_list in ["full", "high_memory"]: warnings.warn( "The full model list is quite large and uses a lot of memory. " "Consider using `lite` or define configs yourselves for a faster run" ) _model_args = ["metrics", "metrics_params", "metrics_prob_input"] # Replacing the common model args with the ones passed in the function for arg in _model_args: if locals()[arg] is not None: common_model_args[arg] = locals()[arg] if isinstance(model_list, str): model_list = copy.deepcopy(MODEL_SWEEP_PRESETS[model_list]) model_list = [ ( getattr(models, model_config[0])(task=task, **model_config[1], **common_model_args) if isinstance(model_config, Tuple) else ( getattr(models, model_config)(task=task, **common_model_args) if isinstance(model_config, str) else model_config ) ) for model_config in model_list ] def _init_tabular_model(m): return TabularModel( data_config=data_config, model_config=m, optimizer_config=optimizer_config, trainer_config=trainer_config, experiment_config=experiment_config, verbose=False, ) init_tabular_model = _init_tabular_model(model_list[0]) prep_dl_kwargs, prep_model_kwargs, train_kwargs = init_tabular_model._split_kwargs(kwargs) datamodule = init_tabular_model.prepare_dataloader(train=train, validation=validation, seed=seed, **prep_dl_kwargs) results = [] best_model = None is_lower_better = rank_metric[1] == "lower_is_better" best_score = 1e9 if is_lower_better else -1e9 it = track(model_list, description="Sweeping Models") if progress_bar else model_list ctx = Progress() if progress_bar else nullcontext() with ctx as progress: if progress_bar: task_p = progress.add_task("Sweeping Models", total=len(model_list)) for model_config in model_list: if isinstance(model_config, str): model_config = getattr(models, model_config)(task=task, **common_model_args) else: for key, val in common_model_args.items(): if hasattr(model_config, key): setattr(model_config, key, val) else: raise ValueError( f"ModelConfig {model_config.name} does not have an" f" attribute {key} in common_model_args" ) params = model_config.__dict__ start_time = time.time() tabular_model = _init_tabular_model(model_config) name = tabular_model.name if verbose: logger.info(f"Training {name}") model = tabular_model.prepare_model(datamodule, **prep_model_kwargs) if progress_bar: progress.update(task_p, description=f"Training {name}", advance=1) with OutOfMemoryHandler(handle_oom=True) as handler: # Copy from train() method with additional lr_find parameters handle_oom = False tabular_model._prepare_for_training(model, datamodule, **train_kwargs) train_loader, val_loader = ( tabular_model.datamodule.train_dataloader(), tabular_model.datamodule.val_dataloader(), ) tabular_model.model.train() if tabular_model.config.auto_lr_find and (not tabular_model.config.fast_dev_run): if tabular_model.verbose: logger.info("Auto LR Find Started") with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler: lr_finder = Tuner(tabular_model.trainer).lr_find( tabular_model.model, train_dataloaders=train_loader, val_dataloaders=val_loader, min_lr=min_lr, max_lr=max_lr, num_training=num_training, mode=mode, early_stop_threshold=early_stop_threshold, ) if oom_handler.oom_triggered: raise OOMException( "OOM detected during LR Find. Try reducing your batch_size or the" " model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg ) if tabular_model.verbose: logger.info( f"Suggested LR: {lr_finder.suggestion()}. For plot and detailed" " analysis, use `find_learning_rate` method." ) tabular_model.model.reset_weights() # Parameters in models needs to be initialized again after LR find tabular_model.model.data_aware_initialization(tabular_model.datamodule) tabular_model.model.train() if tabular_model.verbose: logger.info("Training Started") with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler: tabular_model.trainer.fit(tabular_model.model, train_loader, val_loader) if oom_handler.oom_triggered: raise OOMException( "OOM detected during Training. Try reducing your batch_size or the" " model parameters." "/n" + "Original Error: " + oom_handler.oom_msg ) tabular_model._is_fitted = True if tabular_model.verbose: logger.info("Training the model completed") if tabular_model.config.load_best: tabular_model.load_best_model() res_dict = { "model": name, 'learning_rate': lr_finder.suggestion(), "# Params": int_to_human_readable(tabular_model.num_params), } if handler.oom_triggered: if not ignore_oom: raise OOMException( "Out of memory error occurred during cross validation. " "Set ignore_oom=True to ignore this error." ) else: res_dict.update( { f"test_{rank_metric[0]}": (np.inf if is_lower_better else -np.inf), "epochs": "OOM", "time_taken": "OOM", "time_taken_per_epoch": "OOM", } ) res_dict["model"] = name + " (OOM)" else: if ( tabular_model.trainer.early_stopping_callback is not None and tabular_model.trainer.early_stopping_callback.stopped_epoch != 0 ): res_dict["epochs"] = tabular_model.trainer.early_stopping_callback.stopped_epoch else: res_dict["epochs"] = tabular_model.trainer.max_epochs # Update results with train metrics train_metrics = tabular_model.evaluate(test=train, verbose=False)[0] metrics_names = list(train_metrics.keys()) for m_name in metrics_names: train_metrics[m_name.replace('test', 'train')] = train_metrics.pop(m_name) res_dict.update(train_metrics) # Update results with validation metrics validation_metrics = tabular_model.evaluate(test=validation, verbose=False)[0] metrics_names = list(validation_metrics.keys()) print(validation_metrics) for m_name in metrics_names: validation_metrics[m_name.replace('test', 'validation')] = validation_metrics.pop(m_name) res_dict.update(validation_metrics) # Update results with test metrics res_dict.update(tabular_model.evaluate(test=test, verbose=False)[0]) res_dict["time_taken"] = time.time() - start_time res_dict["time_taken_per_epoch"] = res_dict["time_taken"] / res_dict["epochs"] if verbose: logger.info(f"Finished Training {name}") logger.info("Results:" f" {', '.join([f'{k}: {v}' for k,v in res_dict.items()])}") res_dict["params"] = params if tabular_model.trainer.checkpoint_callback: res_dict["checkpoint"] = tabular_model.trainer.checkpoint_callback.best_model_path save_dir = str(Path(res_dict["checkpoint"]).parent).replace('\\', '/') + '/' + Path(res_dict["checkpoint"]).stem tabular_model.save_model(save_dir) os.remove(res_dict["checkpoint"]) results.append(res_dict) if tabular_model.config['checkpoints_path']: try: pd.DataFrame(results).style.background_gradient( subset=[ "train_loss", "validation_loss", "test_loss", "time_taken", "time_taken_per_epoch" ], cmap="RdYlGn_r" ).to_excel(f"{tabular_model.config['checkpoints_path']}/progress.xlsx") except PermissionError: pass if return_best_model: tabular_model.datamodule = None if best_model is None: best_model = copy.deepcopy(tabular_model) best_score = res_dict[f"test_{rank_metric[0]}"] else: if is_lower_better: if res_dict[f"test_{rank_metric[0]}"] < best_score: best_model = copy.deepcopy(tabular_model) best_score = res_dict[f"test_{rank_metric[0]}"] else: if res_dict[f"test_{rank_metric[0]}"] > best_score: best_model = copy.deepcopy(tabular_model) best_score = res_dict[f"test_{rank_metric[0]}"] if verbose: logger.info("Model Sweep Finished") logger.info(f"Best Model: {best_model.name}") results = pd.DataFrame(results).sort_values(by=f"test_{rank_metric[0]}", ascending=is_lower_better) if return_best_model and best_model is not None: best_model.datamodule = datamodule return results, best_model else: return results, None