Spaces:
Sleeping
Sleeping
File size: 8,355 Bytes
96da58e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 |
"""
This file contains utility classes and functions for logging to stdout, stderr,
and to tensorboard.
"""
import os
import sys
import numpy as np
from datetime import datetime
from contextlib import contextmanager
import textwrap
import time
from tqdm import tqdm
from termcolor import colored
import robomimic
# global list of warning messages can be populated with @log_warning and flushed with @flush_warnings
WARNINGS_BUFFER = []
class PrintLogger(object):
"""
This class redirects print statements to both console and a file.
"""
def __init__(self, log_file):
self.terminal = sys.stdout
print('STDOUT will be forked to %s' % log_file)
self.log_file = open(log_file, "a")
def fileno(self):
return self.terminal.fileno()
def write(self, message):
self.terminal.write(message)
self.log_file.write(message)
self.log_file.flush()
def flush(self):
# this flush method is needed for python 3 compatibility.
# this handles the flush command by doing nothing.
# you might want to specify some extra behavior here.
pass
class DataLogger(object):
"""
Logging class to log metrics to tensorboard and/or retrieve running statistics about logged data.
"""
def __init__(self, log_dir, config, log_tb=True, log_wandb=False):
"""
Args:
log_dir (str): base path to store logs
log_tb (bool): whether to use tensorboard logging
"""
self._tb_logger = None
self._wandb_logger = None
self._data = dict() # store all the scalar data logged so far
if log_tb:
from tensorboardX import SummaryWriter
self._tb_logger = SummaryWriter(os.path.join(log_dir, 'tb'))
if log_wandb:
import wandb
import robomimic.macros as Macros
# set up wandb api key if specified in macros
if Macros.WANDB_API_KEY is not None:
os.environ["WANDB_API_KEY"] = Macros.WANDB_API_KEY
assert Macros.WANDB_ENTITY is not None, "WANDB_ENTITY macro is set to None." \
"\nSet this macro in {base_path}/macros_private.py" \
"\nIf this file does not exist, first run python {base_path}/scripts/setup_macros.py".format(base_path=robomimic.__path__[0])
# attempt to set up wandb 10 times. If unsuccessful after these trials, don't use wandb
num_attempts = 10
for attempt in range(num_attempts):
try:
# set up wandb
self._wandb_logger = wandb
self._wandb_logger.init(
entity=Macros.WANDB_ENTITY,
project=config.experiment.logging.wandb_proj_name,
name=config.experiment.name,
dir=log_dir,
mode=("offline" if attempt == num_attempts - 1 else "online"),
)
# set up info for identifying experiment
wandb_config = {k: v for (k, v) in config.meta.items() if k not in ["hp_keys", "hp_values"]}
for (k, v) in zip(config.meta["hp_keys"], config.meta["hp_values"]):
wandb_config[k] = v
if "algo" not in wandb_config:
wandb_config["algo"] = config.algo_name
self._wandb_logger.config.update(wandb_config)
break
except Exception as e:
log_warning("wandb initialization error (attempt #{}): {}".format(attempt + 1, e))
self._wandb_logger = None
time.sleep(30)
def record(self, k, v, epoch, data_type='scalar', log_stats=False):
"""
Record data with logger.
Args:
k (str): key string
v (float or image): value to store
epoch: current epoch number
data_type (str): the type of data. either 'scalar' or 'image'
log_stats (bool): whether to store the mean/max/min/std for all data logged so far with key k
"""
assert data_type in ['scalar', 'image']
if data_type == 'scalar':
# maybe update internal cache if logging stats for this key
if log_stats or k in self._data: # any key that we're logging or previously logged
if k not in self._data:
self._data[k] = []
self._data[k].append(v)
# maybe log to tensorboard
if self._tb_logger is not None:
if data_type == 'scalar':
self._tb_logger.add_scalar(k, v, epoch)
if log_stats:
stats = self.get_stats(k)
for (stat_k, stat_v) in stats.items():
stat_k_name = '{}-{}'.format(k, stat_k)
self._tb_logger.add_scalar(stat_k_name, stat_v, epoch)
elif data_type == 'image':
self._tb_logger.add_images(k, img_tensor=v, global_step=epoch, dataformats="NHWC")
if self._wandb_logger is not None:
try:
if data_type == 'scalar':
self._wandb_logger.log({k: v}, step=epoch)
if log_stats:
stats = self.get_stats(k)
for (stat_k, stat_v) in stats.items():
self._wandb_logger.log({"{}/{}".format(k, stat_k): stat_v}, step=epoch)
elif data_type == 'image':
raise NotImplementedError
except Exception as e:
log_warning("wandb logging: {}".format(e))
def get_stats(self, k):
"""
Computes running statistics for a particular key.
Args:
k (str): key string
Returns:
stats (dict): dictionary of statistics
"""
stats = dict()
stats['mean'] = np.mean(self._data[k])
stats['std'] = np.std(self._data[k])
stats['min'] = np.min(self._data[k])
stats['max'] = np.max(self._data[k])
return stats
def close(self):
"""
Run before terminating to make sure all logs are flushed
"""
if self._tb_logger is not None:
self._tb_logger.close()
if self._wandb_logger is not None:
self._wandb_logger.finish()
class custom_tqdm(tqdm):
"""
Small extension to tqdm to make a few changes from default behavior.
By default tqdm writes to stderr. Instead, we change it to write
to stdout.
"""
def __init__(self, *args, **kwargs):
assert "file" not in kwargs
super(custom_tqdm, self).__init__(*args, file=sys.stdout, **kwargs)
@contextmanager
def silence_stdout():
"""
This contextmanager will redirect stdout so that nothing is printed
to the terminal. Taken from the link below:
https://stackoverflow.com/questions/6735917/redirecting-stdout-to-nothing-in-python
"""
old_target = sys.stdout
try:
with open(os.devnull, "w") as new_target:
sys.stdout = new_target
yield new_target
finally:
sys.stdout = old_target
def log_warning(message, color="yellow", print_now=True):
"""
This function logs a warning message by recording it in a global warning buffer.
The global registry will be maintained until @flush_warnings is called, at
which point the warnings will get printed to the terminal.
Args:
message (str): warning message to display
color (str): color of message - defaults to "yellow"
print_now (bool): if True (default), will print to terminal immediately, in
addition to adding it to the global warning buffer
"""
global WARNINGS_BUFFER
buffer_message = colored("ROBOMIMIC WARNING(\n{}\n)".format(textwrap.indent(message, " ")), color)
WARNINGS_BUFFER.append(buffer_message)
if print_now:
print(buffer_message)
def flush_warnings():
"""
This function flushes all warnings from the global warning buffer to the terminal and
clears the global registry.
"""
global WARNINGS_BUFFER
for msg in WARNINGS_BUFFER:
print(msg)
WARNINGS_BUFFER = []
|