Spaces:
Running
Running
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) Megvii Inc. All rights reserved. | |
| import inspect | |
| import os | |
| import sys | |
| from loguru import logger | |
| import torch | |
| def get_caller_name(depth=0): | |
| """ | |
| Args: | |
| depth (int): Depth of caller conext, use 0 for caller depth. | |
| Default value: 0. | |
| Returns: | |
| str: module name of the caller | |
| """ | |
| # the following logic is a little bit faster than inspect.stack() logic | |
| frame = inspect.currentframe().f_back | |
| for _ in range(depth): | |
| frame = frame.f_back | |
| return frame.f_globals["__name__"] | |
| class StreamToLoguru: | |
| """ | |
| stream object that redirects writes to a logger instance. | |
| """ | |
| def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): | |
| """ | |
| Args: | |
| level(str): log level string of loguru. Default value: "INFO". | |
| caller_names(tuple): caller names of redirected module. | |
| Default value: (apex, pycocotools). | |
| """ | |
| self.level = level | |
| self.linebuf = "" | |
| self.caller_names = caller_names | |
| def write(self, buf): | |
| full_name = get_caller_name(depth=1) | |
| module_name = full_name.rsplit(".", maxsplit=-1)[0] | |
| if module_name in self.caller_names: | |
| for line in buf.rstrip().splitlines(): | |
| # use caller level log | |
| logger.opt(depth=2).log(self.level, line.rstrip()) | |
| else: | |
| sys.__stdout__.write(buf) | |
| def flush(self): | |
| pass | |
| def redirect_sys_output(log_level="INFO"): | |
| redirect_logger = StreamToLoguru(log_level) | |
| sys.stderr = redirect_logger | |
| sys.stdout = redirect_logger | |
| def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): | |
| """setup logger for training and testing. | |
| Args: | |
| save_dir(str): location to save log file | |
| distributed_rank(int): device rank when multi-gpu environment | |
| filename (string): log save name. | |
| mode(str): log file write mode, `append` or `override`. default is `a`. | |
| Return: | |
| logger instance. | |
| """ | |
| loguru_format = ( | |
| "<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " | |
| "<level>{level: <8}</level> | " | |
| "<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" | |
| ) | |
| logger.remove() | |
| save_file = os.path.join(save_dir, filename) | |
| if mode == "o" and os.path.exists(save_file): | |
| os.remove(save_file) | |
| # only keep logger in rank0 process | |
| if distributed_rank == 0: | |
| logger.add( | |
| sys.stderr, | |
| format=loguru_format, | |
| level="INFO", | |
| enqueue=True, | |
| ) | |
| logger.add(save_file) | |
| # redirect stdout/stderr to loguru | |
| redirect_sys_output("INFO") | |
| class WandbLogger(object): | |
| """ | |
| Log training runs, datasets, models, and predictions to Weights & Biases. | |
| This logger sends information to W&B at wandb.ai. | |
| By default, this information includes hyperparameters, | |
| system configuration and metrics, model metrics, | |
| and basic data metrics and analyses. | |
| For more information, please refer to: | |
| https://docs.wandb.ai/guides/track | |
| """ | |
| def __init__(self, | |
| project=None, | |
| name=None, | |
| id=None, | |
| entity=None, | |
| save_dir=None, | |
| config=None, | |
| **kwargs): | |
| """ | |
| Args: | |
| project (str): wandb project name. | |
| name (str): wandb run name. | |
| id (str): wandb run id. | |
| entity (str): wandb entity name. | |
| save_dir (str): save directory. | |
| config (dict): config dict. | |
| **kwargs: other kwargs. | |
| """ | |
| try: | |
| import wandb | |
| self.wandb = wandb | |
| except ModuleNotFoundError: | |
| raise ModuleNotFoundError( | |
| "wandb is not installed." | |
| "Please install wandb using pip install wandb" | |
| ) | |
| self.project = project | |
| self.name = name | |
| self.id = id | |
| self.save_dir = save_dir | |
| self.config = config | |
| self.kwargs = kwargs | |
| self.entity = entity | |
| self._run = None | |
| self._wandb_init = dict( | |
| project=self.project, | |
| name=self.name, | |
| id=self.id, | |
| entity=self.entity, | |
| dir=self.save_dir, | |
| resume="allow" | |
| ) | |
| self._wandb_init.update(**kwargs) | |
| _ = self.run | |
| if self.config: | |
| self.run.config.update(self.config) | |
| self.run.define_metric("epoch") | |
| self.run.define_metric("val/", step_metric="epoch") | |
| def run(self): | |
| if self._run is None: | |
| if self.wandb.run is not None: | |
| logger.info( | |
| "There is a wandb run already in progress " | |
| "and newly created instances of `WandbLogger` will reuse" | |
| " this run. If this is not desired, call `wandb.finish()`" | |
| "before instantiating `WandbLogger`." | |
| ) | |
| self._run = self.wandb.run | |
| else: | |
| self._run = self.wandb.init(**self._wandb_init) | |
| return self._run | |
| def log_metrics(self, metrics, step=None): | |
| """ | |
| Args: | |
| metrics (dict): metrics dict. | |
| step (int): step number. | |
| """ | |
| for k, v in metrics.items(): | |
| if isinstance(v, torch.Tensor): | |
| metrics[k] = v.item() | |
| if step is not None: | |
| self.run.log(metrics, step=step) | |
| else: | |
| self.run.log(metrics) | |
| def save_checkpoint(self, save_dir, model_name, is_best): | |
| """ | |
| Args: | |
| save_dir (str): save directory. | |
| model_name (str): model name. | |
| is_best (bool): whether the model is the best model. | |
| """ | |
| filename = os.path.join(save_dir, model_name + "_ckpt.pth") | |
| artifact = self.wandb.Artifact( | |
| name=f"model-{self.run.id}", | |
| type="model" | |
| ) | |
| artifact.add_file(filename, name="model_ckpt.pth") | |
| aliases = ["latest"] | |
| if is_best: | |
| aliases.append("best") | |
| self.run.log_artifact(artifact, aliases=aliases) | |
| def finish(self): | |
| self.run.finish() | |