| | |
| | from ...dist_utils import master_only |
| | from ..hook import HOOKS |
| | from .base import LoggerHook |
| |
|
| |
|
| | @HOOKS.register_module() |
| | class MlflowLoggerHook(LoggerHook): |
| |
|
| | def __init__(self, |
| | exp_name=None, |
| | tags=None, |
| | log_model=True, |
| | interval=10, |
| | ignore_last=True, |
| | reset_flag=False, |
| | by_epoch=True): |
| | """Class to log metrics and (optionally) a trained model to MLflow. |
| | |
| | It requires `MLflow`_ to be installed. |
| | |
| | Args: |
| | exp_name (str, optional): Name of the experiment to be used. |
| | Default None. |
| | If not None, set the active experiment. |
| | If experiment does not exist, an experiment with provided name |
| | will be created. |
| | tags (dict of str: str, optional): Tags for the current run. |
| | Default None. |
| | If not None, set tags for the current run. |
| | log_model (bool, optional): Whether to log an MLflow artifact. |
| | Default True. |
| | If True, log runner.model as an MLflow artifact |
| | for the current run. |
| | interval (int): Logging interval (every k iterations). |
| | ignore_last (bool): Ignore the log of last iterations in each epoch |
| | if less than `interval`. |
| | reset_flag (bool): Whether to clear the output buffer after logging |
| | by_epoch (bool): Whether EpochBasedRunner is used. |
| | |
| | .. _MLflow: |
| | https://www.mlflow.org/docs/latest/index.html |
| | """ |
| | super(MlflowLoggerHook, self).__init__(interval, ignore_last, |
| | reset_flag, by_epoch) |
| | self.import_mlflow() |
| | self.exp_name = exp_name |
| | self.tags = tags |
| | self.log_model = log_model |
| |
|
| | def import_mlflow(self): |
| | try: |
| | import mlflow |
| | import mlflow.pytorch as mlflow_pytorch |
| | except ImportError: |
| | raise ImportError( |
| | 'Please run "pip install mlflow" to install mlflow') |
| | self.mlflow = mlflow |
| | self.mlflow_pytorch = mlflow_pytorch |
| |
|
| | @master_only |
| | def before_run(self, runner): |
| | super(MlflowLoggerHook, self).before_run(runner) |
| | if self.exp_name is not None: |
| | self.mlflow.set_experiment(self.exp_name) |
| | if self.tags is not None: |
| | self.mlflow.set_tags(self.tags) |
| |
|
| | @master_only |
| | def log(self, runner): |
| | tags = self.get_loggable_tags(runner) |
| | if tags: |
| | self.mlflow.log_metrics(tags, step=self.get_iter(runner)) |
| |
|
| | @master_only |
| | def after_run(self, runner): |
| | if self.log_model: |
| | self.mlflow_pytorch.log_model(runner.model, 'models') |
| |
|