Spaces:
Running
Running
| from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union | |
| from ditk import logging | |
| from easydict import EasyDict | |
| from matplotlib import pyplot as plt | |
| from matplotlib import animation | |
| import os | |
| import numpy as np | |
| import torch | |
| import wandb | |
| import pickle | |
| import treetensor.numpy as tnp | |
| from ding.framework import task | |
| from ding.envs import BaseEnvManagerV2 | |
| from ding.utils import DistributedWriter | |
| from ding.torch_utils import to_ndarray | |
| from ding.utils.default_helper import one_time_warning | |
| if TYPE_CHECKING: | |
| from ding.framework import OnlineRLContext, OfflineRLContext | |
| def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable: | |
| """ | |
| Overview: | |
| Create an online RL tensorboard logger for recording training and evaluation metrics. | |
| Arguments: | |
| - record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False. | |
| - train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100. | |
| Returns: | |
| - _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. | |
| Raises: | |
| - RuntimeError: If writer is None. | |
| - NotImplementedError: If the key of train_output is not supported, such as "scalars". | |
| Examples: | |
| >>> task.use(online_logger(record_train_iter=False, train_show_freq=1000)) | |
| """ | |
| if task.router.is_active and not task.has_role(task.role.LEARNER): | |
| return task.void() | |
| writer = DistributedWriter.get_instance() | |
| if writer is None: | |
| raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") | |
| last_train_show_iter = -1 | |
| def _logger(ctx: "OnlineRLContext"): | |
| if task.finish: | |
| writer.close() | |
| nonlocal last_train_show_iter | |
| if not np.isinf(ctx.eval_value): | |
| if record_train_iter: | |
| writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step) | |
| writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) | |
| else: | |
| writer.add_scalar('basic/eval_episode_return_mean', ctx.eval_value, ctx.env_step) | |
| if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: | |
| last_train_show_iter = ctx.train_iter | |
| if isinstance(ctx.train_output, List): | |
| output = ctx.train_output.pop() # only use latest output for some algorithms, like PPO | |
| else: | |
| output = ctx.train_output | |
| for k, v in output.items(): | |
| if k in ['priority', 'td_error_priority']: | |
| continue | |
| if "[scalars]" in k: | |
| new_k = k.split(']')[-1] | |
| raise NotImplementedError | |
| elif "[histogram]" in k: | |
| new_k = k.split(']')[-1] | |
| writer.add_histogram(new_k, v, ctx.env_step) | |
| if record_train_iter: | |
| writer.add_histogram(new_k, v, ctx.train_iter) | |
| else: | |
| if record_train_iter: | |
| writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) | |
| writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step) | |
| else: | |
| writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step) | |
| return _logger | |
| def offline_logger(train_show_freq: int = 100) -> Callable: | |
| """ | |
| Overview: | |
| Create an offline RL tensorboard logger for recording training and evaluation metrics. | |
| Arguments: | |
| - train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100. | |
| Returns: | |
| - _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. | |
| Raises: | |
| - RuntimeError: If writer is None. | |
| - NotImplementedError: If the key of train_output is not supported, such as "scalars". | |
| Examples: | |
| >>> task.use(offline_logger(train_show_freq=1000)) | |
| """ | |
| if task.router.is_active and not task.has_role(task.role.LEARNER): | |
| return task.void() | |
| writer = DistributedWriter.get_instance() | |
| if writer is None: | |
| raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.") | |
| last_train_show_iter = -1 | |
| def _logger(ctx: "OfflineRLContext"): | |
| nonlocal last_train_show_iter | |
| if task.finish: | |
| writer.close() | |
| if not np.isinf(ctx.eval_value): | |
| writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter) | |
| if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq: | |
| last_train_show_iter = ctx.train_iter | |
| output = ctx.train_output | |
| for k, v in output.items(): | |
| if k in ['priority']: | |
| continue | |
| if "[scalars]" in k: | |
| new_k = k.split(']')[-1] | |
| raise NotImplementedError | |
| elif "[histogram]" in k: | |
| new_k = k.split(']')[-1] | |
| writer.add_histogram(new_k, v, ctx.train_iter) | |
| else: | |
| writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter) | |
| return _logger | |
| # four utility functions for wandb logger | |
| def softmax(logit: np.ndarray) -> np.ndarray: | |
| v = np.exp(logit) | |
| return v / v.sum(axis=-1, keepdims=True) | |
| def action_prob(num, action_prob, ln): | |
| ax = plt.gca() | |
| ax.set_ylim([0, 1]) | |
| for rect, x in zip(ln, action_prob[num]): | |
| rect.set_height(x) | |
| return ln | |
| def return_prob(num, return_prob, ln): | |
| return ln | |
| def return_distribution(episode_return): | |
| num = len(episode_return) | |
| max_return = max(episode_return) | |
| min_return = min(episode_return) | |
| hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6)) | |
| gap = (max_return - min_return + 100) / 5 | |
| x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)] | |
| return hist / num, x_dim | |
| def wandb_online_logger( | |
| record_path: str = None, | |
| cfg: Union[dict, EasyDict] = None, | |
| exp_config: Union[dict, EasyDict] = None, | |
| metric_list: Optional[List[str]] = None, | |
| env: Optional[BaseEnvManagerV2] = None, | |
| model: Optional[torch.nn.Module] = None, | |
| anonymous: bool = False, | |
| project_name: str = 'default-project', | |
| run_name: str = None, | |
| wandb_sweep: bool = False, | |
| ) -> Callable: | |
| """ | |
| Overview: | |
| Wandb visualizer to track the experiment. | |
| Arguments: | |
| - record_path (:obj:`str`): The path to save the replay of simulation. | |
| - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: | |
| - gradient_logger: boolean. Whether to track the gradient. | |
| - plot_logger: boolean. Whether to track the metrics like reward and loss. | |
| - video_logger: boolean. Whether to upload the rendering video replay. | |
| - action_logger: boolean. `q_value` or `action probability`. | |
| - return_logger: boolean. Whether to track the return value. | |
| - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. | |
| - env (:obj:`BaseEnvManagerV2`): Evaluator environment. | |
| - model (:obj:`nn.Module`): Policy neural network model. | |
| - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ | |
| of data without wandb count. | |
| - project_name (:obj:`str`): The name of wandb project. | |
| - run_name (:obj:`str`): The name of wandb run. | |
| - wandb_sweep (:obj:`bool`): Whether to use wandb sweep. | |
| ''' | |
| Returns: | |
| - _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input. | |
| """ | |
| if task.router.is_active and not task.has_role(task.role.LEARNER): | |
| return task.void() | |
| color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] | |
| if metric_list is None: | |
| metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] | |
| # Initialize wandb with default settings | |
| # Settings can be covered by calling wandb.init() at the top of the script | |
| if exp_config: | |
| if not wandb_sweep: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config, reinit=True) | |
| else: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config) | |
| else: | |
| if not wandb_sweep: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, reinit=True, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, reinit=True, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, reinit=True) | |
| else: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name) | |
| plt.switch_backend('agg') | |
| if cfg is None: | |
| cfg = EasyDict( | |
| dict( | |
| gradient_logger=False, | |
| plot_logger=True, | |
| video_logger=False, | |
| action_logger=False, | |
| return_logger=False, | |
| ) | |
| ) | |
| else: | |
| if not isinstance(cfg, EasyDict): | |
| cfg = EasyDict(cfg) | |
| for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: | |
| if key not in cfg.keys(): | |
| cfg[key] = False | |
| # The visualizer is called to save the replay of the simulation | |
| # which will be uploaded to wandb later | |
| if env is not None and cfg.video_logger is True and record_path is not None: | |
| env.enable_save_replay(replay_path=record_path) | |
| if cfg.gradient_logger: | |
| wandb.watch(model, log="all", log_freq=100, log_graph=True) | |
| else: | |
| one_time_warning( | |
| "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." | |
| ) | |
| first_plot = True | |
| def _plot(ctx: "OnlineRLContext"): | |
| nonlocal first_plot | |
| if first_plot: | |
| first_plot = False | |
| ctx.wandb_url = wandb.run.get_project_url() | |
| info_for_logging = {} | |
| if cfg.plot_logger: | |
| for metric in metric_list: | |
| if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: | |
| if isinstance(ctx.train_output[metric], torch.Tensor): | |
| info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) | |
| else: | |
| info_for_logging.update({metric: ctx.train_output[metric]}) | |
| elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: | |
| metric_value_list = [] | |
| for item in ctx.train_output: | |
| if isinstance(item[metric], torch.Tensor): | |
| metric_value_list.append(item[metric].cpu().detach().numpy()) | |
| else: | |
| metric_value_list.append(item[metric]) | |
| metric_value = np.mean(metric_value_list) | |
| info_for_logging.update({metric: metric_value}) | |
| else: | |
| one_time_warning( | |
| "If you want to use wandb to visualize the result, please set plot_logger = True in the config." | |
| ) | |
| if ctx.eval_value != -np.inf: | |
| if hasattr(ctx, "eval_value_min"): | |
| info_for_logging.update({ | |
| "episode return min": ctx.eval_value_min, | |
| }) | |
| if hasattr(ctx, "eval_value_max"): | |
| info_for_logging.update({ | |
| "episode return max": ctx.eval_value_max, | |
| }) | |
| if hasattr(ctx, "eval_value_std"): | |
| info_for_logging.update({ | |
| "episode return std": ctx.eval_value_std, | |
| }) | |
| if hasattr(ctx, "eval_value"): | |
| info_for_logging.update({ | |
| "episode return mean": ctx.eval_value, | |
| }) | |
| if hasattr(ctx, "train_iter"): | |
| info_for_logging.update({ | |
| "train iter": ctx.train_iter, | |
| }) | |
| if hasattr(ctx, "env_step"): | |
| info_for_logging.update({ | |
| "env step": ctx.env_step, | |
| }) | |
| eval_output = ctx.eval_output['output'] | |
| episode_return = ctx.eval_output['episode_return'] | |
| episode_return = np.array(episode_return) | |
| if len(episode_return.shape) == 2: | |
| episode_return = episode_return.squeeze(1) | |
| if cfg.video_logger: | |
| if 'replay_video' in ctx.eval_output: | |
| # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format | |
| # The numpy tensor must be either 4 dimensional or 5 dimensional. | |
| # Channels should be (time, channel, height, width) or (batch, time, channel, height width) | |
| video_images = ctx.eval_output['replay_video'] | |
| video_images = video_images.astype(np.uint8) | |
| info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) | |
| elif record_path is not None: | |
| file_list = [] | |
| for p in os.listdir(record_path): | |
| if os.path.splitext(p)[-1] == ".mp4": | |
| file_list.append(p) | |
| file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) | |
| video_path = os.path.join(record_path, file_list[-2]) | |
| info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) | |
| if cfg.action_logger: | |
| action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif")) | |
| if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): | |
| if isinstance(eval_output, tnp.ndarray): | |
| action_prob = softmax(eval_output.logit) | |
| else: | |
| action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] | |
| fig, ax = plt.subplots() | |
| plt.ylim([-1, 1]) | |
| action_dim = len(action_prob[1]) | |
| x_range = [str(x + 1) for x in range(action_dim)] | |
| ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) | |
| ani = animation.FuncAnimation( | |
| fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) | |
| ) | |
| ani.save(action_path, writer='pillow') | |
| info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) | |
| elif all(['action' in v for v in eval_output[0]]): | |
| for i, action_trajectory in enumerate(eval_output): | |
| fig, ax = plt.subplots() | |
| fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) | |
| steps = fig_data[:, 0] | |
| actions = fig_data[:, 1:] | |
| plt.ylim([-1, 1]) | |
| for j in range(actions.shape[1]): | |
| ax.scatter(steps, actions[:, j]) | |
| info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) | |
| if cfg.return_logger: | |
| return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif")) | |
| fig, ax = plt.subplots() | |
| ax = plt.gca() | |
| ax.set_ylim([0, 1]) | |
| hist, x_dim = return_distribution(episode_return) | |
| assert len(hist) == len(x_dim) | |
| ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) | |
| ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) | |
| ani.save(return_path, writer='pillow') | |
| info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) | |
| if bool(info_for_logging): | |
| wandb.log(data=info_for_logging, step=ctx.env_step) | |
| plt.clf() | |
| return _plot | |
| def wandb_offline_logger( | |
| record_path: str = None, | |
| cfg: Union[dict, EasyDict] = None, | |
| exp_config: Union[dict, EasyDict] = None, | |
| metric_list: Optional[List[str]] = None, | |
| env: Optional[BaseEnvManagerV2] = None, | |
| model: Optional[torch.nn.Module] = None, | |
| anonymous: bool = False, | |
| project_name: str = 'default-project', | |
| run_name: str = None, | |
| wandb_sweep: bool = False, | |
| ) -> Callable: | |
| """ | |
| Overview: | |
| Wandb visualizer to track the experiment. | |
| Arguments: | |
| - record_path (:obj:`str`): The path to save the replay of simulation. | |
| - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings: | |
| - gradient_logger: boolean. Whether to track the gradient. | |
| - plot_logger: boolean. Whether to track the metrics like reward and loss. | |
| - video_logger: boolean. Whether to upload the rendering video replay. | |
| - action_logger: boolean. `q_value` or `action probability`. | |
| - return_logger: boolean. Whether to track the return value. | |
| - vis_dataset: boolean. Whether to visualize the dataset. | |
| - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies. | |
| - env (:obj:`BaseEnvManagerV2`): Evaluator environment. | |
| - model (:obj:`nn.Module`): Policy neural network model. | |
| - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \ | |
| of data without wandb count. | |
| - project_name (:obj:`str`): The name of wandb project. | |
| - run_name (:obj:`str`): The name of wandb run. | |
| - wandb_sweep (:obj:`bool`): Whether to use wandb sweep. | |
| ''' | |
| Returns: | |
| - _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input. | |
| """ | |
| if task.router.is_active and not task.has_role(task.role.LEARNER): | |
| return task.void() | |
| color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"] | |
| if metric_list is None: | |
| metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"] | |
| # Initialize wandb with default settings | |
| # Settings can be covered by calling wandb.init() at the top of the script | |
| if exp_config: | |
| if not wandb_sweep: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config, reinit=True) | |
| else: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, config=exp_config, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, config=exp_config) | |
| else: | |
| if not wandb_sweep: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, reinit=True, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, reinit=True, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, reinit=True) | |
| else: | |
| if run_name is not None: | |
| if anonymous: | |
| wandb.init(project=project_name, name=run_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name, name=run_name) | |
| else: | |
| if anonymous: | |
| wandb.init(project=project_name, anonymous="must") | |
| else: | |
| wandb.init(project=project_name) | |
| plt.switch_backend('agg') | |
| plt.switch_backend('agg') | |
| if cfg is None: | |
| cfg = EasyDict( | |
| dict( | |
| gradient_logger=False, | |
| plot_logger=True, | |
| video_logger=False, | |
| action_logger=False, | |
| return_logger=False, | |
| vis_dataset=True, | |
| ) | |
| ) | |
| else: | |
| if not isinstance(cfg, EasyDict): | |
| cfg = EasyDict(cfg) | |
| for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]: | |
| if key not in cfg.keys(): | |
| cfg[key] = False | |
| # The visualizer is called to save the replay of the simulation | |
| # which will be uploaded to wandb later | |
| if env is not None and cfg.video_logger is True and record_path is not None: | |
| env.enable_save_replay(replay_path=record_path) | |
| if cfg.gradient_logger: | |
| wandb.watch(model, log="all", log_freq=100, log_graph=True) | |
| else: | |
| one_time_warning( | |
| "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config." | |
| ) | |
| first_plot = True | |
| def _vis_dataset(datasetpath: str): | |
| try: | |
| from sklearn.manifold import TSNE | |
| except ImportError: | |
| import sys | |
| logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.") | |
| sys.exit(1) | |
| try: | |
| import h5py | |
| except ImportError: | |
| import sys | |
| logging.warning("Please install h5py first, such as `pip3 install h5py`.") | |
| sys.exit(1) | |
| assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5'] | |
| if os.path.splitext(datasetpath)[-1] == '.pkl': | |
| with open(datasetpath, 'rb') as f: | |
| data = pickle.load(f) | |
| obs = [] | |
| action = [] | |
| reward = [] | |
| for i in range(len(data)): | |
| obs.extend(data[i]['observations']) | |
| action.extend(data[i]['actions']) | |
| reward.extend(data[i]['rewards']) | |
| elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']: | |
| with h5py.File(datasetpath, 'r') as f: | |
| obs = f['obs'][()] | |
| action = f['action'][()] | |
| reward = f['reward'][()] | |
| cmap = plt.cm.hsv | |
| obs = np.array(obs) | |
| reward = np.array(reward) | |
| obs_action = np.hstack((obs, np.array(action))) | |
| reward = reward / (max(reward) - min(reward)) | |
| embedded_obs = TSNE(n_components=2).fit_transform(obs) | |
| embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action) | |
| x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0) | |
| embedded_obs = embedded_obs / (x_max - x_min) | |
| x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0) | |
| embedded_obs_action = embedded_obs_action / (x_max - x_min) | |
| fig = plt.figure() | |
| f, axes = plt.subplots(nrows=1, ncols=3) | |
| axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward)) | |
| axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action)) | |
| axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward)) | |
| axes[0].set_title('state-reward') | |
| axes[1].set_title('state-action') | |
| axes[2].set_title('stateAction-reward') | |
| plt.savefig('dataset.png') | |
| wandb.log({"dataset": wandb.Image("dataset.png")}) | |
| if cfg.vis_dataset is True: | |
| _vis_dataset(exp_config.dataset_path) | |
| def _plot(ctx: "OfflineRLContext"): | |
| nonlocal first_plot | |
| if first_plot: | |
| first_plot = False | |
| ctx.wandb_url = wandb.run.get_project_url() | |
| info_for_logging = {} | |
| if cfg.plot_logger: | |
| for metric in metric_list: | |
| if isinstance(ctx.train_output, Dict) and metric in ctx.train_output: | |
| if isinstance(ctx.train_output[metric], torch.Tensor): | |
| info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()}) | |
| else: | |
| info_for_logging.update({metric: ctx.train_output[metric]}) | |
| elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]: | |
| metric_value_list = [] | |
| for item in ctx.train_output: | |
| if isinstance(item[metric], torch.Tensor): | |
| metric_value_list.append(item[metric].cpu().detach().numpy()) | |
| else: | |
| metric_value_list.append(item[metric]) | |
| metric_value = np.mean(metric_value_list) | |
| info_for_logging.update({metric: metric_value}) | |
| else: | |
| one_time_warning( | |
| "If you want to use wandb to visualize the result, please set plot_logger = True in the config." | |
| ) | |
| if ctx.eval_value != -np.inf: | |
| if hasattr(ctx, "eval_value_min"): | |
| info_for_logging.update({ | |
| "episode return min": ctx.eval_value_min, | |
| }) | |
| if hasattr(ctx, "eval_value_max"): | |
| info_for_logging.update({ | |
| "episode return max": ctx.eval_value_max, | |
| }) | |
| if hasattr(ctx, "eval_value_std"): | |
| info_for_logging.update({ | |
| "episode return std": ctx.eval_value_std, | |
| }) | |
| if hasattr(ctx, "eval_value"): | |
| info_for_logging.update({ | |
| "episode return mean": ctx.eval_value, | |
| }) | |
| if hasattr(ctx, "train_iter"): | |
| info_for_logging.update({ | |
| "train iter": ctx.train_iter, | |
| }) | |
| if hasattr(ctx, "train_epoch"): | |
| info_for_logging.update({ | |
| "train_epoch": ctx.train_epoch, | |
| }) | |
| eval_output = ctx.eval_output['output'] | |
| episode_return = ctx.eval_output['episode_return'] | |
| episode_return = np.array(episode_return) | |
| if len(episode_return.shape) == 2: | |
| episode_return = episode_return.squeeze(1) | |
| if cfg.video_logger: | |
| if 'replay_video' in ctx.eval_output: | |
| # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format | |
| # The numpy tensor must be either 4 dimensional or 5 dimensional. | |
| # Channels should be (time, channel, height, width) or (batch, time, channel, height width) | |
| video_images = ctx.eval_output['replay_video'] | |
| video_images = video_images.astype(np.uint8) | |
| info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)}) | |
| elif record_path is not None: | |
| file_list = [] | |
| for p in os.listdir(record_path): | |
| if os.path.splitext(p)[-1] == ".mp4": | |
| file_list.append(p) | |
| file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn))) | |
| video_path = os.path.join(record_path, file_list[-2]) | |
| info_for_logging.update({"video": wandb.Video(video_path, format="mp4")}) | |
| if cfg.action_logger: | |
| action_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_action.gif")) | |
| if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"): | |
| if isinstance(eval_output, tnp.ndarray): | |
| action_prob = softmax(eval_output.logit) | |
| else: | |
| action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output] | |
| fig, ax = plt.subplots() | |
| plt.ylim([-1, 1]) | |
| action_dim = len(action_prob[1]) | |
| x_range = [str(x + 1) for x in range(action_dim)] | |
| ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim]) | |
| ani = animation.FuncAnimation( | |
| fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob) | |
| ) | |
| ani.save(action_path, writer='pillow') | |
| info_for_logging.update({"action": wandb.Video(action_path, format="gif")}) | |
| elif all(['action' in v for v in eval_output[0]]): | |
| for i, action_trajectory in enumerate(eval_output): | |
| fig, ax = plt.subplots() | |
| fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)]) | |
| steps = fig_data[:, 0] | |
| actions = fig_data[:, 1:] | |
| plt.ylim([-1, 1]) | |
| for j in range(actions.shape[1]): | |
| ax.scatter(steps, actions[:, j]) | |
| info_for_logging.update({"actions_of_trajectory_{}".format(i): fig}) | |
| if cfg.return_logger: | |
| return_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_return.gif")) | |
| fig, ax = plt.subplots() | |
| ax = plt.gca() | |
| ax.set_ylim([0, 1]) | |
| hist, x_dim = return_distribution(episode_return) | |
| assert len(hist) == len(x_dim) | |
| ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7) | |
| ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1) | |
| ani.save(return_path, writer='pillow') | |
| info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")}) | |
| if bool(info_for_logging): | |
| wandb.log(data=info_for_logging, step=ctx.trained_env_step) | |
| plt.clf() | |
| return _plot | |