| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the CC-by-NC license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import logging | |
| import os | |
| from logging import Logger | |
| from pathlib import Path | |
| from typing import Optional | |
| import torch | |
| import wandb | |
| from omegaconf import OmegaConf | |
| def get_logger(log_path: str, rank: int): | |
| if rank != 0: | |
| return logging.getLogger("dummy") | |
| logger = logging.getLogger() | |
| default_level = logging.INFO | |
| if logger.hasHandlers(): | |
| logger.handlers.clear() | |
| logger.setLevel(default_level) | |
| formatter = logging.Formatter( | |
| "%(levelname)s | %(asctime)s | %(message)s", "%Y-%m-%d %H:%M:%S" | |
| ) | |
| info_file_handler = logging.FileHandler(log_path, mode="a") | |
| info_file_handler.setLevel(default_level) | |
| info_file_handler.setFormatter(formatter) | |
| logger.addHandler(info_file_handler) | |
| console_handler = logging.StreamHandler() | |
| console_handler.setLevel(default_level) | |
| console_handler.setFormatter(formatter) | |
| logger.addHandler(console_handler) | |
| return logger | |
| class TrainLogger: | |
| def __init__(self, log_dir: Path, rank: int, cfg: bool = False): | |
| self.log_dir = log_dir | |
| self.cfg = cfg | |
| self._init_text_logger(rank=rank) | |
| self.enable_wandb = self.cfg.logging.enable_wandb and (rank == 0) | |
| if self.enable_wandb: | |
| self._init_wandb() | |
| def _init_text_logger(self, rank: int): | |
| log_path = self.log_dir / self.cfg.logging.log_file_name | |
| self._logger = get_logger(log_path=log_path, rank=rank) | |
| def _init_wandb( | |
| self, | |
| ): | |
| wandb_run_id_path = self.log_dir / "wandb_run.id" | |
| try: | |
| wandb_run_id = wandb_run_id_path.read_text() | |
| except FileNotFoundError: | |
| wandb_run_id = wandb.util.generate_id() | |
| wandb_run_id_path.write_text(wandb_run_id) | |
| self.wandb_logger = wandb.init( | |
| id=wandb_run_id, | |
| project=self.cfg.logging.project, | |
| group=self.cfg.logging.group, | |
| dir=self.log_dir, | |
| entity=self.cfg.logging.entity, | |
| resume="allow", | |
| config=OmegaConf.to_container(self.cfg, resolve=True), | |
| ) | |
| def log_metric(self, value: float, name: str, stage: bool, step: int) -> None: | |
| self._logger.info(f"[{step}] {stage} {name}: {value:.3f}") | |
| if self.enable_wandb: | |
| self.wandb_logger.log(data={f"{stage}/{name}": value}, step=step) | |
| def log_lr(self, value: float, step: int) -> None: | |
| if self.enable_wandb: | |
| self.wandb_logger.log(data={"Optimization/LR": value}, step=step) | |
| def info(self, msg: str, step: Optional[int] = None) -> None: | |
| step_str = f"[{step}] " if step else "" | |
| self._logger.info(f"{step_str}{msg}") | |
| def warning(self, msg: str) -> None: | |
| self._logger.warning(msg) | |
| def finish(self) -> None: | |
| for handler in self._logger.handlers: | |
| if isinstance(handler, logging.FileHandler): | |
| handler.close() | |
| if self.enable_wandb: | |
| wandb.finish() | |
| def log_devices(device: torch.device, logger: Logger) -> None: | |
| if device.type == "cuda": | |
| logger.info("Found {} CUDA devices.".format(torch.cuda.device_count())) | |
| for i in range(torch.cuda.device_count()): | |
| props = torch.cuda.get_device_properties(i) | |
| logger.info( | |
| "{} \t Memory: {:.2f}GB".format( | |
| props.name, props.total_memory / (1024**3) | |
| ) | |
| ) | |
| else: | |
| logger.warning("WARNING: Using device {}".format(device)) | |
| logger.info(f"Found {os.cpu_count()} total number of CPUs.") | |
Xet Storage Details
- Size:
- 3.84 kB
- Xet hash:
- ddb56c582248cd862acdf24ecd7e09da8a3e9fb3f2008143f0f77d10a0c04c38
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.