| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import os |
| | from argparse import Namespace |
| | from dataclasses import dataclass |
| | from pathlib import Path |
| | from typing import Any, List, Literal, Mapping, Optional, Union |
| |
|
| | import pandas as pd |
| | from lightning_utilities.core.apply_func import apply_to_collection |
| | from omegaconf import DictConfig, ListConfig, OmegaConf |
| | from pytorch_lightning.callbacks import Checkpoint |
| | from pytorch_lightning.loggers import Logger |
| | from pytorch_lightning.utilities.parsing import AttributeDict |
| | from torch import Tensor |
| |
|
| | from nemo.utils import logging |
| |
|
| | try: |
| | from clearml import OutputModel, Task |
| |
|
| | HAVE_CLEARML_LOGGER = True |
| | except (ImportError, ModuleNotFoundError): |
| | HAVE_CLEARML_LOGGER = False |
| |
|
| |
|
| | @dataclass |
| | class ClearMLParams: |
| | project: Optional[str] = None |
| | task: Optional[str] = None |
| | connect_pytorch: Optional[bool] = False |
| | model_name: Optional[str] = None |
| | tags: Optional[List[str]] = None |
| | log_model: Optional[bool] = False |
| | log_cfg: Optional[bool] = False |
| | log_metrics: Optional[bool] = False |
| |
|
| |
|
| | class ClearMLLogger(Logger): |
| | @property |
| | def name(self) -> str: |
| | return self.clearml_task.name |
| |
|
| | @property |
| | def version(self) -> str: |
| | return self.clearml_task.id |
| |
|
| | def __init__( |
| | self, clearml_cfg: DictConfig, log_dir: str, prefix: str, save_best_model: bool, postfix: str = ".nemo" |
| | ) -> None: |
| | if not HAVE_CLEARML_LOGGER: |
| | raise ImportError( |
| | "Found create_clearml_logger is True." |
| | "But ClearML not found. Please see the README for installation instructions:" |
| | "https://github.com/allegroai/clearml" |
| | ) |
| |
|
| | self.clearml_task = None |
| | self.clearml_model = None |
| | self.clearml_cfg = clearml_cfg |
| | self.path_nemo_model = os.path.abspath( |
| | os.path.expanduser(os.path.join(log_dir, "checkpoints", prefix + postfix)) |
| | ) |
| | self.save_best_model = save_best_model |
| | self.prefix = prefix |
| | self.previos_best_model_path = None |
| | self.last_metrics = None |
| | self.save_blocked = True |
| |
|
| | self.project_name = os.getenv("CLEARML_PROJECT", clearml_cfg.project if clearml_cfg.project else "NeMo") |
| | self.task_name = os.getenv("CLEARML_TASK", clearml_cfg.task if clearml_cfg.task else f"Trainer {self.prefix}") |
| |
|
| | tags = ["NeMo"] |
| | if clearml_cfg.tags: |
| | tags.extend(clearml_cfg.tags) |
| |
|
| | self.clearml_task: Task = Task.init( |
| | project_name=self.project_name, |
| | task_name=self.task_name, |
| | auto_connect_frameworks={"pytorch": clearml_cfg.connect_pytorch}, |
| | output_uri=True, |
| | tags=tags, |
| | ) |
| |
|
| | if clearml_cfg.model_name: |
| | model_name = clearml_cfg.model_name |
| | elif self.prefix: |
| | model_name = self.prefix |
| | else: |
| | model_name = self.task_name |
| |
|
| | if clearml_cfg.log_model: |
| | self.clearml_model: OutputModel = OutputModel( |
| | name=model_name, task=self.clearml_task, tags=tags, framework="NeMo" |
| | ) |
| |
|
| | def log_hyperparams(self, params, *args, **kwargs) -> None: |
| | if self.clearml_model and self.clearml_cfg.log_cfg: |
| | if isinstance(params, Namespace): |
| | params = vars(params) |
| | elif isinstance(params, AttributeDict): |
| | params = dict(params) |
| | params = apply_to_collection(params, (DictConfig, ListConfig), OmegaConf.to_container, resolve=True) |
| | params = apply_to_collection(params, Path, str) |
| | params = OmegaConf.to_yaml(params) |
| | self.clearml_model.update_design(config_text=params) |
| |
|
| | def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: |
| | if self.clearml_model and self.clearml_cfg.log_metrics: |
| | metrics = { |
| | k: { |
| | "value": str(v.item() if type(v) == Tensor else v), |
| | "type": str(type(v.item() if type(v) == Tensor else v)), |
| | } |
| | for k, v in metrics.items() |
| | } |
| | self.last_metrics = metrics |
| |
|
| | def log_table( |
| | self, |
| | key: str, |
| | columns: List[str] = None, |
| | data: List[List[Any]] = None, |
| | dataframe: Any = None, |
| | step: Optional[int] = None, |
| | ) -> None: |
| | table: Optional[Union[pd.DataFrame, List[List[Any]]]] = None |
| |
|
| | if dataframe is not None: |
| | table = dataframe |
| | if columns is not None: |
| | table.columns = columns |
| |
|
| | if data is not None: |
| | table = data |
| | assert len(columns) == len(table[0]), "number of column names should match the total number of columns" |
| | table.insert(0, columns) |
| |
|
| | if table is not None: |
| | self.clearml_task.logger.report_table(title=key, series=key, iteration=step, table_plot=table) |
| |
|
| | def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: |
| | if self.clearml_model: |
| | if self.save_best_model: |
| | if self.save_blocked: |
| | self.save_blocked = False |
| | return None |
| | if not os.path.exists(checkpoint_callback.best_model_path): |
| | return None |
| | if self.previos_best_model_path == checkpoint_callback.best_model_path: |
| | return None |
| | self.previos_best_model_path = checkpoint_callback.best_model_path |
| | self._log_model(self.path_nemo_model) |
| |
|
| | def finalize(self, status: Literal["success", "failed", "aborted"] = "success") -> None: |
| | if status == "success": |
| | self.clearml_task.mark_completed() |
| | elif status == "failed": |
| | self.clearml_task.mark_failed() |
| | elif status == "aborted": |
| | self.clearml_task.mark_stopped() |
| |
|
| | def _log_model(self, save_path: str) -> None: |
| | if self.clearml_model: |
| | if os.path.exists(save_path): |
| | self.clearml_model.update_weights( |
| | weights_filename=save_path, |
| | upload_uri=self.clearml_task.storage_uri or self.clearml_task._get_default_report_storage_uri(), |
| | auto_delete_file=False, |
| | is_package=True, |
| | ) |
| |
|
| | if self.clearml_cfg.log_metrics and self.last_metrics: |
| | self.clearml_model.set_all_metadata(self.last_metrics) |
| |
|
| | self.save_blocked = True |
| | else: |
| | logging.warning((f"Logging model enabled, but cant find .nemo file!" f" Path: {save_path}")) |
| |
|