Spaces:
Runtime error
Runtime error
| """ | |
| Some simple logging functionality, inspired by rllab's logging. | |
| Logs to a tab-separated-values file (path/to/output_directory/progress.txt) | |
| """ | |
| import atexit | |
| import os | |
| import os.path as osp | |
| import time | |
| import warnings | |
| import joblib | |
| import numpy as np | |
| import torch | |
| import wandb | |
| color2num = dict(gray=30, red=31, green=32, yellow=33, blue=34, magenta=35, cyan=36, white=37, crimson=38) | |
| def setup_logger_kwargs(exp_name, seed=None, data_dir=None, datestamp=True): | |
| """ | |
| Sets up the output_dir for a logger and returns a dict for logger kwargs. | |
| If no seed is given and datestamp is false, | |
| :: | |
| output_dir = data_dir/exp_name | |
| If a seed is given and datestamp is false, | |
| :: | |
| output_dir = data_dir/exp_name/exp_name_s[seed] | |
| If datestamp is true, amend to | |
| :: | |
| output_dir = data_dir/YY-MM-DD_exp_name/YY-MM-DD_HH-MM-SS_exp_name_s[seed] | |
| You can force datestamp=True by setting ``FORCE_DATESTAMP=True`` in | |
| ``spinup/user_config.py``. | |
| Args: | |
| exp_name (string): Name for experiment. | |
| seed (int): Seed for random number generators used by experiment. | |
| data_dir (string): Path to folder where results should be saved. | |
| Default is the ``DEFAULT_DATA_DIR`` in ``spinup/user_config.py``. | |
| datestamp (bool): Whether to include a date and timestamp in the | |
| name of the save directory. | |
| Returns: | |
| logger_kwargs, a dict containing output_dir and exp_name. | |
| """ | |
| if data_dir is None: | |
| data_dir = osp.join(osp.abspath(osp.dirname(osp.dirname(osp.dirname(__file__)))), "logs") | |
| # Make base path | |
| ymd_time = time.strftime("%Y-%m-%d_") if datestamp else "" | |
| relpath = "".join([ymd_time, exp_name]) | |
| if seed is not None: | |
| # Make a seed-specific subfolder in the experiment directory. | |
| if datestamp: | |
| hms_time = time.strftime("%Y-%m-%d_%H-%M-%S") | |
| subfolder = "".join([hms_time, "-", exp_name, "_s", str(seed)]) | |
| else: | |
| subfolder = "".join([exp_name, "_s", str(seed)]) | |
| relpath = osp.join(relpath, subfolder) | |
| logger_kwargs = dict(output_dir=osp.join(data_dir, relpath), exp_name=exp_name) | |
| return logger_kwargs | |
| def colorize(string, color, bold=False, highlight=False): | |
| """ | |
| Colorize a string. | |
| This function was originally written by John Schulman. | |
| """ | |
| attr = [] | |
| num = color2num[color] | |
| if highlight: | |
| num += 10 | |
| attr.append(str(num)) | |
| if bold: | |
| attr.append("1") | |
| return "\x1b[%sm%s\x1b[0m" % (";".join(attr), string) | |
| class Logger: | |
| """ | |
| A general-purpose logger. | |
| Makes it easy to save diagnostics, hyperparameter configurations, the | |
| state of a training run, and the trained model. | |
| """ | |
| def __init__( | |
| self, | |
| log_to_wandb=False, | |
| verbose=False, | |
| output_dir=None, | |
| output_fname="progress.csv", | |
| delimeter=",", | |
| exp_name=None, | |
| wandbcommit=1, | |
| ): | |
| """ | |
| Initialize a Logger. | |
| Args: | |
| log_to_wandb (bool): If True logger will log to wandb | |
| output_dir (string): A directory for saving results to. If | |
| ``None``, defaults to a temp directory of the form | |
| ``/tmp/experiments/somerandomnumber``. | |
| output_fname (string): Name for the tab-separated-value file | |
| containing metrics logged throughout a training run. | |
| Defaults to ``progress.csv``. | |
| exp_name (string): Experiment name. If you run multiple training | |
| runs and give them all the same ``exp_name``, the plotter | |
| will know to group them. (Use case: if you run the same | |
| hyperparameter configuration with multiple random seeds, you | |
| should give them all the same ``exp_name``.) | |
| delimeter (string): Used to separate logged values saved in output_fname | |
| """ | |
| self.verbose = verbose | |
| self.log_to_wandb = log_to_wandb | |
| self.delimeter = delimeter | |
| self.wandbcommit = wandbcommit | |
| self.log_iter = 1 | |
| # We assume that there's no multiprocessing. | |
| if output_dir is not None: | |
| self.output_dir = output_dir or "/tmp/experiments/%i" % int(time.time()) | |
| if osp.exists(self.output_dir): | |
| print("Warning: Log dir %s already exists! Storing info there anyway." % self.output_dir) | |
| else: | |
| os.makedirs(self.output_dir) | |
| self.output_file = open(osp.join(self.output_dir, output_fname), "w+") | |
| atexit.register(self.output_file.close) | |
| print(colorize("Logging data to %s" % self.output_file.name, "green", bold=True)) | |
| else: | |
| self.output_file = None | |
| self.first_row = True | |
| self.log_headers = [] | |
| self.log_current_row = {} | |
| self.exp_name = exp_name | |
| def log(self, msg, color="green"): | |
| """Print a colorized message to stdout.""" | |
| print(colorize(msg, color, bold=True)) | |
| def log_tabular(self, key, val): | |
| """ | |
| Log a value of some diagnostic. | |
| Call this only once for each diagnostic quantity, each iteration. | |
| After using ``log_tabular`` to store values for each diagnostic, | |
| make sure to call ``dump_tabular`` to write them out to file and | |
| stdout (otherwise they will not get saved anywhere). | |
| """ | |
| if self.first_row: | |
| self.log_headers.append(key) | |
| else: | |
| if key not in self.log_headers: | |
| self.log_headers.append(key) | |
| if self.output_file is not None: | |
| # move pointer at the beggining of the file | |
| self.output_file.seek(0) | |
| # skip the header | |
| self.output_file.readline() | |
| # keep rest of the file | |
| logs = self.output_file.read() | |
| # clear the file | |
| self.output_file.truncate(0) | |
| self.output_file.seek(0) | |
| # write new headers | |
| self.output_file.write(self.delimeter.join(self.log_headers) + "\n") | |
| # write stored file | |
| self.output_file.write(logs) | |
| self.output_file.seek(0) | |
| self.output_file.seek(0, 2) | |
| # assert key in self.log_headers, ( | |
| # "Trying to introduce a new key %s that you didn't include in the first iteration" % key | |
| # ) | |
| assert key not in self.log_current_row, ( | |
| "You already set %s this iteration. Maybe you forgot to call dump_tabular()" % key | |
| ) | |
| self.log_current_row[key] = val | |
| def save_state(self, state_dict, itr=None): | |
| """ | |
| Saves the state of an experiment. | |
| To be clear: this is about saving *state*, not logging diagnostics. | |
| All diagnostic logging is separate from this function. This function | |
| will save whatever is in ``state_dict``---usually just a copy of the | |
| environment---and the most recent parameters for the model you | |
| previously set up saving for with ``setup_tf_saver``. | |
| Call with any frequency you prefer. If you only want to maintain a | |
| single state and overwrite it at each call with the most recent | |
| version, leave ``itr=None``. If you want to keep all of the states you | |
| save, provide unique (increasing) values for 'itr'. | |
| Args: | |
| state_dict (dict): Dictionary containing essential elements to | |
| describe the current state of training. | |
| itr: An int, or None. Current iteration of training. | |
| """ | |
| fname = "vars.pkl" if itr is None else "vars%d.pkl" % itr | |
| try: | |
| joblib.dump(state_dict, osp.join(self.output_dir, fname)) | |
| except: | |
| self.log("Warning: could not pickle state_dict.", color="red") | |
| if hasattr(self, "pytorch_saver_elements"): | |
| self._pytorch_simple_save(itr) | |
| def setup_pytorch_saver(self, what_to_save): | |
| """ | |
| Set up easy model saving for a single PyTorch model. | |
| Because PyTorch saving and loading is especially painless, this is | |
| very minimal; we just need references to whatever we would like to | |
| pickle. This is integrated into the logger because the logger | |
| knows where the user would like to save information about this | |
| training run. | |
| Args: | |
| what_to_save: Any PyTorch model or serializable object containing | |
| PyTorch models. | |
| """ | |
| self.pytorch_saver_elements = what_to_save | |
| def _pytorch_simple_save(self, itr=None): | |
| """ | |
| Saves the PyTorch model (or models). | |
| """ | |
| assert hasattr(self, "pytorch_saver_elements"), "First have to setup saving with self.setup_pytorch_saver" | |
| fpath = "pyt_save" | |
| fpath = osp.join(self.output_dir, fpath) | |
| fname = "model" + ("%d" % itr if itr is not None else "") + ".pt" | |
| fname = osp.join(fpath, fname) | |
| os.makedirs(fpath, exist_ok=True) | |
| with warnings.catch_warnings(): | |
| warnings.simplefilter("ignore") | |
| # We are using a non-recommended way of saving PyTorch models, | |
| # by pickling whole objects (which are dependent on the exact | |
| # directory structure at the time of saving) as opposed to | |
| # just saving network weights. This works sufficiently well | |
| # for the purposes of Spinning Up, but you may want to do | |
| # something different for your personal PyTorch project. | |
| # We use a catch_warnings() context to avoid the warnings about | |
| # not being able to save the source code. | |
| torch.save(self.pytorch_saver_elements, fname) | |
| def dump_tabular(self): | |
| """ | |
| Write all of the diagnostics from the current iteration. | |
| Writes both to stdout, and to the output file. | |
| """ | |
| vals = [] | |
| key_lens = [len(key) for key in self.log_headers] | |
| max_key_len = max(15, max(key_lens)) | |
| keystr = "%" + "%d" % max_key_len | |
| fmt = "| " + keystr + "s | %15s |" | |
| n_slashes = 22 + max_key_len | |
| step = self.log_current_row.get("total_env_steps") | |
| if self.verbose: | |
| print("-" * n_slashes) | |
| for key in self.log_headers: | |
| val = self.log_current_row.get(key, "") | |
| valstr = "%8.3g" % val if isinstance(val, float) else val | |
| print(fmt % (key, valstr)) | |
| vals.append(val) | |
| print("-" * n_slashes, flush=True) | |
| if self.output_file is not None: | |
| if self.first_row: | |
| self.output_file.write(self.delimeter.join(self.log_headers) + "\n") | |
| self.output_file.write(self.delimeter.join(map(str, vals)) + "\n") | |
| self.output_file.flush() | |
| key_val_dict = {key: self.log_current_row.get(key, "") for key in self.log_headers} | |
| if self.log_to_wandb: | |
| if self.log_iter % self.wandbcommit == 0: | |
| wandb.log(key_val_dict, step=step, commit=True) | |
| else: | |
| wandb.log(key_val_dict, step=step, commit=False) | |
| self.log_current_row.clear() | |
| self.first_row = False | |
| self.log_iter += 1 | |
| return key_val_dict | |
| class EpochLogger(Logger): | |
| """ | |
| A variant of Logger tailored for tracking average values over epochs. | |
| Typical use case: there is some quantity which is calculated many times | |
| throughout an epoch, and at the end of the epoch, you would like to | |
| report the average / std / min / max value of that quantity. | |
| With an EpochLogger, each time the quantity is calculated, you would | |
| use | |
| .. code-block:: python | |
| epoch_logger.store(NameOfQuantity=quantity_value) | |
| to load it into the EpochLogger's state. Then at the end of the epoch, you | |
| would use | |
| .. code-block:: python | |
| epoch_logger.log_tabular(NameOfQuantity, **options) | |
| to record the desired values. | |
| """ | |
| def __init__(self, *args, **kwargs): | |
| super().__init__(*args, **kwargs) | |
| self.epoch_dict = dict() | |
| def store(self, d): | |
| """ | |
| Save something into the epoch_logger's current state. | |
| Provide an arbitrary number of keyword arguments with numerical | |
| values. | |
| """ | |
| for k, v in d.items(): | |
| if not (k in self.epoch_dict.keys()): | |
| self.epoch_dict[k] = [] | |
| self.epoch_dict[k].append(v) | |
| def log_tabular(self, key, val=None, with_min_and_max=False, with_median=False, with_sum=False, average_only=False): | |
| """ | |
| Log a value or possibly the mean/std/min/max values of a diagnostic. | |
| Args: | |
| key (string): The name of the diagnostic. If you are logging a | |
| diagnostic whose state has previously been saved with | |
| ``store``, the key here has to match the key you used there. | |
| val: A value for the diagnostic. If you have previously saved | |
| values for this key via ``store``, do *not* provide a ``val`` | |
| here. | |
| with_min_and_max (bool): If true, log min and max values of the | |
| diagnostic over the epoch. | |
| average_only (bool): If true, do not log the standard deviation | |
| of the diagnostic over the epoch. | |
| """ | |
| if val is not None: | |
| super().log_tabular(key, val) | |
| else: | |
| stats = self.get_stats(key) | |
| super().log_tabular(key if average_only else key + "/avg", stats[0]) | |
| if not (average_only): | |
| super().log_tabular(key + "/std", stats[1]) | |
| if with_min_and_max: | |
| super().log_tabular(key + "/max", stats[3]) | |
| super().log_tabular(key + "/min", stats[2]) | |
| if with_median: | |
| super().log_tabular(key + "/med", stats[4]) | |
| if with_sum: | |
| super().log_tabular(key + "/sum", stats[5]) | |
| self.epoch_dict[key] = [] | |
| def get_stats(self, key): | |
| """ | |
| Lets an algorithm ask the logger for mean/std/min/max of a diagnostic. | |
| """ | |
| v = self.epoch_dict.get(key) | |
| if not v: | |
| return [np.nan, np.nan, np.nan, np.nan] | |
| vals = np.concatenate(v) if isinstance(v[0], np.ndarray) and len(v[0].shape) > 0 else v | |
| return [np.mean(vals), np.std(vals), np.min(vals), np.max(vals), np.median(vals), np.sum(vals)] | |