EpInflammAge / src /pt /model_sweep.py
kalyakulina's picture
First commit
621dedd verified
import copy
import time
import warnings
from contextlib import nullcontext
from typing import Callable, List, Optional, Tuple, Union
import numpy as np
import pandas as pd
from rich.progress import Progress, track
from pytorch_tabular import TabularModel, models
from pytorch_tabular.config import (
DataConfig,
ExperimentConfig,
ModelConfig,
OptimizerConfig,
TrainerConfig,
)
from pytorch_tabular.utils import (
OOMException,
OutOfMemoryHandler,
available_models,
get_logger,
int_to_human_readable,
suppress_lightning_logs,
)
from pytorch_lightning.tuner.tuning import Tuner
from pathlib import Path
import os
logger = get_logger("pytorch_tabular")
MODEL_SWEEP_PRESETS = {
"lite": (
("CategoryEmbeddingModelConfig", {"layers": "256-128-64"}),
("GANDALFConfig", {"gflu_stages": 6}),
(
"TabNetModelConfig",
{
"n_d": 32,
"n_a": 32,
"n_steps": 3,
"gamma": 1.5,
"n_independent": 1,
"n_shared": 2,
},
),
),
"standard": (
("CategoryEmbeddingModelConfig", {"layers": "256-128-64"}),
("CategoryEmbeddingModelConfig", {"layers": "512-128-64"}),
("GANDALFConfig", {"gflu_stages": 6}),
("GANDALFConfig", {"gflu_stages": 15}),
(
"TabNetModelConfig",
{
"n_d": 32,
"n_a": 32,
"n_steps": 3,
"gamma": 1.5,
"n_independent": 1,
"n_shared": 2,
},
),
(
"TabNetModelConfig",
{
"n_d": 32,
"n_a": 32,
"n_steps": 5,
"gamma": 1.5,
"n_independent": 2,
"n_shared": 3,
},
),
("FTTransformerConfig", {"num_heads": 4, "num_attn_blocks": 4}),
),
"full": (m for m in available_models() if m not in ["MDNConfig", "NodeConfig"]),
"high_memory": (m for m in available_models() if m not in ["MDNConfig"]),
}
def _validate_args(
task: str,
train: pd.DataFrame,
test: pd.DataFrame,
metrics: Optional[List[Union[str, Callable]]] = None,
metrics_params: Optional[List[dict]] = None,
metrics_prob_input: Optional[List[bool]] = None,
common_model_args: Optional[dict] = {},
rank_metric: Optional[str] = "loss",
):
assert task in [
"classification",
"regression",
], f"task must be one of ['classification', 'regression'], but got {task}"
assert isinstance(train, pd.DataFrame), f"train must be a pandas DataFrame, but got {type(train)}"
assert isinstance(test, pd.DataFrame), f"test must be a pandas DataFrame, but got {type(test)}"
if metrics is not None:
assert isinstance(metrics, list), f"metrics must be a list of strings or callables, but got {type(metrics)}"
assert all(
isinstance(m, (str, Callable)) for m in metrics
), f"metrics must be a list of strings or callables, but got {metrics}"
assert metrics_params is not None, "metric_params cannot be None when metrics is not None"
assert metrics_prob_input is not None, "metrics_prob_inputs cannot be None when metrics is not None"
assert isinstance(
metrics_params, list
), f"metric_params must be a list of dicts, but got {type(metrics_params)}"
assert isinstance(metrics_prob_input, list), (
"metrics_prob_inputs must be a list of bools, but got" f" {type(metrics_prob_input)}"
)
assert len(metrics) == len(metrics_params), (
"metrics and metric_params must be of the same length, but got" f" {len(metrics)} and {len(metrics_params)}"
)
assert len(metrics) == len(metrics_prob_input), (
"metrics and metrics_prob_inputs must be of the same length, but got"
f" {len(metrics)} and {len(metrics_prob_input)}"
)
assert all(
isinstance(m, dict) for m in metrics_params
), f"metric_params must be a list of dicts, but got {metrics_params}"
if common_model_args is not None:
# all args should be members of ModelConfig
assert all(k in ModelConfig.__dataclass_fields__.keys() for k in common_model_args.keys()), (
"common_model_args must be a subset of ModelConfig, but got" f" {common_model_args.keys()}"
)
if rank_metric[0] not in ["loss", "accuracy", "mean_squared_error"]:
assert rank_metric[0] in metrics, f"rank_metric must be one of {metrics}, but got {rank_metric}"
assert rank_metric[1] in [
"lower_is_better",
"higher_is_better",
], (
"rank_metric[1] must be one of ['lower_is_better', 'higher_is_better'], but" f" got {rank_metric[1]}"
)
def _validate_arg_model_list(model_list, task):
assert model_list is not None, "models cannot be None"
assert isinstance(
model_list, (str, list)
), f"models must be a string or list of strings, but got {type(model_list)}"
if isinstance(model_list, str):
assert (
model_list in MODEL_SWEEP_PRESETS.keys()
), f"models must be one of {MODEL_SWEEP_PRESETS.keys()}, but got {model_list}"
else: # isinstance(models, list):
assert all(
isinstance(m, (str, ModelConfig)) for m in model_list
), f"models must be a list of strings or ModelConfigs, but got {model_list}"
assert all(task == m.task for m in model_list if isinstance(m, ModelConfig)), (
f"task must be the same as the task in ModelConfig, but got {task} and"
f" {[m.task for m in model_list if isinstance(m, ModelConfig)]}"
)
def model_sweep_custom(
task: str,
train: pd.DataFrame,
test: pd.DataFrame,
data_config: Union[DataConfig, str],
optimizer_config: Union[OptimizerConfig, str],
trainer_config: Union[TrainerConfig, str],
model_list: Union[str, List[Union[ModelConfig, str]]] = "lite",
metrics: Optional[List[Union[str, Callable]]] = None,
metrics_params: Optional[List[dict]] = None,
metrics_prob_input: Optional[List[bool]] = None,
validation: Optional[pd.DataFrame] = None,
experiment_config: Optional[Union[ExperimentConfig, str]] = None,
common_model_args: Optional[dict] = {},
rank_metric: Optional[Tuple[str, str]] = ("loss", "lower_is_better"),
return_best_model: bool = True,
seed: int = 42,
ignore_oom: bool = True,
progress_bar: bool = True,
verbose: bool = True,
suppress_lightning_logger: bool = True,
min_lr: float = 1e-8,
max_lr: float = 1,
num_training: int = 100,
mode: str = "exponential",
early_stop_threshold: Optional[float] = 4.0,
**kwargs,
):
"""Compare multiple models on the same dataset.
Args:
task (str): The type of prediction task. Either 'classification' or 'regression'
train (pd.DataFrame): The training data
test (pd.DataFrame): The test data on which performance is evaluated
data_config (Union[DataConfig, str]): DataConfig object or path to the yaml file.
optimizer_config (Union[OptimizerConfig, str]): OptimizerConfig object or path to the yaml file.
trainer_config (Union[TrainerConfig, str]): TrainerConfig object or path to the yaml file.
model_list (Union[str, List[Union[ModelConfig, str]]], optional): The list of models to compare.
This can be one of the presets defined in ``pytorch_tabular.tabular_model_sweep.MODEL_SWEEP_PRESETS``
or a list of ``ModelConfig`` objects. Defaults to "lite".
metrics (Optional[List[str]]): the list of metrics you need to track during training. The metrics
should be one of the functional metrics implemented in ``torchmetrics``. By default, it is
accuracy if classification and mean_squared_error for regression
metrics_prob_input (Optional[bool]): Is a mandatory parameter for classification metrics defined in
the config. This defines whether the input to the metric function is the probability or the class.
Length should be same as the number of metrics. Defaults to None.
metrics_params (Optional[List]): The parameters to be passed to the metrics function. `task` is forced to
be `multiclass` because the multiclass version can handle binary as well and for simplicity we are
only using `multiclass`.
validation (Optional[DataFrame], optional):
If provided, will use this dataframe as the validation while training.
Used in Early Stopping and Logging. If left empty, will use 20% of Train data as validation.
Defaults to None.
experiment_config (Optional[Union[ExperimentConfig, str]], optional): ExperimentConfig object or path to
the yaml file.
common_model_args (Optional[dict], optional): The model argument which are common to all models. The list
of params can be found in ``ModelConfig``. If not provided, will use defaults. Defaults to {}.
rank_metric (Optional[Tuple[str, str]], optional): The metric to use for ranking the models. The first element
of the tuple is the metric name and the second element is the direction.
Defaults to ('loss', "lower_is_better").
return_best_model (bool, optional): If True, will return the best model. Defaults to True.
seed (int, optional): The seed for reproducibility. Defaults to 42.
ignore_oom (bool, optional): If True, will ignore the Out of Memory error and continue with the next model.
progress_bar (bool, optional): If True, will show a progress bar. Defaults to True.
verbose (bool, optional): If True, will print the progress. Defaults to True.
suppress_lightning_logger (bool, optional): If True, will suppress the lightning logger. Defaults to True.
min_lr (Optional[float], optional): minimum learning rate to investigate
max_lr (Optional[float], optional): maximum learning rate to investigate
num_training (Optional[int], optional): number of learning rates to test
mode (Optional[str], optional): search strategy, either 'linear' or 'exponential'. If set to
'linear' the learning rate will be searched by linearly increasing
after each batch. If set to 'exponential', will increase learning
rate exponentially.
early_stop_threshold (Optional[float], optional): threshold for stopping the search. If the
loss at any point is larger than early_stop_threshold*best_loss
then the search is stopped. To disable, set to None.
**kwargs: Additional keyword arguments to be passed to the TabularModel fit.
Returns:
results: Training results.
best_model: If return_best_model is True, return best_model otherwise return None.
"""
_validate_args(
task=task,
train=train,
test=test,
metrics=metrics,
metrics_params=metrics_params,
metrics_prob_input=metrics_prob_input,
common_model_args=common_model_args,
rank_metric=rank_metric,
)
_validate_arg_model_list(model_list, task)
if suppress_lightning_logger:
suppress_lightning_logs()
if progress_bar:
if trainer_config.progress_bar != "none":
# Turning off thie internal progress bar to avoid conflict with sweep progress bar
warnings.warn(
"Training Progress bar is not `none`. Set `progress_bar=none` in"
" `trainer_config` to remove this warning"
)
trainer_config.progress_bar = "none"
if model_list in ["full", "high_memory"]:
warnings.warn(
"The full model list is quite large and uses a lot of memory. "
"Consider using `lite` or define configs yourselves for a faster run"
)
_model_args = ["metrics", "metrics_params", "metrics_prob_input"]
# Replacing the common model args with the ones passed in the function
for arg in _model_args:
if locals()[arg] is not None:
common_model_args[arg] = locals()[arg]
if isinstance(model_list, str):
model_list = copy.deepcopy(MODEL_SWEEP_PRESETS[model_list])
model_list = [
(
getattr(models, model_config[0])(task=task, **model_config[1], **common_model_args)
if isinstance(model_config, Tuple)
else (
getattr(models, model_config)(task=task, **common_model_args)
if isinstance(model_config, str)
else model_config
)
)
for model_config in model_list
]
def _init_tabular_model(m):
return TabularModel(
data_config=data_config,
model_config=m,
optimizer_config=optimizer_config,
trainer_config=trainer_config,
experiment_config=experiment_config,
verbose=False,
)
init_tabular_model = _init_tabular_model(model_list[0])
prep_dl_kwargs, prep_model_kwargs, train_kwargs = init_tabular_model._split_kwargs(kwargs)
datamodule = init_tabular_model.prepare_dataloader(train=train, validation=validation, seed=seed, **prep_dl_kwargs)
results = []
best_model = None
is_lower_better = rank_metric[1] == "lower_is_better"
best_score = 1e9 if is_lower_better else -1e9
it = track(model_list, description="Sweeping Models") if progress_bar else model_list
ctx = Progress() if progress_bar else nullcontext()
with ctx as progress:
if progress_bar:
task_p = progress.add_task("Sweeping Models", total=len(model_list))
for model_config in model_list:
if isinstance(model_config, str):
model_config = getattr(models, model_config)(task=task, **common_model_args)
else:
for key, val in common_model_args.items():
if hasattr(model_config, key):
setattr(model_config, key, val)
else:
raise ValueError(
f"ModelConfig {model_config.name} does not have an" f" attribute {key} in common_model_args"
)
params = model_config.__dict__
start_time = time.time()
tabular_model = _init_tabular_model(model_config)
name = tabular_model.name
if verbose:
logger.info(f"Training {name}")
model = tabular_model.prepare_model(datamodule, **prep_model_kwargs)
if progress_bar:
progress.update(task_p, description=f"Training {name}", advance=1)
with OutOfMemoryHandler(handle_oom=True) as handler:
# Copy from train() method with additional lr_find parameters
handle_oom = False
tabular_model._prepare_for_training(model, datamodule, **train_kwargs)
train_loader, val_loader = (
tabular_model.datamodule.train_dataloader(),
tabular_model.datamodule.val_dataloader(),
)
tabular_model.model.train()
if tabular_model.config.auto_lr_find and (not tabular_model.config.fast_dev_run):
if tabular_model.verbose:
logger.info("Auto LR Find Started")
with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
lr_finder = Tuner(tabular_model.trainer).lr_find(
tabular_model.model,
train_dataloaders=train_loader,
val_dataloaders=val_loader,
min_lr=min_lr,
max_lr=max_lr,
num_training=num_training,
mode=mode,
early_stop_threshold=early_stop_threshold,
)
if oom_handler.oom_triggered:
raise OOMException(
"OOM detected during LR Find. Try reducing your batch_size or the"
" model parameters." + "/n" + "Original Error: " + oom_handler.oom_msg
)
if tabular_model.verbose:
logger.info(
f"Suggested LR: {lr_finder.suggestion()}. For plot and detailed"
" analysis, use `find_learning_rate` method."
)
tabular_model.model.reset_weights()
# Parameters in models needs to be initialized again after LR find
tabular_model.model.data_aware_initialization(tabular_model.datamodule)
tabular_model.model.train()
if tabular_model.verbose:
logger.info("Training Started")
with OutOfMemoryHandler(handle_oom=handle_oom) as oom_handler:
tabular_model.trainer.fit(tabular_model.model, train_loader, val_loader)
if oom_handler.oom_triggered:
raise OOMException(
"OOM detected during Training. Try reducing your batch_size or the"
" model parameters."
"/n" + "Original Error: " + oom_handler.oom_msg
)
tabular_model._is_fitted = True
if tabular_model.verbose:
logger.info("Training the model completed")
if tabular_model.config.load_best:
tabular_model.load_best_model()
res_dict = {
"model": name,
'learning_rate': lr_finder.suggestion(),
"# Params": int_to_human_readable(tabular_model.num_params),
}
if handler.oom_triggered:
if not ignore_oom:
raise OOMException(
"Out of memory error occurred during cross validation. "
"Set ignore_oom=True to ignore this error."
)
else:
res_dict.update(
{
f"test_{rank_metric[0]}": (np.inf if is_lower_better else -np.inf),
"epochs": "OOM",
"time_taken": "OOM",
"time_taken_per_epoch": "OOM",
}
)
res_dict["model"] = name + " (OOM)"
else:
if (
tabular_model.trainer.early_stopping_callback is not None
and tabular_model.trainer.early_stopping_callback.stopped_epoch != 0
):
res_dict["epochs"] = tabular_model.trainer.early_stopping_callback.stopped_epoch
else:
res_dict["epochs"] = tabular_model.trainer.max_epochs
# Update results with train metrics
train_metrics = tabular_model.evaluate(test=train, verbose=False)[0]
metrics_names = list(train_metrics.keys())
for m_name in metrics_names:
train_metrics[m_name.replace('test', 'train')] = train_metrics.pop(m_name)
res_dict.update(train_metrics)
# Update results with validation metrics
validation_metrics = tabular_model.evaluate(test=validation, verbose=False)[0]
metrics_names = list(validation_metrics.keys())
print(validation_metrics)
for m_name in metrics_names:
validation_metrics[m_name.replace('test', 'validation')] = validation_metrics.pop(m_name)
res_dict.update(validation_metrics)
# Update results with test metrics
res_dict.update(tabular_model.evaluate(test=test, verbose=False)[0])
res_dict["time_taken"] = time.time() - start_time
res_dict["time_taken_per_epoch"] = res_dict["time_taken"] / res_dict["epochs"]
if verbose:
logger.info(f"Finished Training {name}")
logger.info("Results:" f" {', '.join([f'{k}: {v}' for k,v in res_dict.items()])}")
res_dict["params"] = params
if tabular_model.trainer.checkpoint_callback:
res_dict["checkpoint"] = tabular_model.trainer.checkpoint_callback.best_model_path
save_dir = str(Path(res_dict["checkpoint"]).parent).replace('\\', '/') + '/' + Path(res_dict["checkpoint"]).stem
tabular_model.save_model(save_dir)
os.remove(res_dict["checkpoint"])
results.append(res_dict)
if tabular_model.config['checkpoints_path']:
try:
pd.DataFrame(results).style.background_gradient(
subset=[
"train_loss",
"validation_loss",
"test_loss",
"time_taken",
"time_taken_per_epoch"
], cmap="RdYlGn_r"
).to_excel(f"{tabular_model.config['checkpoints_path']}/progress.xlsx")
except PermissionError:
pass
if return_best_model:
tabular_model.datamodule = None
if best_model is None:
best_model = copy.deepcopy(tabular_model)
best_score = res_dict[f"test_{rank_metric[0]}"]
else:
if is_lower_better:
if res_dict[f"test_{rank_metric[0]}"] < best_score:
best_model = copy.deepcopy(tabular_model)
best_score = res_dict[f"test_{rank_metric[0]}"]
else:
if res_dict[f"test_{rank_metric[0]}"] > best_score:
best_model = copy.deepcopy(tabular_model)
best_score = res_dict[f"test_{rank_metric[0]}"]
if verbose:
logger.info("Model Sweep Finished")
logger.info(f"Best Model: {best_model.name}")
results = pd.DataFrame(results).sort_values(by=f"test_{rank_metric[0]}", ascending=is_lower_better)
if return_best_model and best_model is not None:
best_model.datamodule = datamodule
return results, best_model
else:
return results, None