| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import pathlib |
| from abc import ABC, abstractmethod |
| from enum import Enum |
| from typing import Dict, Any, Callable, Optional |
|
|
| import dllogger |
| import torch.distributed as dist |
| import wandb |
| from dllogger import Verbosity |
|
|
| from se3_transformer.runtime.utils import rank_zero_only |
|
|
|
|
| class Logger(ABC): |
| @rank_zero_only |
| @abstractmethod |
| def log_hyperparams(self, params): |
| pass |
|
|
| @rank_zero_only |
| @abstractmethod |
| def log_metrics(self, metrics, step=None): |
| pass |
|
|
| @staticmethod |
| def _sanitize_params(params): |
| def _sanitize(val): |
| if isinstance(val, Callable): |
| try: |
| _val = val() |
| if isinstance(_val, Callable): |
| return val.__name__ |
| return _val |
| except Exception: |
| return getattr(val, "__name__", None) |
| elif isinstance(val, pathlib.Path) or isinstance(val, Enum): |
| return str(val) |
| return val |
|
|
| return {key: _sanitize(val) for key, val in params.items()} |
|
|
|
|
| class LoggerCollection(Logger): |
| def __init__(self, loggers): |
| super().__init__() |
| self.loggers = loggers |
|
|
| def __getitem__(self, index): |
| return [logger for logger in self.loggers][index] |
|
|
| @rank_zero_only |
| def log_metrics(self, metrics, step=None): |
| for logger in self.loggers: |
| logger.log_metrics(metrics, step) |
|
|
| @rank_zero_only |
| def log_hyperparams(self, params): |
| for logger in self.loggers: |
| logger.log_hyperparams(params) |
|
|
|
|
| class DLLogger(Logger): |
| def __init__(self, save_dir: pathlib.Path, filename: str): |
| super().__init__() |
| if not dist.is_initialized() or dist.get_rank() == 0: |
| save_dir.mkdir(parents=True, exist_ok=True) |
| dllogger.init( |
| backends=[dllogger.JSONStreamBackend(Verbosity.DEFAULT, str(save_dir / filename))]) |
|
|
| @rank_zero_only |
| def log_hyperparams(self, params): |
| params = self._sanitize_params(params) |
| dllogger.log(step="PARAMETER", data=params) |
|
|
| @rank_zero_only |
| def log_metrics(self, metrics, step=None): |
| if step is None: |
| step = tuple() |
|
|
| dllogger.log(step=step, data=metrics) |
|
|
|
|
| class WandbLogger(Logger): |
| def __init__( |
| self, |
| name: str, |
| save_dir: pathlib.Path, |
| id: Optional[str] = None, |
| project: Optional[str] = None |
| ): |
| super().__init__() |
| if not dist.is_initialized() or dist.get_rank() == 0: |
| save_dir.mkdir(parents=True, exist_ok=True) |
| self.experiment = wandb.init(name=name, |
| project=project, |
| id=id, |
| dir=str(save_dir), |
| resume='allow', |
| anonymous='must') |
|
|
| @rank_zero_only |
| def log_hyperparams(self, params: Dict[str, Any]) -> None: |
| params = self._sanitize_params(params) |
| self.experiment.config.update(params, allow_val_change=True) |
|
|
| @rank_zero_only |
| def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: |
| if step is not None: |
| self.experiment.log({**metrics, 'epoch': step}) |
| else: |
| self.experiment.log(metrics) |
|
|