Spaces:
Runtime error
Runtime error
| import os | |
| import logging | |
| import torch.distributed as dist | |
| class RankFilter(logging.Filter): | |
| def __init__(self, rank): | |
| super().__init__() | |
| self.rank = rank | |
| def filter(self, record): | |
| return dist.get_rank() == self.rank | |
| def create_logger(log_path): | |
| # Create log path | |
| if os.path.isdir(os.path.dirname(log_path)): | |
| os.makedirs(os.path.dirname(log_path), exist_ok=True) | |
| # Create logger object | |
| logger = logging.getLogger() | |
| logger.setLevel(logging.INFO) | |
| # Create file handler and set the formatter | |
| fh = logging.FileHandler(log_path) | |
| formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') | |
| fh.setFormatter(formatter) | |
| # Add the file handler to the logger | |
| logger.addHandler(fh) | |
| # Add a stream handler to print to console | |
| sh = logging.StreamHandler() | |
| sh.setLevel(logging.INFO) # Set logging level for stream handler | |
| sh.setFormatter(formatter) | |
| logger.addHandler(sh) | |
| return logger |