| import os |
| import re |
| import numpy as np |
| import logging |
|
|
| logs = set() |
|
|
|
|
| def init_log(name, level=logging.INFO): |
| if (name, level) in logs: |
| return |
| logs.add((name, level)) |
| logger = logging.getLogger(name) |
| logger.setLevel(level) |
| ch = logging.StreamHandler() |
| ch.setLevel(level) |
| if "SLURM_PROCID" in os.environ: |
| rank = int(os.environ["SLURM_PROCID"]) |
| logger.addFilter(lambda record: rank == 0) |
| else: |
| rank = 0 |
| format_str = "[%(asctime)s][%(levelname)8s] %(message)s" |
| formatter = logging.Formatter(format_str) |
| ch.setFormatter(formatter) |
| logger.addHandler(ch) |
| return logger |
|
|