| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import pathlib |
| | from abc import ABC, abstractmethod |
| | from enum import Enum |
| | from typing import Dict, Any, Callable, Optional |
| |
|
| | import dllogger |
| | import torch.distributed as dist |
| | import wandb |
| | from dllogger import Verbosity |
| |
|
| | from se3_transformer.runtime.utils import rank_zero_only |
| |
|
| |
|
| | class Logger(ABC): |
| | @rank_zero_only |
| | @abstractmethod |
| | def log_hyperparams(self, params): |
| | pass |
| |
|
| | @rank_zero_only |
| | @abstractmethod |
| | def log_metrics(self, metrics, step=None): |
| | pass |
| |
|
| | @staticmethod |
| | def _sanitize_params(params): |
| | def _sanitize(val): |
| | if isinstance(val, Callable): |
| | try: |
| | _val = val() |
| | if isinstance(_val, Callable): |
| | return val.__name__ |
| | return _val |
| | except Exception: |
| | return getattr(val, "__name__", None) |
| | elif isinstance(val, pathlib.Path) or isinstance(val, Enum): |
| | return str(val) |
| | return val |
| |
|
| | return {key: _sanitize(val) for key, val in params.items()} |
| |
|
| |
|
| | class LoggerCollection(Logger): |
| | def __init__(self, loggers): |
| | super().__init__() |
| | self.loggers = loggers |
| |
|
| | def __getitem__(self, index): |
| | return [logger for logger in self.loggers][index] |
| |
|
| | @rank_zero_only |
| | def log_metrics(self, metrics, step=None): |
| | for logger in self.loggers: |
| | logger.log_metrics(metrics, step) |
| |
|
| | @rank_zero_only |
| | def log_hyperparams(self, params): |
| | for logger in self.loggers: |
| | logger.log_hyperparams(params) |
| |
|
| |
|
| | class DLLogger(Logger): |
| | def __init__(self, save_dir: pathlib.Path, filename: str): |
| | super().__init__() |
| | if not dist.is_initialized() or dist.get_rank() == 0: |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| | dllogger.init( |
| | backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))]) |
| |
|
| | @rank_zero_only |
| | def log_hyperparams(self, params): |
| | params = self._sanitize_params(params) |
| | dllogger.log(step="PARAMETER", data=params) |
| |
|
| | @rank_zero_only |
| | def log_metrics(self, metrics, step=None): |
| | if step is None: |
| | step = tuple() |
| |
|
| | dllogger.log(step=step, data=metrics) |
| |
|
| |
|
| | class WandbLogger(Logger): |
| | def __init__( |
| | self, |
| | name: str, |
| | save_dir: pathlib.Path, |
| | id: Optional[str] = None, |
| | project: Optional[str] = None |
| | ): |
| | super().__init__() |
| | if not dist.is_initialized() or dist.get_rank() == 0: |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| | self.experiment = wandb.init(name=name, |
| | project=project, |
| | id=id, |
| | dir=str(save_dir), |
| | resume='allow', |
| | anonymous='must') |
| |
|
| | @rank_zero_only |
| | def log_hyperparams(self, params: Dict[str, Any]) -> None: |
| | params = self._sanitize_params(params) |
| | self.experiment.config.update(params, allow_val_change=True) |
| |
|
| | @rank_zero_only |
| | def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: |
| | if step is not None: |
| | self.experiment.log({**metrics, 'epoch': step}) |
| | else: |
| | self.experiment.log(metrics) |
| |
|