| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import os |
| from argparse import Namespace |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Any, List, Literal, Mapping, Optional, Union |
|
|
| import pandas as pd |
| from lightning.pytorch.callbacks import Checkpoint |
| from lightning.pytorch.loggers import Logger |
| from lightning.pytorch.utilities.parsing import AttributeDict |
| from lightning_utilities.core.apply_func import apply_to_collection |
| from omegaconf import DictConfig, ListConfig, OmegaConf |
| from torch import Tensor |
|
|
| from nemo.utils import logging |
|
|
| try: |
| from clearml import OutputModel, Task |
|
|
| HAVE_CLEARML_LOGGER = True |
| except (ImportError, ModuleNotFoundError): |
| HAVE_CLEARML_LOGGER = False |
|
|
|
|
| @dataclass |
| class ClearMLParams: |
| project: Optional[str] = None |
| task: Optional[str] = None |
| connect_pytorch: Optional[bool] = False |
| model_name: Optional[str] = None |
| tags: Optional[List[str]] = None |
| log_model: Optional[bool] = False |
| log_cfg: Optional[bool] = False |
| log_metrics: Optional[bool] = False |
|
|
|
|
| class ClearMLLogger(Logger): |
| @property |
| def name(self) -> str: |
| return self.clearml_task.name |
|
|
| @property |
| def version(self) -> str: |
| return self.clearml_task.id |
|
|
| def __init__( |
| self, clearml_cfg: DictConfig, log_dir: str, prefix: str, save_best_model: bool, postfix: str = ".nemo" |
| ) -> None: |
| if not HAVE_CLEARML_LOGGER: |
| raise ImportError( |
| "Found create_clearml_logger is True." |
| "But ClearML not found. Please see the README for installation instructions:" |
| "https://github.com/allegroai/clearml" |
| ) |
|
|
| self.clearml_task = None |
| self.clearml_model = None |
| self.clearml_cfg = clearml_cfg |
| self.path_nemo_model = os.path.abspath( |
| os.path.expanduser(os.path.join(log_dir, "checkpoints", prefix + postfix)) |
| ) |
| self.save_best_model = save_best_model |
| self.prefix = prefix |
| self.previos_best_model_path = None |
| self.last_metrics = None |
| self.save_blocked = True |
|
|
| self.project_name = os.getenv("CLEARML_PROJECT", clearml_cfg.project if clearml_cfg.project else "NeMo") |
| self.task_name = os.getenv("CLEARML_TASK", clearml_cfg.task if clearml_cfg.task else f"Trainer {self.prefix}") |
|
|
| tags = ["NeMo"] |
| if clearml_cfg.tags: |
| tags.extend(clearml_cfg.tags) |
|
|
| self.clearml_task: Task = Task.init( |
| project_name=self.project_name, |
| task_name=self.task_name, |
| auto_connect_frameworks={"pytorch": clearml_cfg.connect_pytorch}, |
| output_uri=True, |
| tags=tags, |
| ) |
|
|
| if clearml_cfg.model_name: |
| model_name = clearml_cfg.model_name |
| elif self.prefix: |
| model_name = self.prefix |
| else: |
| model_name = self.task_name |
|
|
| if clearml_cfg.log_model: |
| self.clearml_model: OutputModel = OutputModel( |
| name=model_name, task=self.clearml_task, tags=tags, framework="NeMo" |
| ) |
|
|
| def log_hyperparams(self, params, *args, **kwargs) -> None: |
| if self.clearml_model and self.clearml_cfg.log_cfg: |
| if isinstance(params, Namespace): |
| params = vars(params) |
| elif isinstance(params, AttributeDict): |
| params = dict(params) |
| params = apply_to_collection(params, (DictConfig, ListConfig), OmegaConf.to_container, resolve=True) |
| params = apply_to_collection(params, Path, str) |
| params = OmegaConf.to_yaml(params) |
| self.clearml_model.update_design(config_text=params) |
|
|
| def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: |
| if self.clearml_model and self.clearml_cfg.log_metrics: |
| metrics = { |
| k: { |
| "value": str(v.item() if type(v) == Tensor else v), |
| "type": str(type(v.item() if type(v) == Tensor else v)), |
| } |
| for k, v in metrics.items() |
| } |
| self.last_metrics = metrics |
|
|
| def log_table( |
| self, |
| key: str, |
| columns: List[str] = None, |
| data: List[List[Any]] = None, |
| dataframe: Any = None, |
| step: Optional[int] = None, |
| ) -> None: |
| table: Optional[Union[pd.DataFrame, List[List[Any]]]] = None |
|
|
| if dataframe is not None: |
| table = dataframe |
| if columns is not None: |
| table.columns = columns |
|
|
| if data is not None: |
| table = data |
| assert len(columns) == len(table[0]), "number of column names should match the total number of columns" |
| table.insert(0, columns) |
|
|
| if table is not None: |
| self.clearml_task.logger.report_table(title=key, series=key, iteration=step, table_plot=table) |
|
|
| def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None: |
| if self.clearml_model: |
| if self.save_best_model: |
| if self.save_blocked: |
| self.save_blocked = False |
| return None |
| if not os.path.exists(checkpoint_callback.best_model_path): |
| return None |
| if self.previos_best_model_path == checkpoint_callback.best_model_path: |
| return None |
| self.previos_best_model_path = checkpoint_callback.best_model_path |
| self._log_model(self.path_nemo_model) |
|
|
| def finalize(self, status: Literal["success", "failed", "aborted"] = "success") -> None: |
| if status == "success": |
| self.clearml_task.mark_completed() |
| elif status == "failed": |
| self.clearml_task.mark_failed() |
| elif status == "aborted": |
| self.clearml_task.mark_stopped() |
|
|
| def _log_model(self, save_path: str) -> None: |
| if self.clearml_model: |
| if os.path.exists(save_path): |
| self.clearml_model.update_weights( |
| weights_filename=save_path, |
| upload_uri=self.clearml_task.storage_uri or self.clearml_task._get_default_report_storage_uri(), |
| auto_delete_file=False, |
| is_package=True, |
| ) |
|
|
| if self.clearml_cfg.log_metrics and self.last_metrics: |
| self.clearml_model.set_all_metadata(self.last_metrics) |
|
|
| self.save_blocked = True |
| else: |
| logging.warning((f"Logging model enabled, but cant find .nemo file!" f" Path: {save_path}")) |
|
|