|
|
import os |
|
|
import logging |
|
|
|
|
|
import torch.distributed as dist |
|
|
|
|
|
class RankFilter(logging.Filter): |
|
|
def __init__(self, rank): |
|
|
super().__init__() |
|
|
self.rank = rank |
|
|
|
|
|
def filter(self, record): |
|
|
return dist.get_rank() == self.rank |
|
|
|
|
|
def create_logger(log_path): |
|
|
|
|
|
if os.path.isdir(os.path.dirname(log_path)): |
|
|
os.makedirs(os.path.dirname(log_path), exist_ok=True) |
|
|
|
|
|
|
|
|
logger = logging.getLogger() |
|
|
logger.setLevel(logging.INFO) |
|
|
|
|
|
fh = logging.FileHandler(log_path) |
|
|
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') |
|
|
fh.setFormatter(formatter) |
|
|
|
|
|
|
|
|
logger.addHandler(fh) |
|
|
|
|
|
|
|
|
sh = logging.StreamHandler() |
|
|
sh.setLevel(logging.INFO) |
|
|
sh.setFormatter(formatter) |
|
|
logger.addHandler(sh) |
|
|
|
|
|
return logger |