Buckets:
| # SPDX-FileCopyrightText: Copyright (c) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| # | |
| # Redistribution and use in source and binary forms, with or without | |
| # modification, are permitted provided that the following conditions are met: | |
| # | |
| # 1. Redistributions of source code must retain the above copyright notice, this | |
| # list of conditions and the following disclaimer. | |
| # | |
| # 2. Redistributions in binary form must reproduce the above copyright notice, | |
| # this list of conditions and the following disclaimer in the documentation | |
| # and/or other materials provided with the distribution. | |
| # | |
| # 3. Neither the name of the copyright holder nor the names of its | |
| # contributors may be used to endorse or promote products derived from | |
| # this software without specific prior written permission. | |
| # | |
| # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
| # AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
| # IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
| # DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
| # FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
| # DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
| # SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
| # CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
| # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
| # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
| # | |
| # Copyright (c) 2021 ETH Zurich, Nikita Rudin | |
| import time | |
| import os | |
| from collections import deque | |
| import statistics | |
| # from torch.utils.tensorboard import SummaryWriter | |
| import torch | |
| import numpy as np | |
| from rsl_rl.algorithms import PPO | |
| from rsl_rl.modules import * | |
| # from rsl_rl.env import VecEnv | |
| import IPython; e = IPython.embed | |
| import wandb | |
| class OnPolicyRunner: | |
| def __init__(self, | |
| env, | |
| train_cfg, | |
| log_dir=None, | |
| device='cpu', | |
| **kwargs): | |
| self.cfg=train_cfg["runner"] | |
| self.alg_cfg = train_cfg["algorithm"] | |
| self.policy_cfg = train_cfg["policy"] | |
| self.device = device | |
| self.env = env | |
| if self.env.num_privileged_obs is not None: | |
| num_critic_obs = self.env.num_privileged_obs | |
| else: | |
| num_critic_obs = self.env.num_obs | |
| ############################################################################################################ | |
| self.use_vision = self.env.cfg.sensor.enable_sensor | |
| ############################################################################################################ | |
| actor_critic_class = eval(self.cfg["policy_class_name"]) # ActorCritic | |
| # if self.env has attribute obs_context_len | |
| if hasattr(self.env, 'obs_context_len'): | |
| obs_context_len = self.env.obs_context_len | |
| else: | |
| obs_context_len = 1 | |
| args = kwargs['args'] | |
| actor_critic = actor_critic_class( | |
| self.env.num_obs, | |
| num_critic_obs, | |
| self.env.num_actions, | |
| obs_context_len=obs_context_len, | |
| **self.policy_cfg, | |
| device=self.device, | |
| args=args, | |
| ).to(self.device) | |
| alg_class = eval(self.cfg["algorithm_class_name"]) # PPO | |
| self.alg = alg_class( | |
| actor_critic, device=self.device, **self.alg_cfg) | |
| self.num_steps_per_env = self.cfg["num_steps_per_env"] | |
| self.save_interval = self.cfg["save_interval"] | |
| # init storage and model | |
| if self.use_vision: | |
| obs_vision_shape = [obs_context_len, 3, self.env.cfg.sensor.camera.height, self.env.cfg.sensor.camera.width] if obs_context_len != 1 else [3, self.env.cfg.sensor.camera.height, self.env.cfg.sensor.camera.width] | |
| else: | |
| obs_vision_shape = None | |
| obs_shape = [obs_context_len, self.env.num_obs] if obs_context_len != 1 else [self.env.num_obs] | |
| self.alg.init_storage(self.env.num_envs, self.num_steps_per_env, obs_shape, obs_vision_shape, [self.env.num_privileged_obs], [self.env.num_actions]) | |
| # Log | |
| self.log_dir = log_dir | |
| self.writer = None | |
| self.tot_timesteps = 0 | |
| self.tot_time = 0 | |
| self.current_learning_iteration = 0 | |
| _, _ = self.env.reset() | |
| def learn(self, num_learning_iterations, init_at_random_ep_len=False): | |
| # # initialize writer | |
| # if self.log_dir is not None and self.writer is None: | |
| # self.writer = SummaryWriter(log_dir=self.log_dir, flush_secs=10) | |
| if init_at_random_ep_len: | |
| self.env.episode_length_buf = torch.randint_like(self.env.episode_length_buf, high=int(self.env.max_episode_length)) | |
| obs = self.env.get_observations() | |
| if self.use_vision: | |
| obs_vision = self.env.get_visual_observations().to(self.device) | |
| privileged_obs = self.env.get_privileged_observations() | |
| critic_obs = privileged_obs if privileged_obs is not None else obs | |
| obs, critic_obs = obs.to(self.device), critic_obs.to(self.device) | |
| self.alg.actor_critic.train() # switch to train mode (for dropout for example) | |
| ep_infos = [] | |
| ep_metrics = [] | |
| rewbuffer = deque(maxlen=100) | |
| lenbuffer = deque(maxlen=100) | |
| donebuffer = deque(maxlen=100) | |
| cur_reward_sum = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) | |
| cur_episode_length = torch.zeros(self.env.num_envs, dtype=torch.float, device=self.device) | |
| self.tot_iter = self.current_learning_iteration + num_learning_iterations # starting from current and train for num_learning_iterations | |
| self.start_iter = self.current_learning_iteration | |
| for it in range(self.start_iter, self.tot_iter): | |
| start = time.time() | |
| # Rollout | |
| with torch.inference_mode(): | |
| for i in range(self.num_steps_per_env): | |
| if self.use_vision: | |
| actions = self.alg.act((obs, obs_vision), (critic_obs, obs_vision)) | |
| else: | |
| actions = self.alg.act(obs, critic_obs) | |
| obs, privileged_obs, rewards, dones, infos = self.env.step(actions) | |
| if self.use_vision: | |
| obs_vision = self.env.get_visual_observations().to(self.device) | |
| critic_obs = privileged_obs if privileged_obs is not None else obs | |
| obs, critic_obs, rewards, dones = obs.to(self.device), critic_obs.to(self.device), rewards.to(self.device), dones.to(self.device) | |
| self.alg.process_env_step(rewards, dones, infos) | |
| if self.log_dir is not None: | |
| # Book keeping | |
| if 'episode' in infos: | |
| ep_infos.append(infos['episode']) | |
| if 'episode_metrics' in infos: | |
| ep_metrics.append(infos['episode_metrics']) | |
| cur_reward_sum += rewards | |
| cur_episode_length += 1 | |
| new_ids = (dones > 0).nonzero(as_tuple=False) | |
| rewbuffer.extend(cur_reward_sum[new_ids][:, 0].cpu().numpy().tolist()) | |
| lenbuffer.extend(cur_episode_length[new_ids][:, 0].cpu().numpy().tolist()) | |
| donebuffer.append(len(new_ids) / self.env.num_envs) | |
| cur_reward_sum[new_ids] = 0 | |
| cur_episode_length[new_ids] = 0 | |
| stop = time.time() | |
| collection_time = stop - start | |
| # Learning step | |
| start = stop | |
| if self.use_vision: | |
| self.alg.compute_returns((critic_obs, obs_vision)) | |
| else: | |
| self.alg.compute_returns(critic_obs) | |
| mean_value_loss, mean_surrogate_loss = self.alg.update() | |
| stop = time.time() | |
| learn_time = stop - start | |
| if self.log_dir is not None: | |
| self.log(locals()) | |
| if it % self.save_interval == 0: | |
| self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(it))) | |
| ep_infos.clear() | |
| ep_metrics.clear() | |
| self.current_learning_iteration += 1 | |
| self.save(os.path.join(self.log_dir, 'model_{}.pt'.format(self.current_learning_iteration))) | |
| def log(self, locs, width=80, pad=35): | |
| self.tot_timesteps += self.num_steps_per_env * self.env.num_envs | |
| self.tot_time += locs['collection_time'] + locs['learn_time'] | |
| iteration_time = locs['collection_time'] + locs['learn_time'] | |
| ep_string = f'' | |
| wandb_dict = {} | |
| if locs['ep_infos']: | |
| for key in locs['ep_infos'][0]: | |
| infotensor = torch.tensor([], device=self.device) | |
| for ep_info in locs['ep_infos']: | |
| # handle scalar and zero dimensional tensor infos | |
| if not isinstance(ep_info[key], torch.Tensor): | |
| ep_info[key] = torch.Tensor([ep_info[key]]) | |
| if len(ep_info[key].shape) == 0: | |
| ep_info[key] = ep_info[key].unsqueeze(0) | |
| infotensor = torch.cat((infotensor, ep_info[key].to(self.device))) | |
| value = torch.mean(infotensor) | |
| # wandb.log({'Episode/' + key: value}, step=locs['it']) | |
| wandb_dict['Episode/' + key] = value | |
| ep_string += f"""{f'Mean episode {key}:':>{pad}} {value:.4f}\n""" | |
| if locs['ep_metrics']: | |
| for key in locs['ep_metrics'][0]: | |
| info = [] | |
| for ep_metric in locs['ep_metrics']: | |
| info.append(ep_metric[key]) | |
| value = np.mean(info) | |
| # wandb.log({'Episode/' + key: value}, step=locs['it']) | |
| wandb_dict['Metric/' + key] = value | |
| ep_string += f"""{f'Mean episode metric {key}:':>{pad}} {value:.4f}\n""" | |
| std = self.alg.actor_critic.std.cpu().detach().numpy() | |
| mean_std = std.mean() | |
| entropy = self.alg.actor_critic.entropy.detach().mean().item() | |
| fps = int(self.num_steps_per_env * self.env.num_envs / (locs['collection_time'] + locs['learn_time'])) | |
| wandb_dict['Loss/value_function'] = locs['mean_value_loss'] | |
| wandb_dict['Loss/surrogate'] = locs['mean_surrogate_loss'] | |
| wandb_dict['Loss/entropy'] = entropy | |
| wandb_dict['Loss/learning_rate'] = self.alg.learning_rate | |
| wandb_dict['Perf/total_fps'] = fps | |
| wandb_dict['Perf/collection time'] = locs['collection_time'] | |
| wandb_dict['Perf/learning_time'] = locs['learn_time'] | |
| wandb_dict['Std/mean_std'] = mean_std | |
| # log all dim of the std | |
| for i, std in enumerate(self.alg.actor_critic.std): | |
| wandb_dict[f'Std/std_dim_{i}'] = std | |
| if len(locs['rewbuffer']) > 0: | |
| wandb_dict['Train/mean_reward'] = statistics.mean(locs['rewbuffer']) | |
| # wandb_dict['Train/mean_arm_reward'] = statistics.mean(locs['armrewbuffer']) | |
| wandb_dict['Train/mean_episode_length'] = statistics.mean(locs['lenbuffer']) | |
| wandb_dict['Train/dones'] = statistics.mean(locs['donebuffer']) | |
| # wandb.log({'Train/mean_reward/time': statistics.mean(locs['rewbuffer'])}, step=self.tot_time) | |
| # wandb.log({'Train/mean_episode_length/time': statistics.mean(locs['lenbuffer'])}, step=self.tot_time) | |
| wandb.log(wandb_dict, step=locs['it']) | |
| str = f" \033[1m Learning iteration {locs['it']}/{self.tot_iter} \033[0m " | |
| if len(locs['rewbuffer']) > 0: | |
| log_string = (f"""{'#' * width}\n""" | |
| f"""{str.center(width, ' ')}\n\n""" | |
| f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ | |
| 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" | |
| f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" | |
| f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" | |
| f"""{'Mean action noise std:':>{pad}} {mean_std:.2f}\n""" | |
| f"""{'Mean reward:':>{pad}} {statistics.mean(locs['rewbuffer']):.2f}\n""" | |
| f"""{'Mean episode length:':>{pad}} {statistics.mean(locs['lenbuffer']):.2f}\n""") | |
| # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" | |
| # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") | |
| else: | |
| log_string = (f"""{'#' * width}\n""" | |
| f"""{str.center(width, ' ')}\n\n""" | |
| f"""{'Computation:':>{pad}} {fps:.0f} steps/s (collection: {locs[ | |
| 'collection_time']:.3f}s, learning {locs['learn_time']:.3f}s)\n""" | |
| f"""{'Value function loss:':>{pad}} {locs['mean_value_loss']:.4f}\n""" | |
| f"""{'Surrogate loss:':>{pad}} {locs['mean_surrogate_loss']:.4f}\n""" | |
| f"""{'Mean action noise std:':>{pad}} {mean_std:.2f}\n""") | |
| # f"""{'Mean reward/step:':>{pad}} {locs['mean_reward']:.2f}\n""" | |
| # f"""{'Mean episode length/episode:':>{pad}} {locs['mean_trajectory_length']:.2f}\n""") | |
| log_string += ep_string | |
| eta = self.tot_time / (locs['it'] + 1 - self.start_iter) * (locs['num_learning_iterations'] - (locs['it'] - self.start_iter)) | |
| eta_hrs, eta_mins, eta_secs = eta // 3600, (eta % 3600) // 60, eta % 60 | |
| tot_hrs, tot_mins, tot_secs = self.tot_time // 3600, (self.tot_time % 3600) // 60, self.tot_time % 60 | |
| log_string += (f"""{'-' * width}\n""" | |
| f"""{'Experiment name:':>{pad}} {self.cfg['experiment_name']}\n""" | |
| f"""{'Run name:':>{pad}} {self.cfg['run_name']}\n""" | |
| f"""{'Progress:':>{pad}} {self.start_iter}+{locs['it']-self.start_iter}/{self.tot_iter-self.start_iter}+{self.start_iter}\n""" | |
| f"""{'Device:':>{pad}} {self.device}\n""" | |
| f"""{'Total timesteps:':>{pad}} {self.tot_timesteps}\n""" | |
| f"""{'Iteration time:':>{pad}} {iteration_time:.2f}s\n""" | |
| f"""{'Total time:':>{pad}} {tot_hrs:.0f} hrs {tot_mins:.0f} mins {tot_secs:.1f} s\n""" | |
| f"""{'ETA:':>{pad}} {eta_hrs:.0f} hrs {eta_mins:.0f} mins {eta_secs:.1f} s\n""") | |
| print(log_string) | |
| def save(self, path, infos=None): | |
| torch.save({ | |
| 'model_state_dict': self.alg.actor_critic.state_dict(), | |
| 'optimizer_state_dict': self.alg.optimizer.state_dict(), | |
| 'iter': self.current_learning_iteration, | |
| 'infos': infos, | |
| }, path) | |
| def load(self, path, load_optimizer=True): | |
| try: | |
| loaded_dict = torch.load(path) | |
| except: | |
| loaded_dict = torch.load(path, map_location="cuda:0") | |
| self.alg.actor_critic.load_state_dict(loaded_dict['model_state_dict'], strict=False) | |
| if load_optimizer: | |
| self.alg.optimizer.load_state_dict(loaded_dict['optimizer_state_dict']) | |
| self.current_learning_iteration = loaded_dict['iter'] | |
| return loaded_dict['infos'] | |
| def get_inference_policy(self, device=None, hrl=False): | |
| self.alg.actor_critic.eval() # switch to evaluation mode (dropout for example) | |
| if device is not None: | |
| self.alg.actor_critic.to(device) | |
| if hrl: | |
| return self.alg.actor_critic.act_inference_hrl | |
| else: | |
| return self.alg.actor_critic.act_inference | |
Xet Storage Details
- Size:
- 16.1 kB
- Xet hash:
- 91996c9a57a7f6b3a7b252034fd9df7e848c507a102c9f511e5fe8df6ecf22a7
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.