diff --git a/.gitattributes b/.gitattributes index db23df24a51b016a716ec6e56057a782e5fe1690..c9bca540b7213f83530a2a192c42adbeff2ffc39 100644 --- a/.gitattributes +++ b/.gitattributes @@ -178,3 +178,4 @@ tuning-competition-baseline/.venv/lib/python3.11/site-packages/torch/_inductor/_ .venv/lib/python3.11/site-packages/ray/_private/thirdparty/pynvml/__pycache__/pynvml.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text .venv/lib/python3.11/site-packages/ray/rllib/algorithms/__pycache__/algorithm_config.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text +.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc filter=lfs diff=lfs merge=lfs -text diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a84ec966ed86cfcf6ad53e0c8dd7dcaabf72247b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/env/__pycache__/multi_agent_episode.cpython-311.pyc @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:70ee04d5ba78d502ad5d58d83cd6ec52ed3635c4af63ccc12837f71debf75e54 +size 115849 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..08f5bd48be3dbc03584483ed0066e5deac1efe2a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__init__.py @@ -0,0 +1,20 @@ +from ray.rllib.evaluation.rollout_worker import RolloutWorker +from ray.rllib.evaluation.sample_batch_builder import ( + SampleBatchBuilder, + MultiAgentSampleBatchBuilder, +) +from ray.rllib.evaluation.sampler import SyncSampler +from ray.rllib.evaluation.postprocessing import compute_advantages +from ray.rllib.evaluation.metrics import collect_metrics +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch + +__all__ = [ + "RolloutWorker", + "SampleBatch", + "MultiAgentBatch", + "SampleBatchBuilder", + "MultiAgentSampleBatchBuilder", + "SyncSampler", + "compute_advantages", + "collect_metrics", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1add60d0e155e1598328fde218b7ff4bc71c13ac Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/env_runner_v2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/env_runner_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7c373e5524e14e9ced85970b709bcc20ca51571 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/env_runner_v2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/episode_v2.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/episode_v2.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..7e629d89647ecc9978b06a749c39d537a3f8355a Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/episode_v2.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/metrics.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/metrics.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ffe7d7fab90d29f638a9113298a30a1cc706e968 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/metrics.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/postprocessing.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/postprocessing.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6df92031ea94654ecc4383437aa500426d2cc123 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/postprocessing.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/rollout_worker.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/rollout_worker.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cc3f46f11180c587f67aa2ff1843f5711874d192 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/rollout_worker.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sampler.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sampler.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..da7c66c19878e4cc49e8def1e2fbeae30069dfd5 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/__pycache__/sampler.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py new file mode 100644 index 0000000000000000000000000000000000000000..a301f61ec0df78cfb0bdefa3ca4d44927ed9e168 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/collectors/simple_list_collector.py @@ -0,0 +1,698 @@ +import collections +from gymnasium.spaces import Space +import logging +import numpy as np +import tree # pip install dm_tree +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING, Union + +from ray.rllib.env.base_env import _DUMMY_AGENT_ID +from ray.rllib.evaluation.collectors.sample_collector import SampleCollector +from ray.rllib.evaluation.collectors.agent_collector import AgentCollector +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.policy_map import PolicyMap +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch, concat_samples +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.debug import summarize +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.spaces.space_utils import get_dummy_batch_for_space +from ray.rllib.utils.typing import ( + AgentID, + EpisodeID, + EnvID, + PolicyID, + TensorType, + ViewRequirementsDict, +) +from ray.util.debug import log_once + +_, tf, _ = try_import_tf() +torch, _ = try_import_torch() + +if TYPE_CHECKING: + from ray.rllib.callbacks.callbacks import RLlibCallback + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class _PolicyCollector: + """Collects already postprocessed (single agent) samples for one policy. + + Samples come in through already postprocessed SampleBatches, which + contain single episode/trajectory data for a single agent and are then + appended to this policy's buffers. + """ + + def __init__(self, policy: Policy): + """Initializes a _PolicyCollector instance. + + Args: + policy: The policy object. + """ + + self.batches = [] + self.policy = policy + # The total timestep count for all agents that use this policy. + # NOTE: This is not an env-step count (across n agents). AgentA and + # agentB, both using this policy, acting in the same episode and both + # doing n steps would increase the count by 2*n. + self.agent_steps = 0 + + def add_postprocessed_batch_for_training( + self, batch: SampleBatch, view_requirements: ViewRequirementsDict + ) -> None: + """Adds a postprocessed SampleBatch (single agent) to our buffers. + + Args: + batch: An individual agent's (one trajectory) + SampleBatch to be added to the Policy's buffers. + view_requirements: The view + requirements for the policy. This is so we know, whether a + view-column needs to be copied at all (not needed for + training). + """ + # Add the agent's trajectory length to our count. + self.agent_steps += batch.count + # And remove columns not needed for training. + for view_col, view_req in view_requirements.items(): + if view_col in batch and not view_req.used_for_training: + del batch[view_col] + self.batches.append(batch) + + def build(self): + """Builds a SampleBatch for this policy from the collected data. + + Also resets all buffers for further sample collection for this policy. + + Returns: + SampleBatch: The SampleBatch with all thus-far collected data for + this policy. + """ + # Create batch from our buffers. + batch = concat_samples(self.batches) + # Clear batches for future samples. + self.batches = [] + # Reset agent steps to 0. + self.agent_steps = 0 + # Add num_grad_updates counter to the policy's batch. + batch.num_grad_updates = self.policy.num_grad_updates + + return batch + + +class _PolicyCollectorGroup: + def __init__(self, policy_map): + self.policy_collectors = {} + # Total env-steps (1 env-step=up to N agents stepped). + self.env_steps = 0 + # Total agent steps (1 agent-step=1 individual agent (out of N) + # stepped). + self.agent_steps = 0 + + +@OldAPIStack +class SimpleListCollector(SampleCollector): + """Util to build SampleBatches for each policy in a multi-agent env. + + Input data is per-agent, while output data is per-policy. There is an M:N + mapping between agents and policies. We retain one local batch builder + per agent. When an agent is done, then its local batch is appended into the + corresponding policy batch for the agent's policy. + """ + + def __init__( + self, + policy_map: PolicyMap, + clip_rewards: Union[bool, float], + callbacks: "RLlibCallback", + multiple_episodes_in_batch: bool = True, + rollout_fragment_length: int = 200, + count_steps_by: str = "env_steps", + ): + """Initializes a SimpleListCollector instance.""" + + super().__init__( + policy_map, + clip_rewards, + callbacks, + multiple_episodes_in_batch, + rollout_fragment_length, + count_steps_by, + ) + + self.large_batch_threshold: int = ( + max(1000, self.rollout_fragment_length * 10) + if self.rollout_fragment_length != float("inf") + else 5000 + ) + + # Whenever we observe a new episode+agent, add a new + # _SingleTrajectoryCollector. + self.agent_collectors: Dict[Tuple[EpisodeID, AgentID], AgentCollector] = {} + # Internal agent-key-to-policy-id map. + self.agent_key_to_policy_id = {} + # Pool of used/unused PolicyCollectorGroups (attached to episodes for + # across-episode multi-agent sample collection). + self.policy_collector_groups = [] + + # Agents to collect data from for the next forward pass (per policy). + self.forward_pass_agent_keys = {pid: [] for pid in self.policy_map.keys()} + self.forward_pass_size = {pid: 0 for pid in self.policy_map.keys()} + + # Maps episode ID to the (non-built) env steps taken in this episode. + self.episode_steps: Dict[EpisodeID, int] = collections.defaultdict(int) + # Maps episode ID to the (non-built) individual agent steps in this + # episode. + self.agent_steps: Dict[EpisodeID, int] = collections.defaultdict(int) + # Maps episode ID to Episode. + self.episodes = {} + + @override(SampleCollector) + def episode_step(self, episode) -> None: + episode_id = episode.episode_id + # In the rase case that an "empty" step is taken at the beginning of + # the episode (none of the agents has an observation in the obs-dict + # and thus does not take an action), we have seen the episode before + # and have to add it here to our registry. + if episode_id not in self.episodes: + self.episodes[episode_id] = episode + else: + assert episode is self.episodes[episode_id] + self.episode_steps[episode_id] += 1 + episode.length += 1 + + # In case of "empty" env steps (no agent is stepping), the builder + # object may still be None. + if episode.batch_builder: + env_steps = episode.batch_builder.env_steps + num_individual_observations = sum( + c.agent_steps for c in episode.batch_builder.policy_collectors.values() + ) + + if num_individual_observations > self.large_batch_threshold and log_once( + "large_batch_warning" + ): + logger.warning( + "More than {} observations in {} env steps for " + "episode {} ".format( + num_individual_observations, env_steps, episode_id + ) + + "are buffered in the sampler. If this is more than you " + "expected, check that that you set a horizon on your " + "environment correctly and that it terminates at some " + "point. Note: In multi-agent environments, " + "`rollout_fragment_length` sets the batch size based on " + "(across-agents) environment steps, not the steps of " + "individual agents, which can result in unexpectedly " + "large batches." + + ( + "Also, you may be waiting for your Env to " + "terminate (batch_mode=`complete_episodes`). Make sure " + "it does at some point." + if not self.multiple_episodes_in_batch + else "" + ) + ) + + @override(SampleCollector) + def add_init_obs( + self, + *, + episode, + agent_id: AgentID, + env_id: EnvID, + policy_id: PolicyID, + init_obs: TensorType, + init_infos: Optional[Dict[str, TensorType]] = None, + t: int = -1, + ) -> None: + # Make sure our mappings are up to date. + agent_key = (episode.episode_id, agent_id) + self.agent_key_to_policy_id[agent_key] = policy_id + policy = self.policy_map[policy_id] + + # Add initial obs to Trajectory. + assert agent_key not in self.agent_collectors + # TODO: determine exact shift-before based on the view-req shifts. + + # get max_seq_len value (Default is 1) + try: + max_seq_len = policy.config["model"]["max_seq_len"] + except KeyError: + max_seq_len = 1 + + self.agent_collectors[agent_key] = AgentCollector( + policy.view_requirements, + max_seq_len=max_seq_len, + disable_action_flattening=policy.config.get( + "_disable_action_flattening", False + ), + intial_states=policy.get_initial_state(), + is_policy_recurrent=policy.is_recurrent(), + ) + self.agent_collectors[agent_key].add_init_obs( + episode_id=episode.episode_id, + agent_index=episode._agent_index(agent_id), + env_id=env_id, + init_obs=init_obs, + init_infos=init_infos or {}, + t=t, + ) + + self.episodes[episode.episode_id] = episode + if episode.batch_builder is None: + episode.batch_builder = ( + self.policy_collector_groups.pop() + if self.policy_collector_groups + else _PolicyCollectorGroup(self.policy_map) + ) + + self._add_to_next_inference_call(agent_key) + + @override(SampleCollector) + def add_action_reward_next_obs( + self, + episode_id: EpisodeID, + agent_id: AgentID, + env_id: EnvID, + policy_id: PolicyID, + agent_done: bool, + values: Dict[str, TensorType], + ) -> None: + # Make sure, episode/agent already has some (at least init) data. + agent_key = (episode_id, agent_id) + assert self.agent_key_to_policy_id[agent_key] == policy_id + assert agent_key in self.agent_collectors + + self.agent_steps[episode_id] += 1 + + # Include the current agent id for multi-agent algorithms. + if agent_id != _DUMMY_AGENT_ID: + values["agent_id"] = agent_id + + # Add action/reward/next-obs (and other data) to Trajectory. + self.agent_collectors[agent_key].add_action_reward_next_obs(values) + + if not agent_done: + self._add_to_next_inference_call(agent_key) + + @override(SampleCollector) + def total_env_steps(self) -> int: + # Add the non-built ongoing-episode env steps + the already built + # env-steps. + return sum(self.episode_steps.values()) + sum( + pg.env_steps for pg in self.policy_collector_groups.values() + ) + + @override(SampleCollector) + def total_agent_steps(self) -> int: + # Add the non-built ongoing-episode agent steps (still in the agent + # collectors) + the already built agent steps. + return sum(a.agent_steps for a in self.agent_collectors.values()) + sum( + pg.agent_steps for pg in self.policy_collector_groups.values() + ) + + @override(SampleCollector) + def get_inference_input_dict(self, policy_id: PolicyID) -> Dict[str, TensorType]: + policy = self.policy_map[policy_id] + keys = self.forward_pass_agent_keys[policy_id] + batch_size = len(keys) + + # Return empty batch, if no forward pass to do. + if batch_size == 0: + return SampleBatch() + + buffers = {} + for k in keys: + collector = self.agent_collectors[k] + buffers[k] = collector.buffers + # Use one agent's buffer_structs (they should all be the same). + buffer_structs = self.agent_collectors[keys[0]].buffer_structs + + input_dict = {} + for view_col, view_req in policy.view_requirements.items(): + # Not used for action computations. + if not view_req.used_for_compute_actions: + continue + + # Create the batch of data from the different buffers. + data_col = view_req.data_col or view_col + delta = ( + -1 + if data_col + in [ + SampleBatch.OBS, + SampleBatch.INFOS, + SampleBatch.ENV_ID, + SampleBatch.EPS_ID, + SampleBatch.AGENT_INDEX, + SampleBatch.T, + ] + else 0 + ) + # Range of shifts, e.g. "-100:0". Note: This includes index 0! + if view_req.shift_from is not None: + time_indices = (view_req.shift_from + delta, view_req.shift_to + delta) + # Single shift (e.g. -1) or list of shifts, e.g. [-4, -1, 0]. + else: + time_indices = view_req.shift + delta + + # Loop through agents and add up their data (batch). + data = None + for k in keys: + # Buffer for the data does not exist yet: Create dummy + # (zero) data. + if data_col not in buffers[k]: + if view_req.data_col is not None: + space = policy.view_requirements[view_req.data_col].space + else: + space = view_req.space + + if isinstance(space, Space): + fill_value = get_dummy_batch_for_space( + space, + batch_size=0, + ) + else: + fill_value = space + + self.agent_collectors[k]._build_buffers({data_col: fill_value}) + + if data is None: + data = [[] for _ in range(len(buffers[keys[0]][data_col]))] + + # `shift_from` and `shift_to` are defined: User wants a + # view with some time-range. + if isinstance(time_indices, tuple): + # `shift_to` == -1: Until the end (including(!) the + # last item). + if time_indices[1] == -1: + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices[0] :]) + # `shift_to` != -1: "Normal" range. + else: + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices[0] : time_indices[1] + 1]) + # Single index. + else: + for d, b in zip(data, buffers[k][data_col]): + d.append(b[time_indices]) + + np_data = [np.array(d) for d in data] + if data_col in buffer_structs: + input_dict[view_col] = tree.unflatten_as( + buffer_structs[data_col], np_data + ) + else: + input_dict[view_col] = np_data[0] + + self._reset_inference_calls(policy_id) + + return SampleBatch( + input_dict, + seq_lens=np.ones(batch_size, dtype=np.int32) + if "state_in_0" in input_dict + else None, + ) + + @override(SampleCollector) + def postprocess_episode( + self, + episode, + is_done: bool = False, + check_dones: bool = False, + build: bool = False, + ) -> Union[None, SampleBatch, MultiAgentBatch]: + episode_id = episode.episode_id + policy_collector_group = episode.batch_builder + + # Build SampleBatches for the given episode. + pre_batches = {} + for (eps_id, agent_id), collector in self.agent_collectors.items(): + # Build only if there is data and agent is part of given episode. + if collector.agent_steps == 0 or eps_id != episode_id: + continue + pid = self.agent_key_to_policy_id[(eps_id, agent_id)] + policy = self.policy_map[pid] + pre_batch = collector.build_for_training(policy.view_requirements) + pre_batches[agent_id] = (policy, pre_batch) + + # Apply reward clipping before calling postprocessing functions. + if self.clip_rewards is True: + for _, (_, pre_batch) in pre_batches.items(): + pre_batch["rewards"] = np.sign(pre_batch["rewards"]) + elif self.clip_rewards: + for _, (_, pre_batch) in pre_batches.items(): + pre_batch["rewards"] = np.clip( + pre_batch["rewards"], + a_min=-self.clip_rewards, + a_max=self.clip_rewards, + ) + + post_batches = {} + for agent_id, (_, pre_batch) in pre_batches.items(): + # Entire episode is said to be done. + # Error if no DONE at end of this agent's trajectory. + if is_done and check_dones and not pre_batch.is_terminated_or_truncated(): + raise ValueError( + "Episode {} terminated for all agents, but we still " + "don't have a last observation for agent {} (policy " + "{}). ".format( + episode_id, + agent_id, + self.agent_key_to_policy_id[(episode_id, agent_id)], + ) + + "Please ensure that you include the last observations " + "of all live agents when setting truncated[__all__] or " + "terminated[__all__] to True." + ) + + # Skip a trajectory's postprocessing (and thus using it for training), + # if its agent's info exists and contains the training_enabled=False + # setting (used by our PolicyClients). + last_info = episode.last_info_for(agent_id) + if last_info and not last_info.get("training_enabled", True): + if is_done: + agent_key = (episode_id, agent_id) + del self.agent_key_to_policy_id[agent_key] + del self.agent_collectors[agent_key] + continue + + if len(pre_batches) > 1: + other_batches = pre_batches.copy() + del other_batches[agent_id] + else: + other_batches = {} + pid = self.agent_key_to_policy_id[(episode_id, agent_id)] + policy = self.policy_map[pid] + if not pre_batch.is_single_trajectory(): + raise ValueError( + "Batches sent to postprocessing must be from a single trajectory! " + "TERMINATED & TRUNCATED need to be False everywhere, except the " + "last timestep, which can be either True or False for those keys)!", + pre_batch, + ) + elif len(set(pre_batch[SampleBatch.EPS_ID])) > 1: + episode_ids = set(pre_batch[SampleBatch.EPS_ID]) + raise ValueError( + "Batches sent to postprocessing must only contain steps " + "from a single episode! Your trajectory contains data from " + f"{len(episode_ids)} episodes ({list(episode_ids)}).", + pre_batch, + ) + # Call the Policy's Exploration's postprocess method. + post_batches[agent_id] = pre_batch + if getattr(policy, "exploration", None) is not None: + policy.exploration.postprocess_trajectory( + policy, post_batches[agent_id], policy.get_session() + ) + post_batches[agent_id].set_get_interceptor(None) + post_batches[agent_id] = policy.postprocess_trajectory( + post_batches[agent_id], other_batches, episode + ) + + if log_once("after_post"): + logger.info( + "Trajectory fragment after postprocess_trajectory():\n\n{}\n".format( + summarize(post_batches) + ) + ) + + # Append into policy batches and reset. + from ray.rllib.evaluation.rollout_worker import get_global_worker + + for agent_id, post_batch in sorted(post_batches.items()): + agent_key = (episode_id, agent_id) + pid = self.agent_key_to_policy_id[agent_key] + policy = self.policy_map[pid] + self.callbacks.on_postprocess_trajectory( + worker=get_global_worker(), + episode=episode, + agent_id=agent_id, + policy_id=pid, + policies=self.policy_map, + postprocessed_batch=post_batch, + original_batches=pre_batches, + ) + + # Add the postprocessed SampleBatch to the policy collectors for + # training. + # PID may be a newly added policy. Just confirm we have it in our + # policy map before proceeding with adding a new _PolicyCollector() + # to the group. + if pid not in policy_collector_group.policy_collectors: + assert pid in self.policy_map + policy_collector_group.policy_collectors[pid] = _PolicyCollector(policy) + policy_collector_group.policy_collectors[ + pid + ].add_postprocessed_batch_for_training(post_batch, policy.view_requirements) + + if is_done: + del self.agent_key_to_policy_id[agent_key] + del self.agent_collectors[agent_key] + + if policy_collector_group: + env_steps = self.episode_steps[episode_id] + policy_collector_group.env_steps += env_steps + agent_steps = self.agent_steps[episode_id] + policy_collector_group.agent_steps += agent_steps + + if is_done: + del self.episode_steps[episode_id] + del self.episodes[episode_id] + + if episode_id in self.agent_steps: + del self.agent_steps[episode_id] + else: + assert ( + len(pre_batches) == 0 + ), "Expected the batch to be empty since the episode_id is missing." + # if the key does not exist it means that throughout the episode all + # observations were empty (i.e. there was no agent in the env) + msg = ( + f"Data from episode {episode_id} does not show any agent " + f"interactions. Hint: Make sure for at least one timestep in the " + f"episode, env.step() returns non-empty values." + ) + raise ValueError(msg) + + # Make PolicyCollectorGroup available for more agent batches in + # other episodes. Do not reset count to 0. + if policy_collector_group: + self.policy_collector_groups.append(policy_collector_group) + else: + self.episode_steps[episode_id] = self.agent_steps[episode_id] = 0 + + # Build a MultiAgentBatch from the episode and return. + if build: + return self._build_multi_agent_batch(episode) + + def _build_multi_agent_batch(self, episode) -> Union[MultiAgentBatch, SampleBatch]: + + ma_batch = {} + for pid, collector in episode.batch_builder.policy_collectors.items(): + if collector.agent_steps > 0: + ma_batch[pid] = collector.build() + + # TODO(sven): We should always return the same type here (MultiAgentBatch), + # no matter what. Just have to unify our `training_step` methods, then. This + # will reduce a lot of confusion about what comes out of the sampling process. + # Create the batch. + ma_batch = MultiAgentBatch.wrap_as_needed( + ma_batch, env_steps=episode.batch_builder.env_steps + ) + + # PolicyCollectorGroup is empty. + episode.batch_builder.env_steps = 0 + episode.batch_builder.agent_steps = 0 + + return ma_batch + + @override(SampleCollector) + def try_build_truncated_episode_multi_agent_batch( + self, + ) -> List[Union[MultiAgentBatch, SampleBatch]]: + batches = [] + # Loop through ongoing episodes and see whether their length plus + # what's already in the policy collectors reaches the fragment-len + # (abiding to the unit used: env-steps or agent-steps). + for episode_id, episode in self.episodes.items(): + # Measure batch size in env-steps. + if self.count_steps_by == "env_steps": + built_steps = ( + episode.batch_builder.env_steps if episode.batch_builder else 0 + ) + ongoing_steps = self.episode_steps[episode_id] + # Measure batch-size in agent-steps. + else: + built_steps = ( + episode.batch_builder.agent_steps if episode.batch_builder else 0 + ) + ongoing_steps = self.agent_steps[episode_id] + + # Reached the fragment-len -> We should build an MA-Batch. + if built_steps + ongoing_steps >= self.rollout_fragment_length: + if self.count_steps_by == "env_steps": + assert built_steps + ongoing_steps == self.rollout_fragment_length + # If we reached the fragment-len only because of `episode_id` + # (still ongoing) -> postprocess `episode_id` first. + if built_steps < self.rollout_fragment_length: + self.postprocess_episode(episode, is_done=False) + # If there is a builder for this episode, + # build the MA-batch and add to return values. + if episode.batch_builder: + batch = self._build_multi_agent_batch(episode=episode) + batches.append(batch) + # No batch-builder: + # We have reached the rollout-fragment length w/o any agent + # steps! Warn that the environment may never request any + # actions from any agents. + elif log_once("no_agent_steps"): + logger.warning( + "Your environment seems to be stepping w/o ever " + "emitting agent observations (agents are never " + "requested to act)!" + ) + + return batches + + def _add_to_next_inference_call(self, agent_key: Tuple[EpisodeID, AgentID]) -> None: + """Adds an Agent key (episode+agent IDs) to the next inference call. + + This makes sure that the agent's current data (in the trajectory) is + used for generating the next input_dict for a + `Policy.compute_actions()` call. + + Args: + agent_key (Tuple[EpisodeID, AgentID]: A unique agent key (across + vectorized environments). + """ + pid = self.agent_key_to_policy_id[agent_key] + + # PID may be a newly added policy (added on the fly during training). + # Just confirm we have it in our policy map before proceeding with + # forward_pass_size=0. + if pid not in self.forward_pass_size: + assert pid in self.policy_map + self.forward_pass_size[pid] = 0 + self.forward_pass_agent_keys[pid] = [] + + idx = self.forward_pass_size[pid] + assert idx >= 0 + if idx == 0: + self.forward_pass_agent_keys[pid].clear() + + self.forward_pass_agent_keys[pid].append(agent_key) + self.forward_pass_size[pid] += 1 + + def _reset_inference_calls(self, policy_id: PolicyID) -> None: + """Resets internal inference input-dict registries. + + Calling `self.get_inference_input_dict()` after this method is called + would return an empty input-dict. + + Args: + policy_id: The policy ID for which to reset the + inference pointers. + """ + self.forward_pass_size[policy_id] = 0 diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..f052ee7915573de385ef446d32a266d6b7ff475b --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/env_runner_v2.py @@ -0,0 +1,1232 @@ +from collections import defaultdict +import logging +import time +import tree # pip install dm_tree +from typing import TYPE_CHECKING, Dict, Iterator, List, Optional, Set, Tuple, Union +import numpy as np + +from ray.rllib.env.base_env import ASYNC_RESET_RETURN, BaseEnv +from ray.rllib.env.external_env import ExternalEnvWrapper +from ray.rllib.env.wrappers.atari_wrappers import MonitorEnv, get_wrapper_by_cls +from ray.rllib.evaluation.collectors.simple_list_collector import _PolicyCollectorGroup +from ray.rllib.evaluation.episode_v2 import EpisodeV2 +from ray.rllib.evaluation.metrics import RolloutMetrics +from ray.rllib.models.preprocessors import Preprocessor +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch, concat_samples +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.filter import Filter +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.spaces.space_utils import unbatch, get_original_space +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + AgentConnectorDataType, + AgentID, + EnvActionType, + EnvID, + EnvInfoDict, + EnvObsType, + MultiAgentDict, + MultiEnvDict, + PolicyID, + PolicyOutputType, + SampleBatchType, + StateBatches, + TensorStructType, +) +from ray.util.debug import log_once + +if TYPE_CHECKING: + from gymnasium.envs.classic_control.rendering import SimpleImageViewer + + from ray.rllib.callbacks.callbacks import RLlibCallback + from ray.rllib.evaluation.rollout_worker import RolloutWorker + + +logger = logging.getLogger(__name__) + + +MIN_LARGE_BATCH_THRESHOLD = 1000 +DEFAULT_LARGE_BATCH_THRESHOLD = 5000 +MS_TO_SEC = 1000.0 + + +@OldAPIStack +class _PerfStats: + """Sampler perf stats that will be included in rollout metrics.""" + + def __init__(self, ema_coef: Optional[float] = None): + # If not None, enable Exponential Moving Average mode. + # The way we update stats is by: + # updated = (1 - ema_coef) * old + ema_coef * new + # In general provides more responsive stats about sampler performance. + # TODO(jungong) : make ema the default (only) mode if it works well. + self.ema_coef = ema_coef + + self.iters = 0 + self.raw_obs_processing_time = 0.0 + self.inference_time = 0.0 + self.action_processing_time = 0.0 + self.env_wait_time = 0.0 + self.env_render_time = 0.0 + + def incr(self, field: str, value: Union[int, float]): + if field == "iters": + self.iters += value + return + + # All the other fields support either global average or ema mode. + if self.ema_coef is None: + # Global average. + self.__dict__[field] += value + else: + self.__dict__[field] = (1.0 - self.ema_coef) * self.__dict__[ + field + ] + self.ema_coef * value + + def _get_avg(self): + # Mean multiplicator (1000 = sec -> ms). + factor = MS_TO_SEC / self.iters + return { + # Raw observation preprocessing. + "mean_raw_obs_processing_ms": self.raw_obs_processing_time * factor, + # Computing actions through policy. + "mean_inference_ms": self.inference_time * factor, + # Processing actions (to be sent to env, e.g. clipping). + "mean_action_processing_ms": self.action_processing_time * factor, + # Waiting for environment (during poll). + "mean_env_wait_ms": self.env_wait_time * factor, + # Environment rendering (False by default). + "mean_env_render_ms": self.env_render_time * factor, + } + + def _get_ema(self): + # In EMA mode, stats are already (exponentially) averaged, + # hence we only need to do the sec -> ms conversion here. + return { + # Raw observation preprocessing. + "mean_raw_obs_processing_ms": self.raw_obs_processing_time * MS_TO_SEC, + # Computing actions through policy. + "mean_inference_ms": self.inference_time * MS_TO_SEC, + # Processing actions (to be sent to env, e.g. clipping). + "mean_action_processing_ms": self.action_processing_time * MS_TO_SEC, + # Waiting for environment (during poll). + "mean_env_wait_ms": self.env_wait_time * MS_TO_SEC, + # Environment rendering (False by default). + "mean_env_render_ms": self.env_render_time * MS_TO_SEC, + } + + def get(self): + if self.ema_coef is None: + return self._get_avg() + else: + return self._get_ema() + + +@OldAPIStack +class _NewDefaultDict(defaultdict): + def __missing__(self, env_id): + ret = self[env_id] = self.default_factory(env_id) + return ret + + +@OldAPIStack +def _build_multi_agent_batch( + episode_id: int, + batch_builder: _PolicyCollectorGroup, + large_batch_threshold: int, + multiple_episodes_in_batch: bool, +) -> MultiAgentBatch: + """Build MultiAgentBatch from a dict of _PolicyCollectors. + + Args: + env_steps: total env steps. + policy_collectors: collected training SampleBatchs by policy. + + Returns: + Always returns a sample batch in MultiAgentBatch format. + """ + ma_batch = {} + for pid, collector in batch_builder.policy_collectors.items(): + if collector.agent_steps <= 0: + continue + + if batch_builder.agent_steps > large_batch_threshold and log_once( + "large_batch_warning" + ): + logger.warning( + "More than {} observations in {} env steps for " + "episode {} ".format( + batch_builder.agent_steps, batch_builder.env_steps, episode_id + ) + + "are buffered in the sampler. If this is more than you " + "expected, check that that you set a horizon on your " + "environment correctly and that it terminates at some " + "point. Note: In multi-agent environments, " + "`rollout_fragment_length` sets the batch size based on " + "(across-agents) environment steps, not the steps of " + "individual agents, which can result in unexpectedly " + "large batches." + + ( + "Also, you may be waiting for your Env to " + "terminate (batch_mode=`complete_episodes`). Make sure " + "it does at some point." + if not multiple_episodes_in_batch + else "" + ) + ) + + batch = collector.build() + + ma_batch[pid] = batch + + # Create the multi agent batch. + return MultiAgentBatch(policy_batches=ma_batch, env_steps=batch_builder.env_steps) + + +@OldAPIStack +def _batch_inference_sample_batches(eval_data: List[SampleBatch]) -> SampleBatch: + """Batch a list of input SampleBatches into a single SampleBatch. + + Args: + eval_data: list of SampleBatches. + + Returns: + single batched SampleBatch. + """ + inference_batch = concat_samples(eval_data) + if "state_in_0" in inference_batch: + batch_size = len(eval_data) + inference_batch[SampleBatch.SEQ_LENS] = np.ones(batch_size, dtype=np.int32) + return inference_batch + + +@OldAPIStack +class EnvRunnerV2: + """Collect experiences from user environment using Connectors.""" + + def __init__( + self, + worker: "RolloutWorker", + base_env: BaseEnv, + multiple_episodes_in_batch: bool, + callbacks: "RLlibCallback", + perf_stats: _PerfStats, + rollout_fragment_length: int = 200, + count_steps_by: str = "env_steps", + render: bool = None, + ): + """ + Args: + worker: Reference to the current rollout worker. + base_env: Env implementing BaseEnv. + multiple_episodes_in_batch: Whether to pack multiple + episodes into each batch. This guarantees batches will be exactly + `rollout_fragment_length` in size. + callbacks: User callbacks to run on episode events. + perf_stats: Record perf stats into this object. + rollout_fragment_length: The length of a fragment to collect + before building a SampleBatch from the data and resetting + the SampleBatchBuilder object. + count_steps_by: One of "env_steps" (default) or "agent_steps". + Use "agent_steps", if you want rollout lengths to be counted + by individual agent steps. In a multi-agent env, + a single env_step contains one or more agent_steps, depending + on how many agents are present at any given time in the + ongoing episode. + render: Whether to try to render the environment after each + step. + """ + self._worker = worker + if isinstance(base_env, ExternalEnvWrapper): + raise ValueError( + "Policies using the new Connector API do not support ExternalEnv." + ) + self._base_env = base_env + self._multiple_episodes_in_batch = multiple_episodes_in_batch + self._callbacks = callbacks + self._perf_stats = perf_stats + self._rollout_fragment_length = rollout_fragment_length + self._count_steps_by = count_steps_by + self._render = render + + # May be populated for image rendering. + self._simple_image_viewer: Optional[ + "SimpleImageViewer" + ] = self._get_simple_image_viewer() + + # Keeps track of active episodes. + self._active_episodes: Dict[EnvID, EpisodeV2] = {} + self._batch_builders: Dict[EnvID, _PolicyCollectorGroup] = _NewDefaultDict( + self._new_batch_builder + ) + + self._large_batch_threshold: int = ( + max(MIN_LARGE_BATCH_THRESHOLD, self._rollout_fragment_length * 10) + if self._rollout_fragment_length != float("inf") + else DEFAULT_LARGE_BATCH_THRESHOLD + ) + + def _get_simple_image_viewer(self): + """Maybe construct a SimpleImageViewer instance for episode rendering.""" + # Try to render the env, if required. + if not self._render: + return None + + try: + from gymnasium.envs.classic_control.rendering import SimpleImageViewer + + return SimpleImageViewer() + except (ImportError, ModuleNotFoundError): + self._render = False # disable rendering + logger.warning( + "Could not import gymnasium.envs.classic_control." + "rendering! Try `pip install gymnasium[all]`." + ) + + return None + + def _call_on_episode_start(self, episode, env_id): + # Call each policy's Exploration.on_episode_start method. + # Note: This may break the exploration (e.g. ParameterNoise) of + # policies in the `policy_map` that have not been recently used + # (and are therefore stashed to disk). However, we certainly do not + # want to loop through all (even stashed) policies here as that + # would counter the purpose of the LRU policy caching. + for p in self._worker.policy_map.cache.values(): + if getattr(p, "exploration", None) is not None: + p.exploration.on_episode_start( + policy=p, + environment=self._base_env, + episode=episode, + tf_sess=p.get_session(), + ) + # Call `on_episode_start()` callback. + self._callbacks.on_episode_start( + worker=self._worker, + base_env=self._base_env, + policies=self._worker.policy_map, + env_index=env_id, + episode=episode, + ) + + def _new_batch_builder(self, _) -> _PolicyCollectorGroup: + """Create a new batch builder. + + We create a _PolicyCollectorGroup based on the full policy_map + as the batch builder. + """ + return _PolicyCollectorGroup(self._worker.policy_map) + + def run(self) -> Iterator[SampleBatchType]: + """Samples and yields training episodes continuously. + + Yields: + Object containing state, action, reward, terminal condition, + and other fields as dictated by `policy`. + """ + while True: + outputs = self.step() + for o in outputs: + yield o + + def step(self) -> List[SampleBatchType]: + """Samples training episodes by stepping through environments.""" + + self._perf_stats.incr("iters", 1) + + t0 = time.time() + # Get observations from all ready agents. + # types: MultiEnvDict, MultiEnvDict, MultiEnvDict, MultiEnvDict, ... + ( + unfiltered_obs, + rewards, + terminateds, + truncateds, + infos, + off_policy_actions, + ) = self._base_env.poll() + env_poll_time = time.time() - t0 + + # Process observations and prepare for policy evaluation. + t1 = time.time() + # types: Set[EnvID], Dict[PolicyID, List[AgentConnectorDataType]], + # List[Union[RolloutMetrics, SampleBatchType]] + active_envs, to_eval, outputs = self._process_observations( + unfiltered_obs=unfiltered_obs, + rewards=rewards, + terminateds=terminateds, + truncateds=truncateds, + infos=infos, + ) + self._perf_stats.incr("raw_obs_processing_time", time.time() - t1) + + # Do batched policy eval (accross vectorized envs). + t2 = time.time() + # types: Dict[PolicyID, Tuple[TensorStructType, StateBatch, dict]] + eval_results = self._do_policy_eval(to_eval=to_eval) + self._perf_stats.incr("inference_time", time.time() - t2) + + # Process results and update episode state. + t3 = time.time() + actions_to_send: Dict[ + EnvID, Dict[AgentID, EnvActionType] + ] = self._process_policy_eval_results( + active_envs=active_envs, + to_eval=to_eval, + eval_results=eval_results, + off_policy_actions=off_policy_actions, + ) + self._perf_stats.incr("action_processing_time", time.time() - t3) + + # Return computed actions to ready envs. We also send to envs that have + # taken off-policy actions; those envs are free to ignore the action. + t4 = time.time() + self._base_env.send_actions(actions_to_send) + self._perf_stats.incr("env_wait_time", env_poll_time + time.time() - t4) + + self._maybe_render() + + return outputs + + def _get_rollout_metrics( + self, episode: EpisodeV2, policy_map: Dict[str, Policy] + ) -> List[RolloutMetrics]: + """Get rollout metrics from completed episode.""" + # TODO(jungong) : why do we need to handle atari metrics differently? + # Can we unify atari and normal env metrics? + atari_metrics: List[RolloutMetrics] = _fetch_atari_metrics(self._base_env) + if atari_metrics is not None: + for m in atari_metrics: + m._replace(custom_metrics=episode.custom_metrics) + return atari_metrics + # Create connector metrics + connector_metrics = {} + active_agents = episode.get_agents() + for agent in active_agents: + policy_id = episode.policy_for(agent) + policy = episode.policy_map[policy_id] + connector_metrics[policy_id] = policy.get_connector_metrics() + # Otherwise, return RolloutMetrics for the episode. + return [ + RolloutMetrics( + episode_length=episode.length, + episode_reward=episode.total_reward, + agent_rewards=dict(episode.agent_rewards), + custom_metrics=episode.custom_metrics, + perf_stats={}, + hist_data=episode.hist_data, + media=episode.media, + connector_metrics=connector_metrics, + ) + ] + + def _process_observations( + self, + unfiltered_obs: MultiEnvDict, + rewards: MultiEnvDict, + terminateds: MultiEnvDict, + truncateds: MultiEnvDict, + infos: MultiEnvDict, + ) -> Tuple[ + Set[EnvID], + Dict[PolicyID, List[AgentConnectorDataType]], + List[Union[RolloutMetrics, SampleBatchType]], + ]: + """Process raw obs from env. + + Group data for active agents by policy. Reset environments that are done. + + Args: + unfiltered_obs: The unfiltered, raw observations from the BaseEnv + (vectorized, possibly multi-agent). Dict of dict: By env index, + then agent ID, then mapped to actual obs. + rewards: The rewards MultiEnvDict of the BaseEnv. + terminateds: The `terminated` flags MultiEnvDict of the BaseEnv. + truncateds: The `truncated` flags MultiEnvDict of the BaseEnv. + infos: The MultiEnvDict of infos dicts of the BaseEnv. + + Returns: + A tuple of: + A list of envs that were active during this step. + AgentConnectorDataType for active agents for policy evaluation. + SampleBatches and RolloutMetrics for completed agents for output. + """ + # Output objects. + # Note that we need to track envs that are active during this round explicitly, + # just to be confident which envs require us to send at least an empty action + # dict to. + # We can not get this from the _active_episode or to_eval lists because + # 1. All envs are not required to step during every single step. And + # 2. to_eval only contains data for the agents that are still active. An env may + # be active but all agents are done during the step. + active_envs: Set[EnvID] = set() + to_eval: Dict[PolicyID, List[AgentConnectorDataType]] = defaultdict(list) + outputs: List[Union[RolloutMetrics, SampleBatchType]] = [] + + # For each (vectorized) sub-environment. + # types: EnvID, Dict[AgentID, EnvObsType] + for env_id, env_obs in unfiltered_obs.items(): + # Check for env_id having returned an error instead of a multi-agent + # obs dict. This is how our BaseEnv can tell the caller to `poll()` that + # one of its sub-environments is faulty and should be restarted (and the + # ongoing episode should not be used for training). + if isinstance(env_obs, Exception): + assert terminateds[env_id]["__all__"] is True, ( + f"ERROR: When a sub-environment (env-id {env_id}) returns an error " + "as observation, the terminateds[__all__] flag must also be set to " + "True!" + ) + # all_agents_obs is an Exception here. + # Drop this episode and skip to next. + self._handle_done_episode( + env_id=env_id, + env_obs_or_exception=env_obs, + is_done=True, + active_envs=active_envs, + to_eval=to_eval, + outputs=outputs, + ) + continue + + if env_id not in self._active_episodes: + episode: EpisodeV2 = self.create_episode(env_id) + self._active_episodes[env_id] = episode + else: + episode: EpisodeV2 = self._active_episodes[env_id] + # If this episode is brand-new, call the episode start callback(s). + # Note: EpisodeV2s are initialized with length=-1 (before the reset). + if not episode.has_init_obs(): + self._call_on_episode_start(episode, env_id) + + # Check episode termination conditions. + if terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"]: + all_agents_done = True + else: + all_agents_done = False + active_envs.add(env_id) + + # Special handling of common info dict. + episode.set_last_info("__common__", infos[env_id].get("__common__", {})) + + # Agent sample batches grouped by policy. Each set of sample batches will + # go through agent connectors together. + sample_batches_by_policy = defaultdict(list) + # Whether an agent is terminated or truncated. + agent_terminateds = {} + agent_truncateds = {} + for agent_id, obs in env_obs.items(): + assert agent_id != "__all__" + + policy_id: PolicyID = episode.policy_for(agent_id) + + agent_terminated = bool( + terminateds[env_id]["__all__"] or terminateds[env_id].get(agent_id) + ) + agent_terminateds[agent_id] = agent_terminated + agent_truncated = bool( + truncateds[env_id]["__all__"] + or truncateds[env_id].get(agent_id, False) + ) + agent_truncateds[agent_id] = agent_truncated + + # A completely new agent is already done -> Skip entirely. + if not episode.has_init_obs(agent_id) and ( + agent_terminated or agent_truncated + ): + continue + + values_dict = { + SampleBatch.T: episode.length, # Episodes start at -1 before we + # add the initial obs. After that, we infer from initial obs at + # t=0 since that will be our new episode.length. + SampleBatch.ENV_ID: env_id, + SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), + # Last action (SampleBatch.ACTIONS) column will be populated by + # StateBufferConnector. + # Reward received after taking action at timestep t. + SampleBatch.REWARDS: rewards[env_id].get(agent_id, 0.0), + # After taking action=a, did we reach terminal? + SampleBatch.TERMINATEDS: agent_terminated, + # Was the episode truncated artificially + # (e.g. b/c of some time limit)? + SampleBatch.TRUNCATEDS: agent_truncated, + SampleBatch.INFOS: infos[env_id].get(agent_id, {}), + SampleBatch.NEXT_OBS: obs, + } + + # Queue this obs sample for connector preprocessing. + sample_batches_by_policy[policy_id].append((agent_id, values_dict)) + + # The entire episode is done. + if all_agents_done: + # Let's check to see if there are any agents that haven't got the + # last obs yet. If there are, we have to create fake-last + # observations for them. (the environment is not required to do so if + # terminateds[__all__]==True or truncateds[__all__]==True). + for agent_id in episode.get_agents(): + # If the latest obs we got for this agent is done, or if its + # episode state is already done, nothing to do. + if ( + agent_terminateds.get(agent_id, False) + or agent_truncateds.get(agent_id, False) + or episode.is_done(agent_id) + ): + continue + + policy_id: PolicyID = episode.policy_for(agent_id) + policy = self._worker.policy_map[policy_id] + + # Create a fake observation by sampling the original env + # observation space. + obs_space = get_original_space(policy.observation_space) + # Although there is no obs for this agent, there may be + # good rewards and info dicts for it. + # This is the case for e.g. OpenSpiel games, where a reward + # is only earned with the last step, but the obs for that + # step is {}. + reward = rewards[env_id].get(agent_id, 0.0) + info = infos[env_id].get(agent_id, {}) + values_dict = { + SampleBatch.T: episode.length, + SampleBatch.ENV_ID: env_id, + SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), + # TODO(sven): These should be the summed-up(!) rewards since the + # last observation received for this agent. + SampleBatch.REWARDS: reward, + SampleBatch.TERMINATEDS: True, + SampleBatch.TRUNCATEDS: truncateds[env_id].get(agent_id, False), + SampleBatch.INFOS: info, + SampleBatch.NEXT_OBS: obs_space.sample(), + } + + # Queue these fake obs for connector preprocessing too. + sample_batches_by_policy[policy_id].append((agent_id, values_dict)) + + # Run agent connectors. + for policy_id, batches in sample_batches_by_policy.items(): + policy: Policy = self._worker.policy_map[policy_id] + # Collected full MultiAgentDicts for this environment. + # Run agent connectors. + assert ( + policy.agent_connectors + ), "EnvRunnerV2 requires agent connectors to work." + + acd_list: List[AgentConnectorDataType] = [ + AgentConnectorDataType(env_id, agent_id, data) + for agent_id, data in batches + ] + + # For all agents mapped to policy_id, run their data + # through agent_connectors. + processed = policy.agent_connectors(acd_list) + + for d in processed: + # Record transition info if applicable. + if not episode.has_init_obs(d.agent_id): + episode.add_init_obs( + agent_id=d.agent_id, + init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS], + init_infos=d.data.raw_dict[SampleBatch.INFOS], + t=d.data.raw_dict[SampleBatch.T], + ) + else: + episode.add_action_reward_done_next_obs( + d.agent_id, d.data.raw_dict + ) + + # Need to evaluate next actions. + if not ( + all_agents_done + or agent_terminateds.get(d.agent_id, False) + or agent_truncateds.get(d.agent_id, False) + or episode.is_done(d.agent_id) + ): + # Add to eval set if env is not done and this particular agent + # is also not done. + item = AgentConnectorDataType(d.env_id, d.agent_id, d.data) + to_eval[policy_id].append(item) + + # Finished advancing episode by 1 step, mark it so. + episode.step() + + # Exception: The very first env.poll() call causes the env to get reset + # (no step taken yet, just a single starting observation logged). + # We need to skip this callback in this case. + if episode.length > 0: + # Invoke the `on_episode_step` callback after the step is logged + # to the episode. + self._callbacks.on_episode_step( + worker=self._worker, + base_env=self._base_env, + policies=self._worker.policy_map, + episode=episode, + env_index=env_id, + ) + + # Episode is terminated/truncated for all agents + # (terminateds[__all__] == True or truncateds[__all__] == True). + if all_agents_done: + # _handle_done_episode will build a MultiAgentBatch for all + # the agents that are done during this step of rollout in + # the case of _multiple_episodes_in_batch=False. + self._handle_done_episode( + env_id, + env_obs, + terminateds[env_id]["__all__"] or truncateds[env_id]["__all__"], + active_envs, + to_eval, + outputs, + ) + + # Try to build something. + if self._multiple_episodes_in_batch: + sample_batch = self._try_build_truncated_episode_multi_agent_batch( + self._batch_builders[env_id], episode + ) + if sample_batch: + outputs.append(sample_batch) + + # SampleBatch built from data collected by batch_builder. + # Clean up and delete the batch_builder. + del self._batch_builders[env_id] + + return active_envs, to_eval, outputs + + def _build_done_episode( + self, + env_id: EnvID, + is_done: bool, + outputs: List[SampleBatchType], + ): + """Builds a MultiAgentSampleBatch from the episode and adds it to outputs. + + Args: + env_id: The env id. + is_done: Whether the env is done. + outputs: The list of outputs to add the + """ + episode: EpisodeV2 = self._active_episodes[env_id] + batch_builder = self._batch_builders[env_id] + + episode.postprocess_episode( + batch_builder=batch_builder, + is_done=is_done, + check_dones=is_done, + ) + + # If, we are not allowed to pack the next episode into the same + # SampleBatch (batch_mode=complete_episodes) -> Build the + # MultiAgentBatch from a single episode and add it to "outputs". + # Otherwise, just postprocess and continue collecting across + # episodes. + if not self._multiple_episodes_in_batch: + ma_sample_batch = _build_multi_agent_batch( + episode.episode_id, + batch_builder, + self._large_batch_threshold, + self._multiple_episodes_in_batch, + ) + if ma_sample_batch: + outputs.append(ma_sample_batch) + + # SampleBatch built from data collected by batch_builder. + # Clean up and delete the batch_builder. + del self._batch_builders[env_id] + + def __process_resetted_obs_for_eval( + self, + env_id: EnvID, + obs: Dict[EnvID, Dict[AgentID, EnvObsType]], + infos: Dict[EnvID, Dict[AgentID, EnvInfoDict]], + episode: EpisodeV2, + to_eval: Dict[PolicyID, List[AgentConnectorDataType]], + ): + """Process resetted obs through agent connectors for policy eval. + + Args: + env_id: The env id. + obs: The Resetted obs. + episode: New episode. + to_eval: List of agent connector data for policy eval. + """ + per_policy_resetted_obs: Dict[PolicyID, List] = defaultdict(list) + # types: AgentID, EnvObsType + for agent_id, raw_obs in obs[env_id].items(): + policy_id: PolicyID = episode.policy_for(agent_id) + per_policy_resetted_obs[policy_id].append((agent_id, raw_obs)) + + for policy_id, agents_obs in per_policy_resetted_obs.items(): + policy = self._worker.policy_map[policy_id] + acd_list: List[AgentConnectorDataType] = [ + AgentConnectorDataType( + env_id, + agent_id, + { + SampleBatch.NEXT_OBS: obs, + SampleBatch.INFOS: infos, + SampleBatch.T: episode.length, + SampleBatch.AGENT_INDEX: episode.agent_index(agent_id), + }, + ) + for agent_id, obs in agents_obs + ] + # Call agent connectors on these initial obs. + processed = policy.agent_connectors(acd_list) + + for d in processed: + episode.add_init_obs( + agent_id=d.agent_id, + init_obs=d.data.raw_dict[SampleBatch.NEXT_OBS], + init_infos=d.data.raw_dict[SampleBatch.INFOS], + t=d.data.raw_dict[SampleBatch.T], + ) + to_eval[policy_id].append(d) + + def _handle_done_episode( + self, + env_id: EnvID, + env_obs_or_exception: MultiAgentDict, + is_done: bool, + active_envs: Set[EnvID], + to_eval: Dict[PolicyID, List[AgentConnectorDataType]], + outputs: List[SampleBatchType], + ) -> None: + """Handle an all-finished episode. + + Add collected SampleBatch to batch builder. Reset corresponding env, etc. + + Args: + env_id: Environment ID. + env_obs_or_exception: Last per-environment observation or Exception. + env_infos: Last per-environment infos. + is_done: If all agents are done. + active_envs: Set of active env ids. + to_eval: Output container for policy eval data. + outputs: Output container for collected sample batches. + """ + if isinstance(env_obs_or_exception, Exception): + episode_or_exception: Exception = env_obs_or_exception + # Tell the sampler we have got a faulty episode. + outputs.append(RolloutMetrics(episode_faulty=True)) + else: + episode_or_exception: EpisodeV2 = self._active_episodes[env_id] + # Add rollout metrics. + outputs.extend( + self._get_rollout_metrics( + episode_or_exception, policy_map=self._worker.policy_map + ) + ) + # Output the collected episode after adding rollout metrics so that we + # always fetch metrics with RolloutWorker before we fetch samples. + # This is because we need to behave like env_runner() for now. + self._build_done_episode(env_id, is_done, outputs) + + # Clean up and deleted the post-processed episode now that we have collected + # its data. + self.end_episode(env_id, episode_or_exception) + # Create a new episode instance (before we reset the sub-environment). + new_episode: EpisodeV2 = self.create_episode(env_id) + + # The sub environment at index `env_id` might throw an exception + # during the following `try_reset()` attempt. If configured with + # `restart_failed_sub_environments=True`, the BaseEnv will restart + # the affected sub environment (create a new one using its c'tor) and + # must reset the recreated sub env right after that. + # Should the sub environment fail indefinitely during these + # repeated reset attempts, the entire worker will be blocked. + # This would be ok, b/c the alternative would be the worker crashing + # entirely. + while True: + resetted_obs, resetted_infos = self._base_env.try_reset(env_id) + + if ( + resetted_obs is None + or resetted_obs == ASYNC_RESET_RETURN + or not isinstance(resetted_obs[env_id], Exception) + ): + break + else: + # Report a faulty episode. + outputs.append(RolloutMetrics(episode_faulty=True)) + + # Reset connector state if this is a hard reset. + for p in self._worker.policy_map.cache.values(): + p.agent_connectors.reset(env_id) + + # Creates a new episode if this is not async return. + # If reset is async, we will get its result in some future poll. + if resetted_obs is not None and resetted_obs != ASYNC_RESET_RETURN: + self._active_episodes[env_id] = new_episode + self._call_on_episode_start(new_episode, env_id) + + self.__process_resetted_obs_for_eval( + env_id, + resetted_obs, + resetted_infos, + new_episode, + to_eval, + ) + + # Step after adding initial obs. This will give us 0 env and agent step. + new_episode.step() + active_envs.add(env_id) + + def create_episode(self, env_id: EnvID) -> EpisodeV2: + """Creates a new EpisodeV2 instance and returns it. + + Calls `on_episode_created` callbacks, but does NOT reset the respective + sub-environment yet. + + Args: + env_id: Env ID. + + Returns: + The newly created EpisodeV2 instance. + """ + # Make sure we currently don't have an active episode under this env ID. + assert env_id not in self._active_episodes + + # Create a new episode under the same `env_id` and call the + # `on_episode_created` callbacks. + new_episode = EpisodeV2( + env_id, + self._worker.policy_map, + self._worker.policy_mapping_fn, + worker=self._worker, + callbacks=self._callbacks, + ) + + # Call `on_episode_created()` callback. + self._callbacks.on_episode_created( + worker=self._worker, + base_env=self._base_env, + policies=self._worker.policy_map, + env_index=env_id, + episode=new_episode, + ) + return new_episode + + def end_episode( + self, env_id: EnvID, episode_or_exception: Union[EpisodeV2, Exception] + ): + """Cleans up an episode that has finished. + + Args: + env_id: Env ID. + episode_or_exception: Instance of an episode if it finished successfully. + Otherwise, the exception that was thrown, + """ + # Signal the end of an episode, either successfully with an Episode or + # unsuccessfully with an Exception. + self._callbacks.on_episode_end( + worker=self._worker, + base_env=self._base_env, + policies=self._worker.policy_map, + episode=episode_or_exception, + env_index=env_id, + ) + + # Call each (in-memory) policy's Exploration.on_episode_end + # method. + # Note: This may break the exploration (e.g. ParameterNoise) of + # policies in the `policy_map` that have not been recently used + # (and are therefore stashed to disk). However, we certainly do not + # want to loop through all (even stashed) policies here as that + # would counter the purpose of the LRU policy caching. + for p in self._worker.policy_map.cache.values(): + if getattr(p, "exploration", None) is not None: + p.exploration.on_episode_end( + policy=p, + environment=self._base_env, + episode=episode_or_exception, + tf_sess=p.get_session(), + ) + + if isinstance(episode_or_exception, EpisodeV2): + episode = episode_or_exception + if episode.total_agent_steps == 0: + # if the key does not exist it means that throughout the episode all + # observations were empty (i.e. there was no agent in the env) + msg = ( + f"Data from episode {episode.episode_id} does not show any agent " + f"interactions. Hint: Make sure for at least one timestep in the " + f"episode, env.step() returns non-empty values." + ) + raise ValueError(msg) + + # Clean up the episode and batch_builder for this env id. + if env_id in self._active_episodes: + del self._active_episodes[env_id] + + def _try_build_truncated_episode_multi_agent_batch( + self, batch_builder: _PolicyCollectorGroup, episode: EpisodeV2 + ) -> Union[None, SampleBatch, MultiAgentBatch]: + # Measure batch size in env-steps. + if self._count_steps_by == "env_steps": + built_steps = batch_builder.env_steps + ongoing_steps = episode.active_env_steps + # Measure batch-size in agent-steps. + else: + built_steps = batch_builder.agent_steps + ongoing_steps = episode.active_agent_steps + + # Reached the fragment-len -> We should build an MA-Batch. + if built_steps + ongoing_steps >= self._rollout_fragment_length: + if self._count_steps_by != "agent_steps": + assert built_steps + ongoing_steps == self._rollout_fragment_length, ( + f"built_steps ({built_steps}) + ongoing_steps ({ongoing_steps}) != " + f"rollout_fragment_length ({self._rollout_fragment_length})." + ) + + # If we reached the fragment-len only because of `episode_id` + # (still ongoing) -> postprocess `episode_id` first. + if built_steps < self._rollout_fragment_length: + episode.postprocess_episode(batch_builder=batch_builder, is_done=False) + + # If builder has collected some data, + # build the MA-batch and add to return values. + if batch_builder.agent_steps > 0: + return _build_multi_agent_batch( + episode.episode_id, + batch_builder, + self._large_batch_threshold, + self._multiple_episodes_in_batch, + ) + # No batch-builder: + # We have reached the rollout-fragment length w/o any agent + # steps! Warn that the environment may never request any + # actions from any agents. + elif log_once("no_agent_steps"): + logger.warning( + "Your environment seems to be stepping w/o ever " + "emitting agent observations (agents are never " + "requested to act)!" + ) + + return None + + def _do_policy_eval( + self, + to_eval: Dict[PolicyID, List[AgentConnectorDataType]], + ) -> Dict[PolicyID, PolicyOutputType]: + """Call compute_actions on collected episode data to get next action. + + Args: + to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects + (items in these lists will be the batch's items for the model + forward pass). + + Returns: + Dict mapping PolicyIDs to compute_actions_from_input_dict() outputs. + """ + policies = self._worker.policy_map + + # In case policy map has changed, try to find the new policy that + # should handle all these per-agent eval data. + # Throws exception if these agents are mapped to multiple different + # policies now. + def _try_find_policy_again(eval_data: AgentConnectorDataType): + policy_id = None + for d in eval_data: + episode = self._active_episodes[d.env_id] + # Force refresh policy mapping on the episode. + pid = episode.policy_for(d.agent_id, refresh=True) + if policy_id is not None and pid != policy_id: + raise ValueError( + "Policy map changed. The list of eval data that was handled " + f"by a same policy is now handled by policy {pid} " + "and {policy_id}. " + "Please don't do this in the middle of an episode." + ) + policy_id = pid + return _get_or_raise(self._worker.policy_map, policy_id) + + eval_results: Dict[PolicyID, TensorStructType] = {} + for policy_id, eval_data in to_eval.items(): + # In case the policyID has been removed from this worker, we need to + # re-assign policy_id and re-lookup the Policy object to use. + try: + policy: Policy = _get_or_raise(policies, policy_id) + except ValueError: + # policy_mapping_fn from the worker may have already been + # changed (mapping fn not staying constant within one episode). + policy: Policy = _try_find_policy_again(eval_data) + + input_dict = _batch_inference_sample_batches( + [d.data.sample_batch for d in eval_data] + ) + + eval_results[policy_id] = policy.compute_actions_from_input_dict( + input_dict, + timestep=policy.global_timestep, + episodes=[self._active_episodes[t.env_id] for t in eval_data], + ) + + return eval_results + + def _process_policy_eval_results( + self, + active_envs: Set[EnvID], + to_eval: Dict[PolicyID, List[AgentConnectorDataType]], + eval_results: Dict[PolicyID, PolicyOutputType], + off_policy_actions: MultiEnvDict, + ): + """Process the output of policy neural network evaluation. + + Records policy evaluation results into agent connectors and + returns replies to send back to agents in the env. + + Args: + active_envs: Set of env IDs that are still active. + to_eval: Mapping of policy IDs to lists of AgentConnectorDataType objects. + eval_results: Mapping of policy IDs to list of + actions, rnn-out states, extra-action-fetches dicts. + off_policy_actions: Doubly keyed dict of env-ids -> agent ids -> + off-policy-action, returned by a `BaseEnv.poll()` call. + + Returns: + Nested dict of env id -> agent id -> actions to be sent to + Env (np.ndarrays). + """ + actions_to_send: Dict[EnvID, Dict[AgentID, EnvActionType]] = defaultdict(dict) + + for env_id in active_envs: + actions_to_send[env_id] = {} # at minimum send empty dict + + # types: PolicyID, List[AgentConnectorDataType] + for policy_id, eval_data in to_eval.items(): + actions: TensorStructType = eval_results[policy_id][0] + actions = convert_to_numpy(actions) + + rnn_out: StateBatches = eval_results[policy_id][1] + extra_action_out: dict = eval_results[policy_id][2] + + # In case actions is a list (representing the 0th dim of a batch of + # primitive actions), try converting it first. + if isinstance(actions, list): + actions = np.array(actions) + # Split action-component batches into single action rows. + actions: List[EnvActionType] = unbatch(actions) + + policy: Policy = _get_or_raise(self._worker.policy_map, policy_id) + assert ( + policy.agent_connectors and policy.action_connectors + ), "EnvRunnerV2 requires action connectors to work." + + # types: int, EnvActionType + for i, action in enumerate(actions): + env_id: int = eval_data[i].env_id + agent_id: AgentID = eval_data[i].agent_id + input_dict: TensorStructType = eval_data[i].data.raw_dict + + rnn_states: List[StateBatches] = tree.map_structure( + lambda x, i=i: x[i], rnn_out + ) + + # extra_action_out could be a nested dict + fetches: Dict = tree.map_structure( + lambda x, i=i: x[i], extra_action_out + ) + + # Post-process policy output by running them through action connectors. + ac_data = ActionConnectorDataType( + env_id, agent_id, input_dict, (action, rnn_states, fetches) + ) + + action_to_send, rnn_states, fetches = policy.action_connectors( + ac_data + ).output + + # The action we want to buffer is the direct output of + # compute_actions_from_input_dict() here. This is because we want to + # send the unsqushed actions to the environment while learning and + # possibly basing subsequent actions on the squashed actions. + action_to_buffer = ( + action + if env_id not in off_policy_actions + or agent_id not in off_policy_actions[env_id] + else off_policy_actions[env_id][agent_id] + ) + + # Notify agent connectors with this new policy output. + # Necessary for state buffering agent connectors, for example. + ac_data: ActionConnectorDataType = ActionConnectorDataType( + env_id, + agent_id, + input_dict, + (action_to_buffer, rnn_states, fetches), + ) + policy.agent_connectors.on_policy_output(ac_data) + + assert agent_id not in actions_to_send[env_id] + actions_to_send[env_id][agent_id] = action_to_send + + return actions_to_send + + def _maybe_render(self): + """Visualize environment.""" + # Check if we should render. + if not self._render or not self._simple_image_viewer: + return + + t5 = time.time() + + # Render can either return an RGB image (uint8 [w x h x 3] numpy + # array) or take care of rendering itself (returning True). + rendered = self._base_env.try_render() + # Rendering returned an image -> Display it in a SimpleImageViewer. + if isinstance(rendered, np.ndarray) and len(rendered.shape) == 3: + self._simple_image_viewer.imshow(rendered) + elif rendered not in [True, False, None]: + raise ValueError( + f"The env's ({self._base_env}) `try_render()` method returned an" + " unsupported value! Make sure you either return a " + "uint8/w x h x 3 (RGB) image or handle rendering in a " + "window and then return `True`." + ) + + self._perf_stats.incr("env_render_time", time.time() - t5) + + +def _fetch_atari_metrics(base_env: BaseEnv) -> List[RolloutMetrics]: + """Atari games have multiple logical episodes, one per life. + + However, for metrics reporting we count full episodes, all lives included. + """ + sub_environments = base_env.get_sub_environments() + if not sub_environments: + return None + atari_out = [] + for sub_env in sub_environments: + monitor = get_wrapper_by_cls(sub_env, MonitorEnv) + if not monitor: + return None + for eps_rew, eps_len in monitor.next_episode_results(): + atari_out.append(RolloutMetrics(eps_len, eps_rew)) + return atari_out + + +def _get_or_raise( + mapping: Dict[PolicyID, Union[Policy, Preprocessor, Filter]], policy_id: PolicyID +) -> Union[Policy, Preprocessor, Filter]: + """Returns an object under key `policy_id` in `mapping`. + + Args: + mapping (Dict[PolicyID, Union[Policy, Preprocessor, Filter]]): The + mapping dict from policy id (str) to actual object (Policy, + Preprocessor, etc.). + policy_id: The policy ID to lookup. + + Returns: + Union[Policy, Preprocessor, Filter]: The found object. + + Raises: + ValueError: If `policy_id` cannot be found in `mapping`. + """ + if policy_id not in mapping: + raise ValueError( + "Could not find policy for agent: PolicyID `{}` not found " + "in policy map, whose keys are `{}`.".format(policy_id, mapping.keys()) + ) + return mapping[policy_id] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py new file mode 100644 index 0000000000000000000000000000000000000000..e894bee48a561b1279a81af06166dcf5d43a17a8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/episode_v2.py @@ -0,0 +1,378 @@ +import random +from collections import defaultdict +import numpy as np +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple + +from ray.rllib.env.base_env import _DUMMY_AGENT_ID +from ray.rllib.evaluation.collectors.simple_list_collector import ( + _PolicyCollector, + _PolicyCollectorGroup, +) +from ray.rllib.evaluation.collectors.agent_collector import AgentCollector +from ray.rllib.policy.policy_map import PolicyMap +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.typing import AgentID, EnvID, EnvInfoDict, PolicyID, TensorType + +if TYPE_CHECKING: + from ray.rllib.callbacks.callbacks import RLlibCallback + from ray.rllib.evaluation.rollout_worker import RolloutWorker + + +@OldAPIStack +class EpisodeV2: + """Tracks the current state of a (possibly multi-agent) episode.""" + + def __init__( + self, + env_id: EnvID, + policies: PolicyMap, + policy_mapping_fn: Callable[[AgentID, "EpisodeV2", "RolloutWorker"], PolicyID], + *, + worker: Optional["RolloutWorker"] = None, + callbacks: Optional["RLlibCallback"] = None, + ): + """Initializes an Episode instance. + + Args: + env_id: The environment's ID in which this episode runs. + policies: The PolicyMap object (mapping PolicyIDs to Policy + objects) to use for determining, which policy is used for + which agent. + policy_mapping_fn: The mapping function mapping AgentIDs to + PolicyIDs. + worker: The RolloutWorker instance, in which this episode runs. + """ + # Unique id identifying this trajectory. + self.episode_id: int = random.randrange(int(1e18)) + # ID of the environment this episode is tracking. + self.env_id = env_id + # Summed reward across all agents in this episode. + self.total_reward: float = 0.0 + # Active (uncollected) # of env steps taken by this episode. + # Start from -1. After add_init_obs(), we will be at 0 step. + self.active_env_steps: int = -1 + # Total # of env steps taken by this episode. + # Start from -1, After add_init_obs(), we will be at 0 step. + self.total_env_steps: int = -1 + # Active (uncollected) agent steps. + self.active_agent_steps: int = 0 + # Total # of steps take by all agents in this env. + self.total_agent_steps: int = 0 + # Dict for user to add custom metrics. + # TODO (sven): We should probably unify custom_metrics, user_data, + # and hist_data into a single data container for user to track per-step. + # metrics and states. + self.custom_metrics: Dict[str, float] = {} + # Temporary storage. E.g. storing data in between two custom + # callbacks referring to the same episode. + self.user_data: Dict[str, Any] = {} + # Dict mapping str keys to List[float] for storage of + # per-timestep float data throughout the episode. + self.hist_data: Dict[str, List[float]] = {} + self.media: Dict[str, Any] = {} + + self.worker = worker + self.callbacks = callbacks + + self.policy_map: PolicyMap = policies + self.policy_mapping_fn: Callable[ + [AgentID, "EpisodeV2", "RolloutWorker"], PolicyID + ] = policy_mapping_fn + # Per-agent data collectors. + self._agent_to_policy: Dict[AgentID, PolicyID] = {} + self._agent_collectors: Dict[AgentID, AgentCollector] = {} + + self._next_agent_index: int = 0 + self._agent_to_index: Dict[AgentID, int] = {} + + # Summed rewards broken down by agent. + self.agent_rewards: Dict[Tuple[AgentID, PolicyID], float] = defaultdict(float) + self._agent_reward_history: Dict[AgentID, List[int]] = defaultdict(list) + + self._has_init_obs: Dict[AgentID, bool] = {} + self._last_terminateds: Dict[AgentID, bool] = {} + self._last_truncateds: Dict[AgentID, bool] = {} + # Keep last info dict around, in case an environment tries to signal + # us something. + self._last_infos: Dict[AgentID, Dict] = {} + + def policy_for( + self, agent_id: AgentID = _DUMMY_AGENT_ID, refresh: bool = False + ) -> PolicyID: + """Returns and stores the policy ID for the specified agent. + + If the agent is new, the policy mapping fn will be called to bind the + agent to a policy for the duration of the entire episode (even if the + policy_mapping_fn is changed in the meantime!). + + Args: + agent_id: The agent ID to lookup the policy ID for. + + Returns: + The policy ID for the specified agent. + """ + + # Perform a new policy_mapping_fn lookup and bind AgentID for the + # duration of this episode to the returned PolicyID. + if agent_id not in self._agent_to_policy or refresh: + policy_id = self._agent_to_policy[agent_id] = self.policy_mapping_fn( + agent_id, # agent_id + self, # episode + worker=self.worker, + ) + # Use already determined PolicyID. + else: + policy_id = self._agent_to_policy[agent_id] + + # PolicyID not found in policy map -> Error. + if policy_id not in self.policy_map: + raise KeyError( + "policy_mapping_fn returned invalid policy id " f"'{policy_id}'!" + ) + return policy_id + + def get_agents(self) -> List[AgentID]: + """Returns list of agent IDs that have appeared in this episode. + + Returns: + The list of all agent IDs that have appeared so far in this + episode. + """ + return list(self._agent_to_index.keys()) + + def agent_index(self, agent_id: AgentID) -> int: + """Get the index of an agent among its environment. + + A new index will be created if an agent is seen for the first time. + + Args: + agent_id: ID of an agent. + + Returns: + The index of this agent. + """ + if agent_id not in self._agent_to_index: + self._agent_to_index[agent_id] = self._next_agent_index + self._next_agent_index += 1 + return self._agent_to_index[agent_id] + + def step(self) -> None: + """Advance the episode forward by one step.""" + self.active_env_steps += 1 + self.total_env_steps += 1 + + def add_init_obs( + self, + *, + agent_id: AgentID, + init_obs: TensorType, + init_infos: Dict[str, TensorType], + t: int = -1, + ) -> None: + """Add initial env obs at the start of a new episode + + Args: + agent_id: Agent ID. + init_obs: Initial observations. + init_infos: Initial infos dicts. + t: timestamp. + """ + policy = self.policy_map[self.policy_for(agent_id)] + + # Add initial obs to Trajectory. + assert agent_id not in self._agent_collectors + + self._agent_collectors[agent_id] = AgentCollector( + policy.view_requirements, + max_seq_len=policy.config["model"]["max_seq_len"], + disable_action_flattening=policy.config.get( + "_disable_action_flattening", False + ), + is_policy_recurrent=policy.is_recurrent(), + intial_states=policy.get_initial_state(), + _enable_new_api_stack=False, + ) + self._agent_collectors[agent_id].add_init_obs( + episode_id=self.episode_id, + agent_index=self.agent_index(agent_id), + env_id=self.env_id, + init_obs=init_obs, + init_infos=init_infos, + t=t, + ) + + self._has_init_obs[agent_id] = True + + def add_action_reward_done_next_obs( + self, + agent_id: AgentID, + values: Dict[str, TensorType], + ) -> None: + """Add action, reward, info, and next_obs as a new step. + + Args: + agent_id: Agent ID. + values: Dict of action, reward, info, and next_obs. + """ + # Make sure, agent already has some (at least init) data. + assert agent_id in self._agent_collectors + + self.active_agent_steps += 1 + self.total_agent_steps += 1 + + # Include the current agent id for multi-agent algorithms. + if agent_id != _DUMMY_AGENT_ID: + values["agent_id"] = agent_id + + # Add action/reward/next-obs (and other data) to Trajectory. + self._agent_collectors[agent_id].add_action_reward_next_obs(values) + + # Keep track of agent reward history. + reward = values[SampleBatch.REWARDS] + self.total_reward += reward + self.agent_rewards[(agent_id, self.policy_for(agent_id))] += reward + self._agent_reward_history[agent_id].append(reward) + + # Keep track of last terminated info for agent. + if SampleBatch.TERMINATEDS in values: + self._last_terminateds[agent_id] = values[SampleBatch.TERMINATEDS] + # Keep track of last truncated info for agent. + if SampleBatch.TRUNCATEDS in values: + self._last_truncateds[agent_id] = values[SampleBatch.TRUNCATEDS] + + # Keep track of last info dict if available. + if SampleBatch.INFOS in values: + self.set_last_info(agent_id, values[SampleBatch.INFOS]) + + def postprocess_episode( + self, + batch_builder: _PolicyCollectorGroup, + is_done: bool = False, + check_dones: bool = False, + ) -> None: + """Build and return currently collected training samples by policies. + + Clear agent collector states if this episode is done. + + Args: + batch_builder: _PolicyCollectorGroup for saving the collected per-agent + sample batches. + is_done: If this episode is done (terminated or truncated). + check_dones: Whether to make sure per-agent trajectories are actually done. + """ + # TODO: (sven) Once we implement multi-agent communication channels, + # we have to resolve the restriction of only sending other agent + # batches from the same policy to the postprocess methods. + # Build SampleBatches for the given episode. + pre_batches = {} + for agent_id, collector in self._agent_collectors.items(): + # Build only if there is data and agent is part of given episode. + if collector.agent_steps == 0: + continue + pid = self.policy_for(agent_id) + policy = self.policy_map[pid] + pre_batch = collector.build_for_training(policy.view_requirements) + pre_batches[agent_id] = (pid, policy, pre_batch) + + for agent_id, (pid, policy, pre_batch) in pre_batches.items(): + # Entire episode is said to be done. + # Error if no DONE at end of this agent's trajectory. + if is_done and check_dones and not pre_batch.is_terminated_or_truncated(): + raise ValueError( + "Episode {} terminated for all agents, but we still " + "don't have a last observation for agent {} (policy " + "{}). ".format(self.episode_id, agent_id, self.policy_for(agent_id)) + + "Please ensure that you include the last observations " + "of all live agents when setting done[__all__] to " + "True." + ) + + # Skip a trajectory's postprocessing (and thus using it for training), + # if its agent's info exists and contains the training_enabled=False + # setting (used by our PolicyClients). + if not self._last_infos.get(agent_id, {}).get("training_enabled", True): + continue + + if ( + not pre_batch.is_single_trajectory() + or len(np.unique(pre_batch[SampleBatch.EPS_ID])) > 1 + ): + raise ValueError( + "Batches sent to postprocessing must only contain steps " + "from a single trajectory.", + pre_batch, + ) + + if len(pre_batches) > 1: + other_batches = pre_batches.copy() + del other_batches[agent_id] + else: + other_batches = {} + + # Call the Policy's Exploration's postprocess method. + post_batch = pre_batch + if getattr(policy, "exploration", None) is not None: + policy.exploration.postprocess_trajectory( + policy, post_batch, policy.get_session() + ) + post_batch.set_get_interceptor(None) + post_batch = policy.postprocess_trajectory(post_batch, other_batches, self) + + from ray.rllib.evaluation.rollout_worker import get_global_worker + + self.callbacks.on_postprocess_trajectory( + worker=get_global_worker(), + episode=self, + agent_id=agent_id, + policy_id=pid, + policies=self.policy_map, + postprocessed_batch=post_batch, + original_batches=pre_batches, + ) + + # Append post_batch for return. + if pid not in batch_builder.policy_collectors: + batch_builder.policy_collectors[pid] = _PolicyCollector(policy) + batch_builder.policy_collectors[pid].add_postprocessed_batch_for_training( + post_batch, policy.view_requirements + ) + + batch_builder.agent_steps += self.active_agent_steps + batch_builder.env_steps += self.active_env_steps + + # AgentCollector cleared. + self.active_agent_steps = 0 + self.active_env_steps = 0 + + def has_init_obs(self, agent_id: AgentID = None) -> bool: + """Returns whether this episode has initial obs for an agent. + + If agent_id is None, return whether we have received any initial obs, + in other words, whether this episode is completely fresh. + """ + if agent_id is not None: + return agent_id in self._has_init_obs and self._has_init_obs[agent_id] + else: + return any(list(self._has_init_obs.values())) + + def is_done(self, agent_id: AgentID) -> bool: + return self.is_terminated(agent_id) or self.is_truncated(agent_id) + + def is_terminated(self, agent_id: AgentID) -> bool: + return self._last_terminateds.get(agent_id, False) + + def is_truncated(self, agent_id: AgentID) -> bool: + return self._last_truncateds.get(agent_id, False) + + def set_last_info(self, agent_id: AgentID, info: Dict): + self._last_infos[agent_id] = info + + def last_info_for( + self, agent_id: AgentID = _DUMMY_AGENT_ID + ) -> Optional[EnvInfoDict]: + return self._last_infos.get(agent_id) + + @property + def length(self): + return self.total_env_steps diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/metrics.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..016ad2a86264389a131097f9d11a5132ab9c1a20 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/metrics.py @@ -0,0 +1,266 @@ +import collections +import logging +import numpy as np +from typing import List, Optional, TYPE_CHECKING + +from ray.rllib.policy.sample_batch import DEFAULT_POLICY_ID +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.metrics.learner_info import LEARNER_STATS_KEY +from ray.rllib.utils.typing import GradInfoDict, LearnerStatsDict, ResultDict + +if TYPE_CHECKING: + from ray.rllib.env.env_runner_group import EnvRunnerGroup + +logger = logging.getLogger(__name__) + +RolloutMetrics = OldAPIStack( + collections.namedtuple( + "RolloutMetrics", + [ + "episode_length", + "episode_reward", + "agent_rewards", + "custom_metrics", + "perf_stats", + "hist_data", + "media", + "episode_faulty", + "connector_metrics", + ], + ) +) +RolloutMetrics.__new__.__defaults__ = (0, 0, {}, {}, {}, {}, {}, False, {}) + + +@OldAPIStack +def get_learner_stats(grad_info: GradInfoDict) -> LearnerStatsDict: + """Return optimization stats reported from the policy. + + .. testcode:: + :skipif: True + + grad_info = worker.learn_on_batch(samples) + + # {"td_error": [...], "learner_stats": {"vf_loss": ..., ...}} + + print(get_stats(grad_info)) + + .. testoutput:: + + {"vf_loss": ..., "policy_loss": ...} + """ + if LEARNER_STATS_KEY in grad_info: + return grad_info[LEARNER_STATS_KEY] + + multiagent_stats = {} + for k, v in grad_info.items(): + if type(v) is dict: + if LEARNER_STATS_KEY in v: + multiagent_stats[k] = v[LEARNER_STATS_KEY] + + return multiagent_stats + + +@OldAPIStack +def collect_metrics( + workers: "EnvRunnerGroup", + remote_worker_ids: Optional[List[int]] = None, + timeout_seconds: int = 180, + keep_custom_metrics: bool = False, +) -> ResultDict: + """Gathers episode metrics from rollout worker set. + + Args: + workers: EnvRunnerGroup. + remote_worker_ids: Optional list of IDs of remote workers to collect + metrics from. + timeout_seconds: Timeout in seconds for collecting metrics from remote workers. + keep_custom_metrics: Whether to keep custom metrics in the result dict as + they are (True) or to aggregate them (False). + + Returns: + A result dict of metrics. + """ + episodes = collect_episodes( + workers, remote_worker_ids, timeout_seconds=timeout_seconds + ) + metrics = summarize_episodes( + episodes, episodes, keep_custom_metrics=keep_custom_metrics + ) + return metrics + + +@OldAPIStack +def collect_episodes( + workers: "EnvRunnerGroup", + remote_worker_ids: Optional[List[int]] = None, + timeout_seconds: int = 180, +) -> List[RolloutMetrics]: + """Gathers new episodes metrics tuples from the given RolloutWorkers. + + Args: + workers: EnvRunnerGroup. + remote_worker_ids: Optional list of IDs of remote workers to collect + metrics from. + timeout_seconds: Timeout in seconds for collecting metrics from remote workers. + + Returns: + List of RolloutMetrics. + """ + # This will drop get_metrics() calls that are too slow. + # We can potentially make this an asynchronous call if this turns + # out to be a problem. + metric_lists = workers.foreach_env_runner( + lambda w: w.get_metrics(), + local_env_runner=True, + remote_worker_ids=remote_worker_ids, + timeout_seconds=timeout_seconds, + ) + if len(metric_lists) == 0: + logger.warning("WARNING: collected no metrics.") + + episodes = [] + for metrics in metric_lists: + episodes.extend(metrics) + + return episodes + + +@OldAPIStack +def summarize_episodes( + episodes: List[RolloutMetrics], + new_episodes: List[RolloutMetrics] = None, + keep_custom_metrics: bool = False, +) -> ResultDict: + """Summarizes a set of episode metrics tuples. + + Args: + episodes: List of most recent n episodes. This may include historical ones + (not newly collected in this iteration) in order to achieve the size of + the smoothing window. + new_episodes: All the episodes that were completed in this iteration. + keep_custom_metrics: Whether to keep custom metrics in the result dict as + they are (True) or to aggregate them (False). + + Returns: + A result dict of metrics. + """ + + if new_episodes is None: + new_episodes = episodes + + episode_rewards = [] + episode_lengths = [] + policy_rewards = collections.defaultdict(list) + custom_metrics = collections.defaultdict(list) + perf_stats = collections.defaultdict(list) + hist_stats = collections.defaultdict(list) + episode_media = collections.defaultdict(list) + connector_metrics = collections.defaultdict(list) + num_faulty_episodes = 0 + + for episode in episodes: + # Faulty episodes may still carry perf_stats data. + for k, v in episode.perf_stats.items(): + perf_stats[k].append(v) + # Continue if this is a faulty episode. + # There should be other meaningful stats to be collected. + if episode.episode_faulty: + num_faulty_episodes += 1 + continue + + episode_lengths.append(episode.episode_length) + episode_rewards.append(episode.episode_reward) + for k, v in episode.custom_metrics.items(): + custom_metrics[k].append(v) + is_multi_agent = ( + len(episode.agent_rewards) > 1 + or DEFAULT_POLICY_ID not in episode.agent_rewards + ) + if is_multi_agent: + for (_, policy_id), reward in episode.agent_rewards.items(): + policy_rewards[policy_id].append(reward) + for k, v in episode.hist_data.items(): + hist_stats[k] += v + for k, v in episode.media.items(): + episode_media[k].append(v) + if hasattr(episode, "connector_metrics"): + # Group connector metrics by connector_metric name for all policies + for per_pipeline_metrics in episode.connector_metrics.values(): + for per_connector_metrics in per_pipeline_metrics.values(): + for connector_metric_name, val in per_connector_metrics.items(): + connector_metrics[connector_metric_name].append(val) + + if episode_rewards: + min_reward = min(episode_rewards) + max_reward = max(episode_rewards) + avg_reward = np.mean(episode_rewards) + else: + min_reward = float("nan") + max_reward = float("nan") + avg_reward = float("nan") + if episode_lengths: + avg_length = np.mean(episode_lengths) + else: + avg_length = float("nan") + + # Show as histogram distributions. + hist_stats["episode_reward"] = episode_rewards + hist_stats["episode_lengths"] = episode_lengths + + policy_reward_min = {} + policy_reward_mean = {} + policy_reward_max = {} + for policy_id, rewards in policy_rewards.copy().items(): + policy_reward_min[policy_id] = np.min(rewards) + policy_reward_mean[policy_id] = np.mean(rewards) + policy_reward_max[policy_id] = np.max(rewards) + + # Show as histogram distributions. + hist_stats["policy_{}_reward".format(policy_id)] = rewards + + for k, v_list in custom_metrics.copy().items(): + filt = [v for v in v_list if not np.any(np.isnan(v))] + if keep_custom_metrics: + custom_metrics[k] = filt + else: + custom_metrics[k + "_mean"] = np.mean(filt) + if filt: + custom_metrics[k + "_min"] = np.min(filt) + custom_metrics[k + "_max"] = np.max(filt) + else: + custom_metrics[k + "_min"] = float("nan") + custom_metrics[k + "_max"] = float("nan") + del custom_metrics[k] + + for k, v_list in perf_stats.copy().items(): + perf_stats[k] = np.mean(v_list) + + mean_connector_metrics = dict() + for k, v_list in connector_metrics.items(): + mean_connector_metrics[k] = np.mean(v_list) + + return dict( + episode_reward_max=max_reward, + episode_reward_min=min_reward, + episode_reward_mean=avg_reward, + episode_len_mean=avg_length, + episode_media=dict(episode_media), + episodes_timesteps_total=sum(episode_lengths), + policy_reward_min=policy_reward_min, + policy_reward_max=policy_reward_max, + policy_reward_mean=policy_reward_mean, + custom_metrics=dict(custom_metrics), + hist_stats=dict(hist_stats), + sampler_perf=dict(perf_stats), + num_faulty_episodes=num_faulty_episodes, + connector_metrics=mean_connector_metrics, + # Added these (duplicate) values here for forward compatibility with the new API + # stack's metrics structure. This allows us to unify our test cases and keeping + # the new API stack clean of backward-compatible keys. + num_episodes=len(new_episodes), + episode_return_max=max_reward, + episode_return_min=min_reward, + episode_return_mean=avg_reward, + episodes_this_iter=len(new_episodes), # deprecate in favor of `num_epsodes_...` + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/observation_function.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/observation_function.py new file mode 100644 index 0000000000000000000000000000000000000000..c670ed5192cfefe29b76bf6b5e557e2bcb46fea5 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/observation_function.py @@ -0,0 +1,87 @@ +from typing import Dict + +from ray.rllib.env import BaseEnv +from ray.rllib.policy import Policy +from ray.rllib.evaluation import RolloutWorker +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import TensorType +from ray.rllib.utils.typing import AgentID, PolicyID + + +@OldAPIStack +class ObservationFunction: + """Interceptor function for rewriting observations from the environment. + + These callbacks can be used for preprocessing of observations, especially + in multi-agent scenarios. + + Observation functions can be specified in the multi-agent config by + specifying ``{"observation_fn": your_obs_func}``. Note that + ``your_obs_func`` can be a plain Python function. + + This API is **experimental**. + """ + + def __call__( + self, + agent_obs: Dict[AgentID, TensorType], + worker: RolloutWorker, + base_env: BaseEnv, + policies: Dict[PolicyID, Policy], + episode, + **kw + ) -> Dict[AgentID, TensorType]: + """Callback run on each environment step to observe the environment. + + This method takes in the original agent observation dict returned by + a MultiAgentEnv, and returns a possibly modified one. It can be + thought of as a "wrapper" around the environment. + + TODO(ekl): allow end-to-end differentiation through the observation + function and policy losses. + + TODO(ekl): enable batch processing. + + Args: + agent_obs: Dictionary of default observations from the + environment. The default implementation of observe() simply + returns this dict. + worker: Reference to the current rollout worker. + base_env: BaseEnv running the episode. The underlying + sub environment objects (BaseEnvs are vectorized) can be + retrieved by calling `base_env.get_sub_environments()`. + policies: Mapping of policy id to policy objects. In single + agent mode there will only be a single "default" policy. + episode: Episode state object. + kwargs: Forward compatibility placeholder. + + Returns: + new_agent_obs: copy of agent obs with updates. You can + rewrite or drop data from the dict if needed (e.g., the env + can have a dummy "global" observation, and the observer can + merge the global state into individual observations. + + .. testcode:: + :skipif: True + + # Observer that merges global state into individual obs. It is + # rewriting the discrete obs into a tuple with global state. + example_obs_fn1({"a": 1, "b": 2, "global_state": 101}, ...) + + .. testoutput:: + + {"a": [1, 101], "b": [2, 101]} + + .. testcode:: + :skipif: True + + # Observer for e.g., custom centralized critic model. It is + # rewriting the discrete obs into a dict with more data. + example_obs_fn2({"a": 1, "b": 2}, ...) + + .. testoutput:: + + {"a": {"self": 1, "other": 2}, "b": {"self": 2, "other": 1}} + """ + + return agent_obs diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/postprocessing.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/postprocessing.py new file mode 100644 index 0000000000000000000000000000000000000000..4b0a6c79bd60216408c9a772ab5630be01132321 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/postprocessing.py @@ -0,0 +1,328 @@ +import numpy as np +import scipy.signal +from typing import Dict, Optional + +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI, OldAPIStack +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.typing import AgentID +from ray.rllib.utils.typing import TensorType + + +@DeveloperAPI +class Postprocessing: + """Constant definitions for postprocessing.""" + + ADVANTAGES = "advantages" + VALUE_TARGETS = "value_targets" + + +@OldAPIStack +def adjust_nstep(n_step: int, gamma: float, batch: SampleBatch) -> None: + """Rewrites `batch` to encode n-step rewards, terminateds, truncateds, and next-obs. + + Observations and actions remain unaffected. At the end of the trajectory, + n is truncated to fit in the traj length. + + Args: + n_step: The number of steps to look ahead and adjust. + gamma: The discount factor. + batch: The SampleBatch to adjust (in place). + + Examples: + n-step=3 + Trajectory=o0 r0 d0, o1 r1 d1, o2 r2 d2, o3 r3 d3, o4 r4 d4=True o5 + gamma=0.9 + Returned trajectory: + 0: o0 [r0 + 0.9*r1 + 0.9^2*r2 + 0.9^3*r3] d3 o0'=o3 + 1: o1 [r1 + 0.9*r2 + 0.9^2*r3 + 0.9^3*r4] d4 o1'=o4 + 2: o2 [r2 + 0.9*r3 + 0.9^2*r4] d4 o1'=o5 + 3: o3 [r3 + 0.9*r4] d4 o3'=o5 + 4: o4 r4 d4 o4'=o5 + """ + + assert ( + batch.is_single_trajectory() + ), "Unexpected terminated|truncated in middle of trajectory!" + + len_ = len(batch) + + # Shift NEXT_OBS, TERMINATEDS, and TRUNCATEDS. + batch[SampleBatch.NEXT_OBS] = np.concatenate( + [ + batch[SampleBatch.OBS][n_step:], + np.stack([batch[SampleBatch.NEXT_OBS][-1]] * min(n_step, len_)), + ], + axis=0, + ) + batch[SampleBatch.TERMINATEDS] = np.concatenate( + [ + batch[SampleBatch.TERMINATEDS][n_step - 1 :], + np.tile(batch[SampleBatch.TERMINATEDS][-1], min(n_step - 1, len_)), + ], + axis=0, + ) + # Only fix `truncateds`, if present in the batch. + if SampleBatch.TRUNCATEDS in batch: + batch[SampleBatch.TRUNCATEDS] = np.concatenate( + [ + batch[SampleBatch.TRUNCATEDS][n_step - 1 :], + np.tile(batch[SampleBatch.TRUNCATEDS][-1], min(n_step - 1, len_)), + ], + axis=0, + ) + + # Change rewards in place. + for i in range(len_): + for j in range(1, n_step): + if i + j < len_: + batch[SampleBatch.REWARDS][i] += ( + gamma**j * batch[SampleBatch.REWARDS][i + j] + ) + + +@OldAPIStack +def compute_advantages( + rollout: SampleBatch, + last_r: float, + gamma: float = 0.9, + lambda_: float = 1.0, + use_gae: bool = True, + use_critic: bool = True, + rewards: TensorType = None, + vf_preds: TensorType = None, +): + """Given a rollout, compute its value targets and the advantages. + + Args: + rollout: SampleBatch of a single trajectory. + last_r: Value estimation for last observation. + gamma: Discount factor. + lambda_: Parameter for GAE. + use_gae: Using Generalized Advantage Estimation. + use_critic: Whether to use critic (value estimates). Setting + this to False will use 0 as baseline. + rewards: Override the reward values in rollout. + vf_preds: Override the value function predictions in rollout. + + Returns: + SampleBatch with experience from rollout and processed rewards. + """ + assert ( + SampleBatch.VF_PREDS in rollout or not use_critic + ), "use_critic=True but values not found" + assert use_critic or not use_gae, "Can't use gae without using a value function" + last_r = convert_to_numpy(last_r) + + if rewards is None: + rewards = rollout[SampleBatch.REWARDS] + if vf_preds is None and use_critic: + vf_preds = rollout[SampleBatch.VF_PREDS] + + if use_gae: + vpred_t = np.concatenate([vf_preds, np.array([last_r])]) + delta_t = rewards + gamma * vpred_t[1:] - vpred_t[:-1] + # This formula for the advantage comes from: + # "Generalized Advantage Estimation": https://arxiv.org/abs/1506.02438 + rollout[Postprocessing.ADVANTAGES] = discount_cumsum(delta_t, gamma * lambda_) + rollout[Postprocessing.VALUE_TARGETS] = ( + rollout[Postprocessing.ADVANTAGES] + vf_preds + ).astype(np.float32) + else: + rewards_plus_v = np.concatenate([rewards, np.array([last_r])]) + discounted_returns = discount_cumsum(rewards_plus_v, gamma)[:-1].astype( + np.float32 + ) + + if use_critic: + rollout[Postprocessing.ADVANTAGES] = discounted_returns - vf_preds + rollout[Postprocessing.VALUE_TARGETS] = discounted_returns + else: + rollout[Postprocessing.ADVANTAGES] = discounted_returns + rollout[Postprocessing.VALUE_TARGETS] = np.zeros_like( + rollout[Postprocessing.ADVANTAGES] + ) + + rollout[Postprocessing.ADVANTAGES] = rollout[Postprocessing.ADVANTAGES].astype( + np.float32 + ) + + return rollout + + +@OldAPIStack +def compute_gae_for_sample_batch( + policy: Policy, + sample_batch: SampleBatch, + other_agent_batches: Optional[Dict[AgentID, SampleBatch]] = None, + episode=None, +) -> SampleBatch: + """Adds GAE (generalized advantage estimations) to a trajectory. + + The trajectory contains only data from one episode and from one agent. + - If `config.batch_mode=truncate_episodes` (default), sample_batch may + contain a truncated (at-the-end) episode, in case the + `config.rollout_fragment_length` was reached by the sampler. + - If `config.batch_mode=complete_episodes`, sample_batch will contain + exactly one episode (no matter how long). + New columns can be added to sample_batch and existing ones may be altered. + + Args: + policy: The Policy used to generate the trajectory (`sample_batch`) + sample_batch: The SampleBatch to postprocess. + other_agent_batches: Optional dict of AgentIDs mapping to other + agents' trajectory data (from the same episode). + NOTE: The other agents use the same policy. + episode: Optional multi-agent episode object in which the agents + operated. + + Returns: + The postprocessed, modified SampleBatch (or a new one). + """ + # Compute the SampleBatch.VALUES_BOOTSTRAPPED column, which we'll need for the + # following `last_r` arg in `compute_advantages()`. + sample_batch = compute_bootstrap_value(sample_batch, policy) + + vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS]) + rewards = np.array(sample_batch[SampleBatch.REWARDS]) + # We need to squeeze out the time dimension if there is one + # Sanity check that both have the same shape + if len(vf_preds.shape) == 2: + assert vf_preds.shape == rewards.shape + vf_preds = np.squeeze(vf_preds, axis=1) + rewards = np.squeeze(rewards, axis=1) + squeezed = True + else: + squeezed = False + + # Adds the policy logits, VF preds, and advantages to the batch, + # using GAE ("generalized advantage estimation") or not. + batch = compute_advantages( + rollout=sample_batch, + last_r=sample_batch[SampleBatch.VALUES_BOOTSTRAPPED][-1], + gamma=policy.config["gamma"], + lambda_=policy.config["lambda"], + use_gae=policy.config["use_gae"], + use_critic=policy.config.get("use_critic", True), + vf_preds=vf_preds, + rewards=rewards, + ) + + if squeezed: + # If we needed to squeeze rewards and vf_preds, we need to unsqueeze + # advantages again for it to have the same shape + batch[Postprocessing.ADVANTAGES] = np.expand_dims( + batch[Postprocessing.ADVANTAGES], axis=1 + ) + + return batch + + +@OldAPIStack +def compute_bootstrap_value(sample_batch: SampleBatch, policy: Policy) -> SampleBatch: + """Performs a value function computation at the end of a trajectory. + + If the trajectory is terminated (not truncated), will not use the value function, + but assume that the value of the last timestep is 0.0. + In all other cases, will use the given policy's value function to compute the + "bootstrapped" value estimate at the end of the given trajectory. To do so, the + very last observation (sample_batch[NEXT_OBS][-1]) and - if applicable - + the very last state output (sample_batch[STATE_OUT][-1]) wil be used as inputs to + the value function. + + The thus computed value estimate will be stored in a new column of the + `sample_batch`: SampleBatch.VALUES_BOOTSTRAPPED. Thereby, values at all timesteps + in this column are set to 0.0, except or the last timestep, which receives the + computed bootstrapped value. + This is done, such that in any loss function (which processes raw, intact + trajectories, such as those of IMPALA and APPO) can use this new column as follows: + + Example: numbers=ts in episode, '|'=episode boundary (terminal), + X=bootstrapped value (!= 0.0 b/c ts=12 is not a terminal). + ts=5 is NOT a terminal. + T: 8 9 10 11 12 <- no terminal + VF_PREDS: . . . . . + VALUES_BOOTSTRAPPED: 0 0 0 0 X + + Args: + sample_batch: The SampleBatch (single trajectory) for which to compute the + bootstrap value at the end. This SampleBatch will be altered in place + (by adding a new column: SampleBatch.VALUES_BOOTSTRAPPED). + policy: The Policy object, whose value function to use. + + Returns: + The altered SampleBatch (with the extra SampleBatch.VALUES_BOOTSTRAPPED + column). + """ + # Trajectory is actually complete -> last r=0.0. + if sample_batch[SampleBatch.TERMINATEDS][-1]: + last_r = 0.0 + # Trajectory has been truncated -> last r=VF estimate of last obs. + else: + # Input dict is provided to us automatically via the Model's + # requirements. It's a single-timestep (last one in trajectory) + # input_dict. + # Create an input dict according to the Policy's requirements. + input_dict = sample_batch.get_single_step_input_dict( + policy.view_requirements, index="last" + ) + last_r = policy._value(**input_dict) + + vf_preds = np.array(sample_batch[SampleBatch.VF_PREDS]) + # We need to squeeze out the time dimension if there is one + if len(vf_preds.shape) == 2: + vf_preds = np.squeeze(vf_preds, axis=1) + squeezed = True + else: + squeezed = False + + # Set the SampleBatch.VALUES_BOOTSTRAPPED field to VF_PREDS[1:] + the + # very last timestep (where this bootstrapping value is actually needed), which + # we set to the computed `last_r`. + sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.concatenate( + [ + convert_to_numpy(vf_preds[1:]), + np.array([convert_to_numpy(last_r)], dtype=np.float32), + ], + axis=0, + ) + + if squeezed: + sample_batch[SampleBatch.VF_PREDS] = np.expand_dims(vf_preds, axis=1) + sample_batch[SampleBatch.VALUES_BOOTSTRAPPED] = np.expand_dims( + sample_batch[SampleBatch.VALUES_BOOTSTRAPPED], axis=1 + ) + + return sample_batch + + +@OldAPIStack +def discount_cumsum(x: np.ndarray, gamma: float) -> np.ndarray: + """Calculates the discounted cumulative sum over a reward sequence `x`. + + y[t] - discount*y[t+1] = x[t] + reversed(y)[t] - discount*reversed(y)[t-1] = reversed(x)[t] + + Args: + gamma: The discount factor gamma. + + Returns: + The sequence containing the discounted cumulative sums + for each individual reward in `x` till the end of the trajectory. + + .. testcode:: + :skipif: True + + x = np.array([0.0, 1.0, 2.0, 3.0]) + gamma = 0.9 + discount_cumsum(x, gamma) + + .. testoutput:: + + array([0.0 + 0.9*1.0 + 0.9^2*2.0 + 0.9^3*3.0, + 1.0 + 0.9*2.0 + 0.9^2*3.0, + 2.0 + 0.9*3.0, + 3.0]) + """ + return scipy.signal.lfilter([1], [1, float(-gamma)], x[::-1], axis=0)[::-1] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py new file mode 100644 index 0000000000000000000000000000000000000000..b7a5ee6b1d3b1cd048cb7aef099fbb99fd8ced1c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/rollout_worker.py @@ -0,0 +1,2004 @@ +import copy +import importlib.util +import logging +import os +import platform +import threading +from collections import defaultdict +from types import FunctionType +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Collection, + Dict, + List, + Optional, + Set, + Tuple, + Type, + Union, +) + +from gymnasium.spaces import Space + +import ray +from ray import ObjectRef +from ray import cloudpickle as pickle +from ray.rllib.connectors.util import ( + create_connectors_for_policy, + maybe_get_filters_for_syncing, +) +from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.core.rl_module.rl_module import RLModuleSpec +from ray.rllib.env.base_env import BaseEnv, convert_to_base_env +from ray.rllib.env.env_context import EnvContext +from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.external_multi_agent_env import ExternalMultiAgentEnv +from ray.rllib.env.multi_agent_env import MultiAgentEnv +from ray.rllib.env.wrappers.atari_wrappers import is_atari, wrap_deepmind +from ray.rllib.evaluation.metrics import RolloutMetrics +from ray.rllib.evaluation.sampler import SyncSampler +from ray.rllib.models import ModelCatalog +from ray.rllib.models.preprocessors import Preprocessor +from ray.rllib.offline import ( + D4RLReader, + DatasetReader, + DatasetWriter, + InputReader, + IOContext, + JsonReader, + JsonWriter, + MixedInput, + NoopOutput, + OutputWriter, + ShuffledInput, +) +from ray.rllib.policy.policy import Policy, PolicySpec +from ray.rllib.policy.policy_map import PolicyMap +from ray.rllib.policy.sample_batch import ( + DEFAULT_POLICY_ID, + MultiAgentBatch, + concat_samples, + convert_ma_batch_to_sample_batch, +) +from ray.rllib.policy.torch_policy import TorchPolicy +from ray.rllib.policy.torch_policy_v2 import TorchPolicyV2 +from ray.rllib.utils import force_list +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.debug import summarize, update_global_seed_if_necessary +from ray.rllib.utils.error import ERR_MSG_NO_GPUS, HOWTO_CHANGE_CONFIG +from ray.rllib.utils.filter import Filter, NoFilter +from ray.rllib.utils.framework import try_import_tf, try_import_torch +from ray.rllib.utils.from_config import from_config +from ray.rllib.utils.policy import create_policy_for_framework +from ray.rllib.utils.sgd import do_minibatch_sgd +from ray.rllib.utils.tf_run_builder import _TFRunBuilder +from ray.rllib.utils.tf_utils import get_gpu_devices as get_tf_gpu_devices +from ray.rllib.utils.tf_utils import get_tf_eager_cls_if_necessary +from ray.rllib.utils.typing import ( + AgentID, + EnvCreator, + EnvType, + ModelGradients, + ModelWeights, + MultiAgentPolicyConfigDict, + PartialAlgorithmConfigDict, + PolicyID, + PolicyState, + SampleBatchType, + T, +) +from ray.tune.registry import registry_contains_input, registry_get_input +from ray.util.annotations import PublicAPI +from ray.util.debug import disable_log_once_globally, enable_periodic_logging, log_once +from ray.util.iter import ParallelIteratorWorker + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + from ray.rllib.callbacks.callbacks import RLlibCallback + +tf1, tf, tfv = try_import_tf() +torch, _ = try_import_torch() + +logger = logging.getLogger(__name__) + +# Handle to the current rollout worker, which will be set to the most recently +# created RolloutWorker in this process. This can be helpful to access in +# custom env or policy classes for debugging or advanced use cases. +_global_worker: Optional["RolloutWorker"] = None + + +@OldAPIStack +def get_global_worker() -> "RolloutWorker": + """Returns a handle to the active rollout worker in this process.""" + + global _global_worker + return _global_worker + + +def _update_env_seed_if_necessary( + env: EnvType, seed: int, worker_idx: int, vector_idx: int +): + """Set a deterministic random seed on environment. + + NOTE: this may not work with remote environments (issue #18154). + """ + if seed is None: + return + + # A single RL job is unlikely to have more than 10K + # rollout workers. + max_num_envs_per_env_runner: int = 1000 + assert ( + worker_idx < max_num_envs_per_env_runner + ), "Too many envs per worker. Random seeds may collide." + computed_seed: int = worker_idx * max_num_envs_per_env_runner + vector_idx + seed + + # Gymnasium.env. + # This will silently fail for most Farama-foundation gymnasium environments. + # (they do nothing and return None per default) + if not hasattr(env, "reset"): + if log_once("env_has_no_reset_method"): + logger.info(f"Env {env} doesn't have a `reset()` method. Cannot seed.") + else: + try: + env.reset(seed=computed_seed) + except Exception: + logger.info( + f"Env {env} doesn't support setting a seed via its `reset()` " + "method! Implement this method as `reset(self, *, seed=None, " + "options=None)` for it to abide to the correct API. Cannot seed." + ) + + +@OldAPIStack +class RolloutWorker(ParallelIteratorWorker, EnvRunner): + """Common experience collection class. + + This class wraps a policy instance and an environment class to + collect experiences from the environment. You can create many replicas of + this class as Ray actors to scale RL training. + + This class supports vectorized and multi-agent policy evaluation (e.g., + VectorEnv, MultiAgentEnv, etc.) + + .. testcode:: + :skipif: True + + # Create a rollout worker and using it to collect experiences. + import gymnasium as gym + from ray.rllib.evaluation.rollout_worker import RolloutWorker + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy + worker = RolloutWorker( + env_creator=lambda _: gym.make("CartPole-v1"), + default_policy_class=PPOTF1Policy) + print(worker.sample()) + + # Creating a multi-agent rollout worker + from gymnasium.spaces import Discrete, Box + import random + MultiAgentTrafficGrid = ... + worker = RolloutWorker( + env_creator=lambda _: MultiAgentTrafficGrid(num_cars=25), + config=AlgorithmConfig().multi_agent( + policies={ + # Use an ensemble of two policies for car agents + "car_policy1": + (PGTFPolicy, Box(...), Discrete(...), + AlgorithmConfig.overrides(gamma=0.99)), + "car_policy2": + (PGTFPolicy, Box(...), Discrete(...), + AlgorithmConfig.overrides(gamma=0.95)), + # Use a single shared policy for all traffic lights + "traffic_light_policy": + (PGTFPolicy, Box(...), Discrete(...), {}), + }, + policy_mapping_fn=( + lambda agent_id, episode, **kwargs: + random.choice(["car_policy1", "car_policy2"]) + if agent_id.startswith("car_") else "traffic_light_policy"), + ), + ) + print(worker.sample()) + + .. testoutput:: + + SampleBatch({ + "obs": [[...]], "actions": [[...]], "rewards": [[...]], + "terminateds": [[...]], "truncateds": [[...]], "new_obs": [[...]]} + ) + + MultiAgentBatch({ + "car_policy1": SampleBatch(...), + "car_policy2": SampleBatch(...), + "traffic_light_policy": SampleBatch(...)} + ) + + """ + + def __init__( + self, + *, + env_creator: EnvCreator, + validate_env: Optional[Callable[[EnvType, EnvContext], None]] = None, + config: Optional["AlgorithmConfig"] = None, + worker_index: int = 0, + num_workers: Optional[int] = None, + recreated_worker: bool = False, + log_dir: Optional[str] = None, + spaces: Optional[Dict[PolicyID, Tuple[Space, Space]]] = None, + default_policy_class: Optional[Type[Policy]] = None, + dataset_shards: Optional[List[ray.data.Dataset]] = None, + **kwargs, + ): + """Initializes a RolloutWorker instance. + + Args: + env_creator: Function that returns a gym.Env given an EnvContext + wrapped configuration. + validate_env: Optional callable to validate the generated + environment (only on worker=0). + worker_index: For remote workers, this should be set to a + non-zero and unique value. This index is passed to created envs + through EnvContext so that envs can be configured per worker. + recreated_worker: Whether this worker is a recreated one. Workers are + recreated by an Algorithm (via EnvRunnerGroup) in case + `restart_failed_env_runners=True` and one of the original workers (or + an already recreated one) has failed. They don't differ from original + workers other than the value of this flag (`self.recreated_worker`). + log_dir: Directory where logs can be placed. + spaces: An optional space dict mapping policy IDs + to (obs_space, action_space)-tuples. This is used in case no + Env is created on this RolloutWorker. + """ + self._original_kwargs: dict = locals().copy() + del self._original_kwargs["self"] + + global _global_worker + _global_worker = self + + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + # Default config needed? + if config is None or isinstance(config, dict): + config = AlgorithmConfig().update_from_dict(config or {}) + # Freeze config, so no one else can alter it from here on. + config.freeze() + + # Set extra python env variables before calling super constructor. + if config.extra_python_environs_for_driver and worker_index == 0: + for key, value in config.extra_python_environs_for_driver.items(): + os.environ[key] = str(value) + elif config.extra_python_environs_for_worker and worker_index > 0: + for key, value in config.extra_python_environs_for_worker.items(): + os.environ[key] = str(value) + + def gen_rollouts(): + while True: + yield self.sample() + + ParallelIteratorWorker.__init__(self, gen_rollouts, False) + EnvRunner.__init__(self, config=config) + + self.num_workers = ( + num_workers if num_workers is not None else self.config.num_env_runners + ) + # In case we are reading from distributed datasets, store the shards here + # and pick our shard by our worker-index. + self._ds_shards = dataset_shards + self.worker_index: int = worker_index + + # Lock to be able to lock this entire worker + # (via `self.lock()` and `self.unlock()`). + # This might be crucial to prevent a race condition in case + # `config.policy_states_are_swappable=True` and you are using an Algorithm + # with a learner thread. In this case, the thread might update a policy + # that is being swapped (during the update) by the Algorithm's + # training_step's `RolloutWorker.get_weights()` call (to sync back the + # new weights to all remote workers). + self._lock = threading.Lock() + + if ( + tf1 + and (config.framework_str == "tf2" or config.enable_tf1_exec_eagerly) + # This eager check is necessary for certain all-framework tests + # that use tf's eager_mode() context generator. + and not tf1.executing_eagerly() + ): + tf1.enable_eager_execution() + + if self.config.log_level: + logging.getLogger("ray.rllib").setLevel(self.config.log_level) + + if self.worker_index > 1: + disable_log_once_globally() # only need 1 worker to log + elif self.config.log_level == "DEBUG": + enable_periodic_logging() + + env_context = EnvContext( + self.config.env_config, + worker_index=self.worker_index, + vector_index=0, + num_workers=self.num_workers, + remote=self.config.remote_worker_envs, + recreated_worker=recreated_worker, + ) + self.env_context = env_context + self.config: AlgorithmConfig = config + self.callbacks: RLlibCallback = self.config.callbacks_class() + self.recreated_worker: bool = recreated_worker + + # Setup current policy_mapping_fn. Start with the one from the config, which + # might be None in older checkpoints (nowadays AlgorithmConfig has a proper + # default for this); Need to cover this situation via the backup lambda here. + self.policy_mapping_fn = ( + lambda agent_id, episode, worker, **kw: DEFAULT_POLICY_ID + ) + self.set_policy_mapping_fn(self.config.policy_mapping_fn) + + self.env_creator: EnvCreator = env_creator + # Resolve possible auto-fragment length. + configured_rollout_fragment_length = self.config.get_rollout_fragment_length( + worker_index=self.worker_index + ) + self.total_rollout_fragment_length: int = ( + configured_rollout_fragment_length * self.config.num_envs_per_env_runner + ) + self.preprocessing_enabled: bool = not config._disable_preprocessor_api + self.last_batch: Optional[SampleBatchType] = None + self.global_vars: dict = { + # TODO(sven): Make this per-policy! + "timestep": 0, + # Counter for performed gradient updates per policy in `self.policy_map`. + # Allows for compiling metrics on the off-policy'ness of an update given + # that the number of gradient updates of the sampling policies are known + # to the learner (and can be compared to the learner version of the same + # policy). + "num_grad_updates_per_policy": defaultdict(int), + } + + # If seed is provided, add worker index to it and 10k iff evaluation worker. + self.seed = ( + None + if self.config.seed is None + else self.config.seed + + self.worker_index + + self.config.in_evaluation * 10000 + ) + + # Update the global seed for numpy/random/tf-eager/torch if we are not + # the local worker, otherwise, this was already done in the Algorithm + # object itself. + if self.worker_index > 0: + update_global_seed_if_necessary(self.config.framework_str, self.seed) + + # A single environment provided by the user (via config.env). This may + # also remain None. + # 1) Create the env using the user provided env_creator. This may + # return a gym.Env (incl. MultiAgentEnv), an already vectorized + # VectorEnv, BaseEnv, ExternalEnv, or an ActorHandle (remote env). + # 2) Wrap - if applicable - with Atari/rendering wrappers. + # 3) Seed the env, if necessary. + # 4) Vectorize the existing single env by creating more clones of + # this env and wrapping it with the RLlib BaseEnv class. + self.env = self.make_sub_env_fn = None + + # Create a (single) env for this worker. + if not ( + self.worker_index == 0 + and self.num_workers > 0 + and not self.config.create_env_on_local_worker + ): + # Run the `env_creator` function passing the EnvContext. + self.env = env_creator(copy.deepcopy(self.env_context)) + + clip_rewards = self.config.clip_rewards + + if self.env is not None: + # Custom validation function given, typically a function attribute of the + # Algorithm. + if validate_env is not None: + validate_env(self.env, self.env_context) + + # We can't auto-wrap a BaseEnv. + if isinstance(self.env, (BaseEnv, ray.actor.ActorHandle)): + + def wrap(env): + return env + + # Atari type env and "deepmind" preprocessor pref. + elif is_atari(self.env) and self.config.preprocessor_pref == "deepmind": + # Deepmind wrappers already handle all preprocessing. + self.preprocessing_enabled = False + + # If clip_rewards not explicitly set to False, switch it + # on here (clip between -1.0 and 1.0). + if self.config.clip_rewards is None: + clip_rewards = True + + # Framestacking is used. + use_framestack = self.config.model.get("framestack") is True + + def wrap(env): + env = wrap_deepmind( + env, + dim=self.config.model.get("dim"), + framestack=use_framestack, + noframeskip=self.config.env_config.get("frameskip", 0) == 1, + ) + return env + + elif self.config.preprocessor_pref is None: + # Only turn off preprocessing + self.preprocessing_enabled = False + + def wrap(env): + return env + + else: + + def wrap(env): + return env + + # Wrap env through the correct wrapper. + self.env: EnvType = wrap(self.env) + # Ideally, we would use the same make_sub_env() function below + # to create self.env, but wrap(env) and self.env has a cyclic + # dependency on each other right now, so we would settle on + # duplicating the random seed setting logic for now. + _update_env_seed_if_necessary(self.env, self.seed, self.worker_index, 0) + # Call custom callback function `on_sub_environment_created`. + self.callbacks.on_sub_environment_created( + worker=self, + sub_environment=self.env, + env_context=self.env_context, + ) + + self.make_sub_env_fn = self._get_make_sub_env_fn( + env_creator, env_context, validate_env, wrap, self.seed + ) + + self.spaces = spaces + self.default_policy_class = default_policy_class + self.policy_dict, self.is_policy_to_train = self.config.get_multi_agent_setup( + env=self.env, + spaces=self.spaces, + default_policy_class=self.default_policy_class, + ) + + self.policy_map: Optional[PolicyMap] = None + # TODO(jungong) : clean up after non-connector env_runner is fully deprecated. + self.preprocessors: Dict[PolicyID, Preprocessor] = None + + # Check available number of GPUs. + num_gpus = ( + self.config.num_gpus + if self.worker_index == 0 + else self.config.num_gpus_per_env_runner + ) + + # Error if we don't find enough GPUs. + if ( + ray.is_initialized() + and ray._private.worker._mode() != ray._private.worker.LOCAL_MODE + and not config._fake_gpus + ): + devices = [] + if self.config.framework_str in ["tf2", "tf"]: + devices = get_tf_gpu_devices() + elif self.config.framework_str == "torch": + devices = list(range(torch.cuda.device_count())) + + if len(devices) < num_gpus: + raise RuntimeError( + ERR_MSG_NO_GPUS.format(len(devices), devices) + HOWTO_CHANGE_CONFIG + ) + # Warn, if running in local-mode and actual GPUs (not faked) are + # requested. + elif ( + ray.is_initialized() + and ray._private.worker._mode() == ray._private.worker.LOCAL_MODE + and num_gpus > 0 + and not self.config._fake_gpus + ): + logger.warning( + "You are running ray with `local_mode=True`, but have " + f"configured {num_gpus} GPUs to be used! In local mode, " + f"Policies are placed on the CPU and the `num_gpus` setting " + f"is ignored." + ) + + self.filters: Dict[PolicyID, Filter] = defaultdict(NoFilter) + + # If RLModule API is enabled, multi_rl_module_spec holds the specs of the + # RLModules. + self.multi_rl_module_spec = None + self._update_policy_map(policy_dict=self.policy_dict) + + # Update Policy's view requirements from Model, only if Policy directly + # inherited from base `Policy` class. At this point here, the Policy + # must have it's Model (if any) defined and ready to output an initial + # state. + for pol in self.policy_map.values(): + if not pol._model_init_state_automatically_added: + pol._update_model_view_requirements_from_init_state() + + if ( + self.config.is_multi_agent + and self.env is not None + and not isinstance( + self.env, + (BaseEnv, ExternalMultiAgentEnv, MultiAgentEnv, ray.actor.ActorHandle), + ) + ): + raise ValueError( + f"You are running a multi-agent setup, but the env {self.env} is not a " + f"subclass of BaseEnv, MultiAgentEnv, ActorHandle, or " + f"ExternalMultiAgentEnv!" + ) + + if self.worker_index == 0: + logger.info("Built filter map: {}".format(self.filters)) + + # This RolloutWorker has no env. + if self.env is None: + self.async_env = None + # Use a custom env-vectorizer and call it providing self.env. + elif "custom_vector_env" in self.config: + self.async_env = self.config.custom_vector_env(self.env) + # Default: Vectorize self.env via the make_sub_env function. This adds + # further clones of self.env and creates a RLlib BaseEnv (which is + # vectorized under the hood). + else: + # Always use vector env for consistency even if num_envs_per_env_runner=1. + self.async_env: BaseEnv = convert_to_base_env( + self.env, + make_env=self.make_sub_env_fn, + num_envs=self.config.num_envs_per_env_runner, + remote_envs=self.config.remote_worker_envs, + remote_env_batch_wait_ms=self.config.remote_env_batch_wait_ms, + worker=self, + restart_failed_sub_environments=( + self.config.restart_failed_sub_environments + ), + ) + + # `truncate_episodes`: Allow a batch to contain more than one episode + # (fragments) and always make the batch `rollout_fragment_length` + # long. + rollout_fragment_length_for_sampler = configured_rollout_fragment_length + if self.config.batch_mode == "truncate_episodes": + pack = True + # `complete_episodes`: Never cut episodes and sampler will return + # exactly one (complete) episode per poll. + else: + assert self.config.batch_mode == "complete_episodes" + rollout_fragment_length_for_sampler = float("inf") + pack = False + + # Create the IOContext for this worker. + self.io_context: IOContext = IOContext( + log_dir, self.config, self.worker_index, self + ) + + render = False + if self.config.render_env is True and ( + self.num_workers == 0 or self.worker_index == 1 + ): + render = True + + if self.env is None: + self.sampler = None + else: + self.sampler = SyncSampler( + worker=self, + env=self.async_env, + clip_rewards=clip_rewards, + rollout_fragment_length=rollout_fragment_length_for_sampler, + count_steps_by=self.config.count_steps_by, + callbacks=self.callbacks, + multiple_episodes_in_batch=pack, + normalize_actions=self.config.normalize_actions, + clip_actions=self.config.clip_actions, + observation_fn=self.config.observation_fn, + sample_collector_class=self.config.sample_collector, + render=render, + ) + + self.input_reader: InputReader = self._get_input_creator_from_config()( + self.io_context + ) + self.output_writer: OutputWriter = self._get_output_creator_from_config()( + self.io_context + ) + + # The current weights sequence number (version). May remain None for when + # not tracking weights versions. + self.weights_seq_no: Optional[int] = None + + @override(EnvRunner) + def make_env(self): + # Override this method, b/c it's abstract and must be overridden. + # However, we see no point in implementing it for the old API stack any longer + # (the RolloutWorker class will be deprecated soon). + raise NotImplementedError + + @override(EnvRunner) + def assert_healthy(self): + is_healthy = self.policy_map and self.input_reader and self.output_writer + assert is_healthy, ( + f"RolloutWorker {self} (idx={self.worker_index}; " + f"num_workers={self.num_workers}) not healthy!" + ) + + @override(EnvRunner) + def sample(self, **kwargs) -> SampleBatchType: + """Returns a batch of experience sampled from this worker. + + This method must be implemented by subclasses. + + Returns: + A columnar batch of experiences (e.g., tensors) or a MultiAgentBatch. + + .. testcode:: + :skipif: True + + import gymnasium as gym + from ray.rllib.evaluation.rollout_worker import RolloutWorker + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy + worker = RolloutWorker( + env_creator=lambda _: gym.make("CartPole-v1"), + default_policy_class=PPOTF1Policy, + config=AlgorithmConfig(), + ) + print(worker.sample()) + + .. testoutput:: + + SampleBatch({"obs": [...], "action": [...], ...}) + """ + if self.config.fake_sampler and self.last_batch is not None: + return self.last_batch + elif self.input_reader is None: + raise ValueError( + "RolloutWorker has no `input_reader` object! " + "Cannot call `sample()`. You can try setting " + "`create_env_on_driver` to True." + ) + + if log_once("sample_start"): + logger.info( + "Generating sample batch of size {}".format( + self.total_rollout_fragment_length + ) + ) + + batches = [self.input_reader.next()] + steps_so_far = ( + batches[0].count + if self.config.count_steps_by == "env_steps" + else batches[0].agent_steps() + ) + + # In truncate_episodes mode, never pull more than 1 batch per env. + # This avoids over-running the target batch size. + if ( + self.config.batch_mode == "truncate_episodes" + and not self.config.offline_sampling + ): + max_batches = self.config.num_envs_per_env_runner + else: + max_batches = float("inf") + while steps_so_far < self.total_rollout_fragment_length and ( + len(batches) < max_batches + ): + batch = self.input_reader.next() + steps_so_far += ( + batch.count + if self.config.count_steps_by == "env_steps" + else batch.agent_steps() + ) + batches.append(batch) + + batch = concat_samples(batches) + + self.callbacks.on_sample_end(worker=self, samples=batch) + + # Always do writes prior to compression for consistency and to allow + # for better compression inside the writer. + self.output_writer.write(batch) + + if log_once("sample_end"): + logger.info("Completed sample batch:\n\n{}\n".format(summarize(batch))) + + if self.config.compress_observations: + batch.compress(bulk=self.config.compress_observations == "bulk") + + if self.config.fake_sampler: + self.last_batch = batch + + return batch + + @override(EnvRunner) + def get_spaces(self) -> Dict[str, Tuple[Space, Space]]: + spaces = self.foreach_policy( + lambda p, pid: (pid, p.observation_space, p.action_space) + ) + spaces = {e[0]: (getattr(e[1], "original_space", e[1]), e[2]) for e in spaces} + # Try to add the actual env's obs/action spaces. + env_spaces = self.foreach_env( + lambda env: (env.observation_space, env.action_space) + ) + if env_spaces: + from ray.rllib.env import INPUT_ENV_SPACES + + spaces[INPUT_ENV_SPACES] = env_spaces[0] + return spaces + + @ray.method(num_returns=2) + def sample_with_count(self) -> Tuple[SampleBatchType, int]: + """Same as sample() but returns the count as a separate value. + + Returns: + A columnar batch of experiences (e.g., tensors) and the + size of the collected batch. + + .. testcode:: + :skipif: True + + import gymnasium as gym + from ray.rllib.evaluation.rollout_worker import RolloutWorker + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy + worker = RolloutWorker( + env_creator=lambda _: gym.make("CartPole-v1"), + default_policy_class=PPOTFPolicy) + print(worker.sample_with_count()) + + .. testoutput:: + + (SampleBatch({"obs": [...], "action": [...], ...}), 3) + """ + batch = self.sample() + return batch, batch.count + + def learn_on_batch(self, samples: SampleBatchType) -> Dict: + """Update policies based on the given batch. + + This is the equivalent to apply_gradients(compute_gradients(samples)), + but can be optimized to avoid pulling gradients into CPU memory. + + Args: + samples: The SampleBatch or MultiAgentBatch to learn on. + + Returns: + Dictionary of extra metadata from compute_gradients(). + + .. testcode:: + :skipif: True + + import gymnasium as gym + from ray.rllib.evaluation.rollout_worker import RolloutWorker + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy + worker = RolloutWorker( + env_creator=lambda _: gym.make("CartPole-v1"), + default_policy_class=PPOTF1Policy) + batch = worker.sample() + info = worker.learn_on_batch(samples) + """ + if log_once("learn_on_batch"): + logger.info( + "Training on concatenated sample batches:\n\n{}\n".format( + summarize(samples) + ) + ) + + info_out = {} + if isinstance(samples, MultiAgentBatch): + builders = {} + to_fetch = {} + for pid, batch in samples.policy_batches.items(): + if self.is_policy_to_train is not None and not self.is_policy_to_train( + pid, samples + ): + continue + # Decompress SampleBatch, in case some columns are compressed. + batch.decompress_if_needed() + + policy = self.policy_map[pid] + tf_session = policy.get_session() + if tf_session and hasattr(policy, "_build_learn_on_batch"): + builders[pid] = _TFRunBuilder(tf_session, "learn_on_batch") + to_fetch[pid] = policy._build_learn_on_batch(builders[pid], batch) + else: + info_out[pid] = policy.learn_on_batch(batch) + + info_out.update({pid: builders[pid].get(v) for pid, v in to_fetch.items()}) + else: + if self.is_policy_to_train is None or self.is_policy_to_train( + DEFAULT_POLICY_ID, samples + ): + info_out.update( + { + DEFAULT_POLICY_ID: self.policy_map[ + DEFAULT_POLICY_ID + ].learn_on_batch(samples) + } + ) + if log_once("learn_out"): + logger.debug("Training out:\n\n{}\n".format(summarize(info_out))) + return info_out + + def sample_and_learn( + self, + expected_batch_size: int, + num_sgd_iter: int, + sgd_minibatch_size: str, + standardize_fields: List[str], + ) -> Tuple[dict, int]: + """Sample and batch and learn on it. + + This is typically used in combination with distributed allreduce. + + Args: + expected_batch_size: Expected number of samples to learn on. + num_sgd_iter: Number of SGD iterations. + sgd_minibatch_size: SGD minibatch size. + standardize_fields: List of sample fields to normalize. + + Returns: + A tuple consisting of a dictionary of extra metadata returned from + the policies' `learn_on_batch()` and the number of samples + learned on. + """ + batch = self.sample() + assert batch.count == expected_batch_size, ( + "Batch size possibly out of sync between workers, expected:", + expected_batch_size, + "got:", + batch.count, + ) + logger.info( + "Executing distributed minibatch SGD " + "with epoch size {}, minibatch size {}".format( + batch.count, sgd_minibatch_size + ) + ) + info = do_minibatch_sgd( + batch, + self.policy_map, + self, + num_sgd_iter, + sgd_minibatch_size, + standardize_fields, + ) + return info, batch.count + + def compute_gradients( + self, + samples: SampleBatchType, + single_agent: bool = None, + ) -> Tuple[ModelGradients, dict]: + """Returns a gradient computed w.r.t the specified samples. + + Uses the Policy's/ies' compute_gradients method(s) to perform the + calculations. Skips policies that are not trainable as per + `self.is_policy_to_train()`. + + Args: + samples: The SampleBatch or MultiAgentBatch to compute gradients + for using this worker's trainable policies. + + Returns: + In the single-agent case, a tuple consisting of ModelGradients and + info dict of the worker's policy. + In the multi-agent case, a tuple consisting of a dict mapping + PolicyID to ModelGradients and a dict mapping PolicyID to extra + metadata info. + Note that the first return value (grads) can be applied as is to a + compatible worker using the worker's `apply_gradients()` method. + + .. testcode:: + :skipif: True + + import gymnasium as gym + from ray.rllib.evaluation.rollout_worker import RolloutWorker + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy + worker = RolloutWorker( + env_creator=lambda _: gym.make("CartPole-v1"), + default_policy_class=PPOTF1Policy) + batch = worker.sample() + grads, info = worker.compute_gradients(samples) + """ + if log_once("compute_gradients"): + logger.info("Compute gradients on:\n\n{}\n".format(summarize(samples))) + + if single_agent is True: + samples = convert_ma_batch_to_sample_batch(samples) + grad_out, info_out = self.policy_map[DEFAULT_POLICY_ID].compute_gradients( + samples + ) + info_out["batch_count"] = samples.count + return grad_out, info_out + + # Treat everything as is multi-agent. + samples = samples.as_multi_agent() + + # Calculate gradients for all policies. + grad_out, info_out = {}, {} + if self.config.framework_str == "tf": + for pid, batch in samples.policy_batches.items(): + if self.is_policy_to_train is not None and not self.is_policy_to_train( + pid, samples + ): + continue + policy = self.policy_map[pid] + builder = _TFRunBuilder(policy.get_session(), "compute_gradients") + grad_out[pid], info_out[pid] = policy._build_compute_gradients( + builder, batch + ) + grad_out = {k: builder.get(v) for k, v in grad_out.items()} + info_out = {k: builder.get(v) for k, v in info_out.items()} + else: + for pid, batch in samples.policy_batches.items(): + if self.is_policy_to_train is not None and not self.is_policy_to_train( + pid, samples + ): + continue + grad_out[pid], info_out[pid] = self.policy_map[pid].compute_gradients( + batch + ) + + info_out["batch_count"] = samples.count + if log_once("grad_out"): + logger.info("Compute grad info:\n\n{}\n".format(summarize(info_out))) + + return grad_out, info_out + + def apply_gradients( + self, + grads: Union[ModelGradients, Dict[PolicyID, ModelGradients]], + ) -> None: + """Applies the given gradients to this worker's models. + + Uses the Policy's/ies' apply_gradients method(s) to perform the + operations. + + Args: + grads: Single ModelGradients (single-agent case) or a dict + mapping PolicyIDs to the respective model gradients + structs. + + .. testcode:: + :skipif: True + + import gymnasium as gym + from ray.rllib.evaluation.rollout_worker import RolloutWorker + from ray.rllib.algorithms.ppo.ppo_tf_policy import PPOTF1Policy + worker = RolloutWorker( + env_creator=lambda _: gym.make("CartPole-v1"), + default_policy_class=PPOTF1Policy) + samples = worker.sample() + grads, info = worker.compute_gradients(samples) + worker.apply_gradients(grads) + """ + if log_once("apply_gradients"): + logger.info("Apply gradients:\n\n{}\n".format(summarize(grads))) + # Grads is a dict (mapping PolicyIDs to ModelGradients). + # Multi-agent case. + if isinstance(grads, dict): + for pid, g in grads.items(): + if self.is_policy_to_train is None or self.is_policy_to_train( + pid, None + ): + self.policy_map[pid].apply_gradients(g) + # Grads is a ModelGradients type. Single-agent case. + elif self.is_policy_to_train is None or self.is_policy_to_train( + DEFAULT_POLICY_ID, None + ): + self.policy_map[DEFAULT_POLICY_ID].apply_gradients(grads) + + @override(EnvRunner) + def get_metrics(self) -> List[RolloutMetrics]: + """Returns the thus-far collected metrics from this worker's rollouts. + + Returns: + List of RolloutMetrics collected thus-far. + """ + # Get metrics from sampler (if any). + if self.sampler is not None: + out = self.sampler.get_metrics() + else: + out = [] + + return out + + def foreach_env(self, func: Callable[[EnvType], T]) -> List[T]: + """Calls the given function with each sub-environment as arg. + + Args: + func: The function to call for each underlying + sub-environment (as only arg). + + Returns: + The list of return values of all calls to `func([env])`. + """ + + if self.async_env is None: + return [] + + envs = self.async_env.get_sub_environments() + # Empty list (not implemented): Call function directly on the + # BaseEnv. + if not envs: + return [func(self.async_env)] + # Call function on all underlying (vectorized) sub environments. + else: + return [func(e) for e in envs] + + def foreach_env_with_context( + self, func: Callable[[EnvType, EnvContext], T] + ) -> List[T]: + """Calls given function with each sub-env plus env_ctx as args. + + Args: + func: The function to call for each underlying + sub-environment and its EnvContext (as the args). + + Returns: + The list of return values of all calls to `func([env, ctx])`. + """ + + if self.async_env is None: + return [] + + envs = self.async_env.get_sub_environments() + # Empty list (not implemented): Call function directly on the + # BaseEnv. + if not envs: + return [func(self.async_env, self.env_context)] + # Call function on all underlying (vectorized) sub environments. + else: + ret = [] + for i, e in enumerate(envs): + ctx = self.env_context.copy_with_overrides(vector_index=i) + ret.append(func(e, ctx)) + return ret + + def get_policy(self, policy_id: PolicyID = DEFAULT_POLICY_ID) -> Optional[Policy]: + """Return policy for the specified id, or None. + + Args: + policy_id: ID of the policy to return. None for DEFAULT_POLICY_ID + (in the single agent case). + + Returns: + The policy under the given ID (or None if not found). + """ + return self.policy_map.get(policy_id) + + def add_policy( + self, + policy_id: PolicyID, + policy_cls: Optional[Type[Policy]] = None, + policy: Optional[Policy] = None, + *, + observation_space: Optional[Space] = None, + action_space: Optional[Space] = None, + config: Optional[PartialAlgorithmConfigDict] = None, + policy_state: Optional[PolicyState] = None, + policy_mapping_fn=None, + policies_to_train: Optional[ + Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]] + ] = None, + module_spec: Optional[RLModuleSpec] = None, + ) -> Policy: + """Adds a new policy to this RolloutWorker. + + Args: + policy_id: ID of the policy to add. + policy_cls: The Policy class to use for constructing the new Policy. + Note: Only one of `policy_cls` or `policy` must be provided. + policy: The Policy instance to add to this algorithm. + Note: Only one of `policy_cls` or `policy` must be provided. + observation_space: The observation space of the policy to add. + action_space: The action space of the policy to add. + config: The config overrides for the policy to add. + policy_state: Optional state dict to apply to the new + policy instance, right after its construction. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. Note that already ongoing episodes will + not change their mapping but will use the old mapping till + the end of the episode. + policies_to_train: An optional collection of policy IDs to be + trained or a callable taking PolicyID and - optionally - + SampleBatchType and returning a bool (trainable or not?). + If None, will keep the existing setup in place. + Policies, whose IDs are not in the list (or for which the + callable returns False) will not be updated. + module_spec: In the new RLModule API we need to pass in the module_spec for + the new module that is supposed to be added. Knowing the policy spec is + not sufficient. + + Returns: + The newly added policy. + + Raises: + ValueError: If both `policy_cls` AND `policy` are provided. + KeyError: If the given `policy_id` already exists in this worker's + PolicyMap. + """ + validate_module_id(policy_id, error=False) + + if module_spec is not None: + raise ValueError( + "If you pass in module_spec to the policy, the RLModule API needs " + "to be enabled." + ) + + if policy_id in self.policy_map: + raise KeyError( + f"Policy ID '{policy_id}' already exists in policy map! " + "Make sure you use a Policy ID that has not been taken yet." + " Policy IDs that are already in your policy map: " + f"{list(self.policy_map.keys())}" + ) + if (policy_cls is None) == (policy is None): + raise ValueError( + "Only one of `policy_cls` or `policy` must be provided to " + "RolloutWorker.add_policy()!" + ) + + if policy is None: + policy_dict_to_add, _ = self.config.get_multi_agent_setup( + policies={ + policy_id: PolicySpec( + policy_cls, observation_space, action_space, config + ) + }, + env=self.env, + spaces=self.spaces, + default_policy_class=self.default_policy_class, + ) + else: + policy_dict_to_add = { + policy_id: PolicySpec( + type(policy), + policy.observation_space, + policy.action_space, + policy.config, + ) + } + + self.policy_dict.update(policy_dict_to_add) + self._update_policy_map( + policy_dict=policy_dict_to_add, + policy=policy, + policy_states={policy_id: policy_state}, + single_agent_rl_module_spec=module_spec, + ) + + self.set_policy_mapping_fn(policy_mapping_fn) + if policies_to_train is not None: + self.set_is_policy_to_train(policies_to_train) + + return self.policy_map[policy_id] + + def remove_policy( + self, + *, + policy_id: PolicyID = DEFAULT_POLICY_ID, + policy_mapping_fn: Optional[Callable[[AgentID], PolicyID]] = None, + policies_to_train: Optional[ + Union[Collection[PolicyID], Callable[[PolicyID, SampleBatchType], bool]] + ] = None, + ) -> None: + """Removes a policy from this RolloutWorker. + + Args: + policy_id: ID of the policy to be removed. None for + DEFAULT_POLICY_ID. + policy_mapping_fn: An optional (updated) policy mapping function + to use from here on. Note that already ongoing episodes will + not change their mapping but will use the old mapping till + the end of the episode. + policies_to_train: An optional collection of policy IDs to be + trained or a callable taking PolicyID and - optionally - + SampleBatchType and returning a bool (trainable or not?). + If None, will keep the existing setup in place. + Policies, whose IDs are not in the list (or for which the + callable returns False) will not be updated. + """ + if policy_id not in self.policy_map: + raise ValueError(f"Policy ID '{policy_id}' not in policy map!") + del self.policy_map[policy_id] + del self.preprocessors[policy_id] + self.set_policy_mapping_fn(policy_mapping_fn) + if policies_to_train is not None: + self.set_is_policy_to_train(policies_to_train) + + def set_policy_mapping_fn( + self, + policy_mapping_fn: Optional[Callable[[AgentID, Any], PolicyID]] = None, + ) -> None: + """Sets `self.policy_mapping_fn` to a new callable (if provided). + + Args: + policy_mapping_fn: The new mapping function to use. If None, + will keep the existing mapping function in place. + """ + if policy_mapping_fn is not None: + self.policy_mapping_fn = policy_mapping_fn + if not callable(self.policy_mapping_fn): + raise ValueError("`policy_mapping_fn` must be a callable!") + + def set_is_policy_to_train( + self, + is_policy_to_train: Union[ + Collection[PolicyID], Callable[[PolicyID, Optional[SampleBatchType]], bool] + ], + ) -> None: + """Sets `self.is_policy_to_train()` to a new callable. + + Args: + is_policy_to_train: A collection of policy IDs to be + trained or a callable taking PolicyID and - optionally - + SampleBatchType and returning a bool (trainable or not?). + If None, will keep the existing setup in place. + Policies, whose IDs are not in the list (or for which the + callable returns False) will not be updated. + """ + # If collection given, construct a simple default callable returning True + # if the PolicyID is found in the list/set of IDs. + if not callable(is_policy_to_train): + assert isinstance(is_policy_to_train, (list, set, tuple)), ( + "ERROR: `is_policy_to_train`must be a [list|set|tuple] or a " + "callable taking PolicyID and SampleBatch and returning " + "True|False (trainable or not?)." + ) + pols = set(is_policy_to_train) + + def is_policy_to_train(pid, batch=None): + return pid in pols + + self.is_policy_to_train = is_policy_to_train + + @PublicAPI(stability="alpha") + def get_policies_to_train( + self, batch: Optional[SampleBatchType] = None + ) -> Set[PolicyID]: + """Returns all policies-to-train, given an optional batch. + + Loops through all policies currently in `self.policy_map` and checks + the return value of `self.is_policy_to_train(pid, batch)`. + + Args: + batch: An optional SampleBatchType for the + `self.is_policy_to_train(pid, [batch]?)` check. + + Returns: + The set of currently trainable policy IDs, given the optional + `batch`. + """ + return { + pid + for pid in self.policy_map.keys() + if self.is_policy_to_train is None or self.is_policy_to_train(pid, batch) + } + + def for_policy( + self, + func: Callable[[Policy, Optional[Any]], T], + policy_id: Optional[PolicyID] = DEFAULT_POLICY_ID, + **kwargs, + ) -> T: + """Calls the given function with the specified policy as first arg. + + Args: + func: The function to call with the policy as first arg. + policy_id: The PolicyID of the policy to call the function with. + + Keyword Args: + kwargs: Additional kwargs to be passed to the call. + + Returns: + The return value of the function call. + """ + + return func(self.policy_map[policy_id], **kwargs) + + def foreach_policy( + self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs + ) -> List[T]: + """Calls the given function with each (policy, policy_id) tuple. + + Args: + func: The function to call with each (policy, policy ID) tuple. + + Keyword Args: + kwargs: Additional kwargs to be passed to the call. + + Returns: + The list of return values of all calls to + `func([policy, pid, **kwargs])`. + """ + return [func(policy, pid, **kwargs) for pid, policy in self.policy_map.items()] + + def foreach_policy_to_train( + self, func: Callable[[Policy, PolicyID, Optional[Any]], T], **kwargs + ) -> List[T]: + """ + Calls the given function with each (policy, policy_id) tuple. + + Only those policies/IDs will be called on, for which + `self.is_policy_to_train()` returns True. + + Args: + func: The function to call with each (policy, policy ID) tuple, + for only those policies that `self.is_policy_to_train` + returns True. + + Keyword Args: + kwargs: Additional kwargs to be passed to the call. + + Returns: + The list of return values of all calls to + `func([policy, pid, **kwargs])`. + """ + return [ + # Make sure to only iterate over keys() and not items(). Iterating over + # items will access policy_map elements even for pids that we do not need, + # i.e. those that are not in policy_to_train. Access to policy_map elements + # can cause disk access for policies that were offloaded to disk. Since + # these policies will be skipped in the for-loop accessing them is + # unnecessary, making subsequent disk access unnecessary. + func(self.policy_map[pid], pid, **kwargs) + for pid in self.policy_map.keys() + if self.is_policy_to_train is None or self.is_policy_to_train(pid, None) + ] + + def sync_filters(self, new_filters: dict) -> None: + """Changes self's filter to given and rebases any accumulated delta. + + Args: + new_filters: Filters with new state to update local copy. + """ + assert all(k in new_filters for k in self.filters) + for k in self.filters: + self.filters[k].sync(new_filters[k]) + + def get_filters(self, flush_after: bool = False) -> Dict: + """Returns a snapshot of filters. + + Args: + flush_after: Clears the filter buffer state. + + Returns: + Dict for serializable filters + """ + return_filters = {} + for k, f in self.filters.items(): + return_filters[k] = f.as_serializable() + if flush_after: + f.reset_buffer() + return return_filters + + def get_state(self) -> dict: + filters = self.get_filters(flush_after=True) + policy_states = {} + for pid in self.policy_map.keys(): + # If required by the user, only capture policies that are actually + # trainable. Otherwise, capture all policies (for saving to disk). + if ( + not self.config.checkpoint_trainable_policies_only + or self.is_policy_to_train is None + or self.is_policy_to_train(pid) + ): + policy_states[pid] = self.policy_map[pid].get_state() + + return { + # List all known policy IDs here for convenience. When an Algorithm gets + # restored from a checkpoint, it will not have access to the list of + # possible IDs as each policy is stored in its own sub-dir + # (see "policy_states"). + "policy_ids": list(self.policy_map.keys()), + # Note that this field will not be stored in the algorithm checkpoint's + # state file, but each policy will get its own state file generated in + # a sub-dir within the algo's checkpoint dir. + "policy_states": policy_states, + # Also store current mapping fn and which policies to train. + "policy_mapping_fn": self.policy_mapping_fn, + "is_policy_to_train": self.is_policy_to_train, + # TODO: Filters will be replaced by connectors. + "filters": filters, + } + + def set_state(self, state: dict) -> None: + # Backward compatibility (old checkpoints' states would have the local + # worker state as a bytes object, not a dict). + if isinstance(state, bytes): + state = pickle.loads(state) + + # TODO: Once filters are handled by connectors, get rid of the "filters" + # key in `state` entirely (will be part of the policies then). + self.sync_filters(state["filters"]) + + # Support older checkpoint versions (< 1.0), in which the policy_map + # was stored under the "state" key, not "policy_states". + policy_states = ( + state["policy_states"] if "policy_states" in state else state["state"] + ) + for pid, policy_state in policy_states.items(): + # If - for some reason - we have an invalid PolicyID in the state, + # this might be from an older checkpoint (pre v1.0). Just warn here. + validate_module_id(pid, error=False) + + if pid not in self.policy_map: + spec = policy_state.get("policy_spec", None) + if spec is None: + logger.warning( + f"PolicyID '{pid}' was probably added on-the-fly (not" + " part of the static `multagent.policies` config) and" + " no PolicySpec objects found in the pickled policy " + f"state. Will not add `{pid}`, but ignore it for now." + ) + else: + policy_spec = ( + PolicySpec.deserialize(spec) if isinstance(spec, dict) else spec + ) + self.add_policy( + policy_id=pid, + policy_cls=policy_spec.policy_class, + observation_space=policy_spec.observation_space, + action_space=policy_spec.action_space, + config=policy_spec.config, + ) + if pid in self.policy_map: + self.policy_map[pid].set_state(policy_state) + + # Also restore mapping fn and which policies to train. + if "policy_mapping_fn" in state: + self.set_policy_mapping_fn(state["policy_mapping_fn"]) + if state.get("is_policy_to_train") is not None: + self.set_is_policy_to_train(state["is_policy_to_train"]) + + def get_weights( + self, + policies: Optional[Collection[PolicyID]] = None, + inference_only: bool = False, + ) -> Dict[PolicyID, ModelWeights]: + """Returns each policies' model weights of this worker. + + Args: + policies: List of PolicyIDs to get the weights from. + Use None for all policies. + inference_only: This argument is only added for interface + consistency with the new api stack. + + Returns: + Dict mapping PolicyIDs to ModelWeights. + + .. testcode:: + :skipif: True + + from ray.rllib.evaluation.rollout_worker import RolloutWorker + # Create a RolloutWorker. + worker = ... + weights = worker.get_weights() + print(weights) + + .. testoutput:: + + {"default_policy": {"layer1": array(...), "layer2": ...}} + """ + if policies is None: + policies = list(self.policy_map.keys()) + policies = force_list(policies) + + return { + # Make sure to only iterate over keys() and not items(). Iterating over + # items will access policy_map elements even for pids that we do not need, + # i.e. those that are not in policies. Access to policy_map elements can + # cause disk access for policies that were offloaded to disk. Since these + # policies will be skipped in the for-loop accessing them is unnecessary, + # making subsequent disk access unnecessary. + pid: self.policy_map[pid].get_weights() + for pid in self.policy_map.keys() + if pid in policies + } + + def set_weights( + self, + weights: Dict[PolicyID, ModelWeights], + global_vars: Optional[Dict] = None, + weights_seq_no: Optional[int] = None, + ) -> None: + """Sets each policies' model weights of this worker. + + Args: + weights: Dict mapping PolicyIDs to the new weights to be used. + global_vars: An optional global vars dict to set this + worker to. If None, do not update the global_vars. + weights_seq_no: If needed, a sequence number for the weights version + can be passed into this method. If not None, will store this seq no + (in self.weights_seq_no) and in future calls - if the seq no did not + change wrt. the last call - will ignore the call to save on performance. + + .. testcode:: + :skipif: True + + from ray.rllib.evaluation.rollout_worker import RolloutWorker + # Create a RolloutWorker. + worker = ... + weights = worker.get_weights() + # Set `global_vars` (timestep) as well. + worker.set_weights(weights, {"timestep": 42}) + """ + # Only update our weights, if no seq no given OR given seq no is different + # from ours. + if weights_seq_no is None or weights_seq_no != self.weights_seq_no: + # If per-policy weights are object refs, `ray.get()` them first. + if weights and isinstance(next(iter(weights.values())), ObjectRef): + actual_weights = ray.get(list(weights.values())) + weights = { + pid: actual_weights[i] for i, pid in enumerate(weights.keys()) + } + + for pid, w in weights.items(): + if pid in self.policy_map: + self.policy_map[pid].set_weights(w) + elif log_once("set_weights_on_non_existent_policy"): + logger.warning( + "`RolloutWorker.set_weights()` used with weights from " + f"policyID={pid}, but this policy cannot be found on this " + f"worker! Skipping ..." + ) + + self.weights_seq_no = weights_seq_no + + if global_vars: + self.set_global_vars(global_vars) + + def get_global_vars(self) -> dict: + """Returns the current `self.global_vars` dict of this RolloutWorker. + + Returns: + The current `self.global_vars` dict of this RolloutWorker. + + .. testcode:: + :skipif: True + + from ray.rllib.evaluation.rollout_worker import RolloutWorker + # Create a RolloutWorker. + worker = ... + global_vars = worker.get_global_vars() + print(global_vars) + + .. testoutput:: + + {"timestep": 424242} + """ + return self.global_vars + + def set_global_vars( + self, + global_vars: dict, + policy_ids: Optional[List[PolicyID]] = None, + ) -> None: + """Updates this worker's and all its policies' global vars. + + Updates are done using the dict's update method. + + Args: + global_vars: The global_vars dict to update the `self.global_vars` dict + from. + policy_ids: Optional list of Policy IDs to update. If None, will update all + policies on the to-be-updated workers. + + .. testcode:: + :skipif: True + + worker = ... + global_vars = worker.set_global_vars( + ... {"timestep": 4242}) + """ + # Handle per-policy values. + global_vars_copy = global_vars.copy() + gradient_updates_per_policy = global_vars_copy.pop( + "num_grad_updates_per_policy", {} + ) + self.global_vars["num_grad_updates_per_policy"].update( + gradient_updates_per_policy + ) + # Only update explicitly provided policies or those that that are being + # trained, in order to avoid superfluous access of policies, which might have + # been offloaded to the object store. + # Important b/c global vars are constantly being updated. + for pid in policy_ids if policy_ids is not None else self.policy_map.keys(): + if self.is_policy_to_train is None or self.is_policy_to_train(pid, None): + self.policy_map[pid].on_global_var_update( + dict( + global_vars_copy, + # If count is None, Policy won't update the counter. + **{"num_grad_updates": gradient_updates_per_policy.get(pid)}, + ) + ) + + # Update all other global vars. + self.global_vars.update(global_vars_copy) + + @override(EnvRunner) + def stop(self) -> None: + """Releases all resources used by this RolloutWorker.""" + + # If we have an env -> Release its resources. + if self.env is not None: + self.async_env.stop() + + # Close all policies' sessions (if tf static graph). + for policy in self.policy_map.cache.values(): + sess = policy.get_session() + # Closes the tf session, if any. + if sess is not None: + sess.close() + + def lock(self) -> None: + """Locks this RolloutWorker via its own threading.Lock.""" + self._lock.acquire() + + def unlock(self) -> None: + """Unlocks this RolloutWorker via its own threading.Lock.""" + self._lock.release() + + def setup_torch_data_parallel( + self, url: str, world_rank: int, world_size: int, backend: str + ) -> None: + """Join a torch process group for distributed SGD.""" + + logger.info( + "Joining process group, url={}, world_rank={}, " + "world_size={}, backend={}".format(url, world_rank, world_size, backend) + ) + torch.distributed.init_process_group( + backend=backend, init_method=url, rank=world_rank, world_size=world_size + ) + + for pid, policy in self.policy_map.items(): + if not isinstance(policy, (TorchPolicy, TorchPolicyV2)): + raise ValueError( + "This policy does not support torch distributed", policy + ) + policy.distributed_world_size = world_size + + def creation_args(self) -> dict: + """Returns the kwargs dict used to create this worker.""" + return self._original_kwargs + + def get_host(self) -> str: + """Returns the hostname of the process running this evaluator.""" + return platform.node() + + def get_node_ip(self) -> str: + """Returns the IP address of the node that this worker runs on.""" + return ray.util.get_node_ip_address() + + def find_free_port(self) -> int: + """Finds a free port on the node that this worker runs on.""" + from ray.air._internal.util import find_free_port + + return find_free_port() + + def _update_policy_map( + self, + *, + policy_dict: MultiAgentPolicyConfigDict, + policy: Optional[Policy] = None, + policy_states: Optional[Dict[PolicyID, PolicyState]] = None, + single_agent_rl_module_spec: Optional[RLModuleSpec] = None, + ) -> None: + """Updates the policy map (and other stuff) on this worker. + + It performs the following: + 1. It updates the observation preprocessors and updates the policy_specs + with the postprocessed observation_spaces. + 2. It updates the policy_specs with the complete algorithm_config (merged + with the policy_spec's config). + 3. If needed it will update the self.multi_rl_module_spec on this worker + 3. It updates the policy map with the new policies + 4. It updates the filter dict + 5. It calls the on_create_policy() hook of the callbacks on the newly added + policies. + + Args: + policy_dict: The policy dict to update the policy map with. + policy: The policy to update the policy map with. + policy_states: The policy states to update the policy map with. + single_agent_rl_module_spec: The RLModuleSpec to add to the + MultiRLModuleSpec. If None, the config's + `get_default_rl_module_spec` method's output will be used to create + the policy with. + """ + + # Update the input policy dict with the postprocessed observation spaces and + # merge configs. Also updates the preprocessor dict. + updated_policy_dict = self._get_complete_policy_specs_dict(policy_dict) + + # Builds the self.policy_map dict + self._build_policy_map( + policy_dict=updated_policy_dict, + policy=policy, + policy_states=policy_states, + ) + + # Initialize the filter dict + self._update_filter_dict(updated_policy_dict) + + # Call callback policy init hooks (only if the added policy did not exist + # before). + if policy is None: + self._call_callbacks_on_create_policy() + + if self.worker_index == 0: + logger.info(f"Built policy map: {self.policy_map}") + logger.info(f"Built preprocessor map: {self.preprocessors}") + + def _get_complete_policy_specs_dict( + self, policy_dict: MultiAgentPolicyConfigDict + ) -> MultiAgentPolicyConfigDict: + """Processes the policy dict and creates a new copy with the processed attrs. + + This processes the observation_space and prepares them for passing to rl module + construction. It also merges the policy configs with the algorithm config. + During this processing, we will also construct the preprocessors dict. + """ + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + updated_policy_dict = copy.deepcopy(policy_dict) + # If our preprocessors dict does not exist yet, create it here. + self.preprocessors = self.preprocessors or {} + # Loop through given policy-dict and add each entry to our map. + for name, policy_spec in sorted(updated_policy_dict.items()): + logger.debug("Creating policy for {}".format(name)) + + # Policy brings its own complete AlgorithmConfig -> Use it for this policy. + if isinstance(policy_spec.config, AlgorithmConfig): + merged_conf = policy_spec.config + else: + # Update the general config with the specific config + # for this particular policy. + merged_conf: "AlgorithmConfig" = self.config.copy(copy_frozen=False) + merged_conf.update_from_dict(policy_spec.config or {}) + + # Update num_workers and worker_index. + merged_conf.worker_index = self.worker_index + + # Preprocessors. + obs_space = policy_spec.observation_space + # Initialize preprocessor for this policy to None. + self.preprocessors[name] = None + if self.preprocessing_enabled: + # Policies should deal with preprocessed (automatically flattened) + # observations if preprocessing is enabled. + preprocessor = ModelCatalog.get_preprocessor_for_space( + obs_space, + merged_conf.model, + include_multi_binary=False, + ) + # Original observation space should be accessible at + # obs_space.original_space after this step. + if preprocessor is not None: + obs_space = preprocessor.observation_space + + policy_spec.config = merged_conf + policy_spec.observation_space = obs_space + + return updated_policy_dict + + def _update_policy_dict_with_multi_rl_module( + self, policy_dict: MultiAgentPolicyConfigDict + ) -> MultiAgentPolicyConfigDict: + for name, policy_spec in policy_dict.items(): + policy_spec.config["__multi_rl_module_spec"] = self.multi_rl_module_spec + return policy_dict + + def _build_policy_map( + self, + *, + policy_dict: MultiAgentPolicyConfigDict, + policy: Optional[Policy] = None, + policy_states: Optional[Dict[PolicyID, PolicyState]] = None, + ) -> None: + """Adds the given policy_dict to `self.policy_map`. + + Args: + policy_dict: The MultiAgentPolicyConfigDict to be added to this + worker's PolicyMap. + policy: If the policy to add already exists, user can provide it here. + policy_states: Optional dict from PolicyIDs to PolicyStates to + restore the states of the policies being built. + """ + + # If our policy_map does not exist yet, create it here. + self.policy_map = self.policy_map or PolicyMap( + capacity=self.config.policy_map_capacity, + policy_states_are_swappable=self.config.policy_states_are_swappable, + ) + + # Loop through given policy-dict and add each entry to our map. + for name, policy_spec in sorted(policy_dict.items()): + # Create the actual policy object. + if policy is None: + new_policy = create_policy_for_framework( + policy_id=name, + policy_class=get_tf_eager_cls_if_necessary( + policy_spec.policy_class, policy_spec.config + ), + merged_config=policy_spec.config, + observation_space=policy_spec.observation_space, + action_space=policy_spec.action_space, + worker_index=self.worker_index, + seed=self.seed, + ) + else: + new_policy = policy + + self.policy_map[name] = new_policy + + restore_states = (policy_states or {}).get(name, None) + # Set the state of the newly created policy before syncing filters, etc. + if restore_states: + new_policy.set_state(restore_states) + + def _update_filter_dict(self, policy_dict: MultiAgentPolicyConfigDict) -> None: + """Updates the filter dict for the given policy_dict.""" + + for name, policy_spec in sorted(policy_dict.items()): + new_policy = self.policy_map[name] + # Note(jungong) : We should only create new connectors for the + # policy iff we are creating a new policy from scratch. i.e, + # we should NOT create new connectors when we already have the + # policy object created before this function call or have the + # restoring states from the caller. + # Also note that we cannot just check the existence of connectors + # to decide whether we should create connectors because we may be + # restoring a policy that has 0 connectors configured. + if ( + new_policy.agent_connectors is None + or new_policy.action_connectors is None + ): + # TODO(jungong) : revisit this. It will be nicer to create + # connectors as the last step of Policy.__init__(). + create_connectors_for_policy(new_policy, policy_spec.config) + maybe_get_filters_for_syncing(self, name) + + def _call_callbacks_on_create_policy(self): + """Calls the on_create_policy callback for each policy in the policy map.""" + for name, policy in self.policy_map.items(): + self.callbacks.on_create_policy(policy_id=name, policy=policy) + + def _get_input_creator_from_config(self): + def valid_module(class_path): + if ( + isinstance(class_path, str) + and not os.path.isfile(class_path) + and "." in class_path + ): + module_path, class_name = class_path.rsplit(".", 1) + try: + spec = importlib.util.find_spec(module_path) + if spec is not None: + return True + except (ModuleNotFoundError, ValueError): + print( + f"module {module_path} not found while trying to get " + f"input {class_path}" + ) + return False + + # A callable returning an InputReader object to use. + if isinstance(self.config.input_, FunctionType): + return self.config.input_ + # Use RLlib's Sampler classes (SyncSampler). + elif self.config.input_ == "sampler": + return lambda ioctx: ioctx.default_sampler_input() + # Ray Dataset input -> Use `config.input_config` to construct DatasetReader. + elif self.config.input_ == "dataset": + assert self._ds_shards is not None + # Input dataset shards should have already been prepared. + # We just need to take the proper shard here. + return lambda ioctx: DatasetReader( + self._ds_shards[self.worker_index], ioctx + ) + # Dict: Mix of different input methods with different ratios. + elif isinstance(self.config.input_, dict): + return lambda ioctx: ShuffledInput( + MixedInput(self.config.input_, ioctx), self.config.shuffle_buffer_size + ) + # A pre-registered input descriptor (str). + elif isinstance(self.config.input_, str) and registry_contains_input( + self.config.input_ + ): + return registry_get_input(self.config.input_) + # D4RL input. + elif "d4rl" in self.config.input_: + env_name = self.config.input_.split(".")[-1] + return lambda ioctx: D4RLReader(env_name, ioctx) + # Valid python module (class path) -> Create using `from_config`. + elif valid_module(self.config.input_): + return lambda ioctx: ShuffledInput( + from_config(self.config.input_, ioctx=ioctx) + ) + # JSON file or list of JSON files -> Use JsonReader (shuffled). + else: + return lambda ioctx: ShuffledInput( + JsonReader(self.config.input_, ioctx), self.config.shuffle_buffer_size + ) + + def _get_output_creator_from_config(self): + if isinstance(self.config.output, FunctionType): + return self.config.output + elif self.config.output is None: + return lambda ioctx: NoopOutput() + elif self.config.output == "dataset": + return lambda ioctx: DatasetWriter( + ioctx, compress_columns=self.config.output_compress_columns + ) + elif self.config.output == "logdir": + return lambda ioctx: JsonWriter( + ioctx.log_dir, + ioctx, + max_file_size=self.config.output_max_file_size, + compress_columns=self.config.output_compress_columns, + ) + else: + return lambda ioctx: JsonWriter( + self.config.output, + ioctx, + max_file_size=self.config.output_max_file_size, + compress_columns=self.config.output_compress_columns, + ) + + def _get_make_sub_env_fn( + self, env_creator, env_context, validate_env, env_wrapper, seed + ): + def _make_sub_env_local(vector_index): + # Used to created additional environments during environment + # vectorization. + + # Create the env context (config dict + meta-data) for + # this particular sub-env within the vectorized one. + env_ctx = env_context.copy_with_overrides(vector_index=vector_index) + # Create the sub-env. + env = env_creator(env_ctx) + # Custom validation function given by user. + if validate_env is not None: + validate_env(env, env_ctx) + # Use our wrapper, defined above. + env = env_wrapper(env) + + # Make sure a deterministic random seed is set on + # all the sub-environments if specified. + _update_env_seed_if_necessary( + env, seed, env_context.worker_index, vector_index + ) + return env + + if not env_context.remote: + + def _make_sub_env_remote(vector_index): + sub_env = _make_sub_env_local(vector_index) + self.callbacks.on_sub_environment_created( + worker=self, + sub_environment=sub_env, + env_context=env_context.copy_with_overrides( + worker_index=env_context.worker_index, + vector_index=vector_index, + remote=False, + ), + ) + return sub_env + + return _make_sub_env_remote + + else: + return _make_sub_env_local diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sample_batch_builder.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sample_batch_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c748fe3bce75308b9f99e3af4fd1644c4cf525 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sample_batch_builder.py @@ -0,0 +1,264 @@ +import collections +import logging +import numpy as np +from typing import List, Any, Dict, TYPE_CHECKING + +from ray.rllib.env.base_env import _DUMMY_AGENT_ID +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.debug import summarize +from ray.rllib.utils.deprecation import deprecation_warning +from ray.rllib.utils.typing import PolicyID, AgentID +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.rllib.callbacks.callbacks import RLlibCallback + +logger = logging.getLogger(__name__) + + +def _to_float_array(v: List[Any]) -> np.ndarray: + arr = np.array(v) + if arr.dtype == np.float64: + return arr.astype(np.float32) # save some memory + return arr + + +@OldAPIStack +class SampleBatchBuilder: + """Util to build a SampleBatch incrementally. + + For efficiency, SampleBatches hold values in column form (as arrays). + However, it is useful to add data one row (dict) at a time. + """ + + _next_unroll_id = 0 # disambiguates unrolls within a single episode + + def __init__(self): + self.buffers: Dict[str, List] = collections.defaultdict(list) + self.count = 0 + + def add_values(self, **values: Any) -> None: + """Add the given dictionary (row) of values to this batch.""" + + for k, v in values.items(): + self.buffers[k].append(v) + self.count += 1 + + def add_batch(self, batch: SampleBatch) -> None: + """Add the given batch of values to this batch.""" + + for k, column in batch.items(): + self.buffers[k].extend(column) + self.count += batch.count + + def build_and_reset(self) -> SampleBatch: + """Returns a sample batch including all previously added values.""" + + batch = SampleBatch({k: _to_float_array(v) for k, v in self.buffers.items()}) + if SampleBatch.UNROLL_ID not in batch: + batch[SampleBatch.UNROLL_ID] = np.repeat( + SampleBatchBuilder._next_unroll_id, batch.count + ) + SampleBatchBuilder._next_unroll_id += 1 + self.buffers.clear() + self.count = 0 + return batch + + +@OldAPIStack +class MultiAgentSampleBatchBuilder: + """Util to build SampleBatches for each policy in a multi-agent env. + + Input data is per-agent, while output data is per-policy. There is an M:N + mapping between agents and policies. We retain one local batch builder + per agent. When an agent is done, then its local batch is appended into the + corresponding policy batch for the agent's policy. + """ + + def __init__( + self, + policy_map: Dict[PolicyID, Policy], + clip_rewards: bool, + callbacks: "RLlibCallback", + ): + """Initialize a MultiAgentSampleBatchBuilder. + + Args: + policy_map (Dict[str,Policy]): Maps policy ids to policy instances. + clip_rewards (Union[bool,float]): Whether to clip rewards before + postprocessing (at +/-1.0) or the actual value to +/- clip. + callbacks: RLlib callbacks. + """ + if log_once("MultiAgentSampleBatchBuilder"): + deprecation_warning(old="MultiAgentSampleBatchBuilder", error=False) + self.policy_map = policy_map + self.clip_rewards = clip_rewards + # Build the Policies' SampleBatchBuilders. + self.policy_builders = {k: SampleBatchBuilder() for k in policy_map.keys()} + # Whenever we observe a new agent, add a new SampleBatchBuilder for + # this agent. + self.agent_builders = {} + # Internal agent-to-policy map. + self.agent_to_policy = {} + self.callbacks = callbacks + # Number of "inference" steps taken in the environment. + # Regardless of the number of agents involved in each of these steps. + self.count = 0 + + def total(self) -> int: + """Returns the total number of steps taken in the env (all agents). + + Returns: + int: The number of steps taken in total in the environment over all + agents. + """ + + return sum(a.count for a in self.agent_builders.values()) + + def has_pending_agent_data(self) -> bool: + """Returns whether there is pending unprocessed data. + + Returns: + bool: True if there is at least one per-agent builder (with data + in it). + """ + + return len(self.agent_builders) > 0 + + def add_values(self, agent_id: AgentID, policy_id: AgentID, **values: Any) -> None: + """Add the given dictionary (row) of values to this batch. + + Args: + agent_id: Unique id for the agent we are adding values for. + policy_id: Unique id for policy controlling the agent. + values: Row of values to add for this agent. + """ + + if agent_id not in self.agent_builders: + self.agent_builders[agent_id] = SampleBatchBuilder() + self.agent_to_policy[agent_id] = policy_id + + # Include the current agent id for multi-agent algorithms. + if agent_id != _DUMMY_AGENT_ID: + values["agent_id"] = agent_id + + self.agent_builders[agent_id].add_values(**values) + + def postprocess_batch_so_far(self, episode=None) -> None: + """Apply policy postprocessors to any unprocessed rows. + + This pushes the postprocessed per-agent batches onto the per-policy + builders, clearing per-agent state. + + Args: + episode (Optional[Episode]): The Episode object that + holds this MultiAgentBatchBuilder object. + """ + + # Materialize the batches so far. + pre_batches = {} + for agent_id, builder in self.agent_builders.items(): + pre_batches[agent_id] = ( + self.policy_map[self.agent_to_policy[agent_id]], + builder.build_and_reset(), + ) + + # Apply postprocessor. + post_batches = {} + if self.clip_rewards is True: + for _, (_, pre_batch) in pre_batches.items(): + pre_batch["rewards"] = np.sign(pre_batch["rewards"]) + elif self.clip_rewards: + for _, (_, pre_batch) in pre_batches.items(): + pre_batch["rewards"] = np.clip( + pre_batch["rewards"], + a_min=-self.clip_rewards, + a_max=self.clip_rewards, + ) + for agent_id, (_, pre_batch) in pre_batches.items(): + other_batches = pre_batches.copy() + del other_batches[agent_id] + policy = self.policy_map[self.agent_to_policy[agent_id]] + if ( + not pre_batch.is_single_trajectory() + or len(set(pre_batch[SampleBatch.EPS_ID])) > 1 + ): + raise ValueError( + "Batches sent to postprocessing must only contain steps " + "from a single trajectory.", + pre_batch, + ) + # Call the Policy's Exploration's postprocess method. + post_batches[agent_id] = pre_batch + if getattr(policy, "exploration", None) is not None: + policy.exploration.postprocess_trajectory( + policy, post_batches[agent_id], policy.get_session() + ) + post_batches[agent_id] = policy.postprocess_trajectory( + post_batches[agent_id], other_batches, episode + ) + + if log_once("after_post"): + logger.info( + "Trajectory fragment after postprocess_trajectory():\n\n{}\n".format( + summarize(post_batches) + ) + ) + + # Append into policy batches and reset + from ray.rllib.evaluation.rollout_worker import get_global_worker + + for agent_id, post_batch in sorted(post_batches.items()): + self.callbacks.on_postprocess_trajectory( + worker=get_global_worker(), + episode=episode, + agent_id=agent_id, + policy_id=self.agent_to_policy[agent_id], + policies=self.policy_map, + postprocessed_batch=post_batch, + original_batches=pre_batches, + ) + self.policy_builders[self.agent_to_policy[agent_id]].add_batch(post_batch) + + self.agent_builders.clear() + self.agent_to_policy.clear() + + def check_missing_dones(self) -> None: + for agent_id, builder in self.agent_builders.items(): + if not builder.buffers.is_terminated_or_truncated(): + raise ValueError( + "The environment terminated for all agents, but we still " + "don't have a last observation for " + "agent {} (policy {}). ".format( + agent_id, self.agent_to_policy[agent_id] + ) + + "Please ensure that you include the last observations " + "of all live agents when setting '__all__' terminated|truncated " + "to True. " + ) + + def build_and_reset(self, episode=None) -> MultiAgentBatch: + """Returns the accumulated sample batches for each policy. + + Any unprocessed rows will be first postprocessed with a policy + postprocessor. The internal state of this builder will be reset. + + Args: + episode (Optional[Episode]): The Episode object that + holds this MultiAgentBatchBuilder object or None. + + Returns: + MultiAgentBatch: Returns the accumulated sample batches for each + policy. + """ + + self.postprocess_batch_so_far(episode) + policy_batches = {} + for policy_id, builder in self.policy_builders.items(): + if builder.count > 0: + policy_batches[policy_id] = builder.build_and_reset() + old_count = self.count + self.count = 0 + return MultiAgentBatch.wrap_as_needed(policy_batches, old_count) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..c6b4ce937e6babc33d2de1282ed631e373fc68f7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/sampler.py @@ -0,0 +1,253 @@ +import logging +import queue +from abc import ABCMeta, abstractmethod +from collections import defaultdict, namedtuple +from typing import ( + TYPE_CHECKING, + Any, + List, + Optional, + Type, + Union, +) + +from ray.rllib.env.base_env import BaseEnv, convert_to_base_env +from ray.rllib.evaluation.collectors.sample_collector import SampleCollector +from ray.rllib.evaluation.collectors.simple_list_collector import SimpleListCollector +from ray.rllib.evaluation.env_runner_v2 import EnvRunnerV2, _PerfStats +from ray.rllib.evaluation.metrics import RolloutMetrics +from ray.rllib.offline import InputReader +from ray.rllib.policy.sample_batch import concat_samples +from ray.rllib.utils.annotations import OldAPIStack, override +from ray.rllib.utils.deprecation import deprecation_warning, DEPRECATED_VALUE +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import SampleBatchType +from ray.util.debug import log_once + +if TYPE_CHECKING: + from ray.rllib.callbacks.callbacks import RLlibCallback + from ray.rllib.evaluation.observation_function import ObservationFunction + from ray.rllib.evaluation.rollout_worker import RolloutWorker + +tf1, tf, _ = try_import_tf() +logger = logging.getLogger(__name__) + +_PolicyEvalData = namedtuple( + "_PolicyEvalData", + ["env_id", "agent_id", "obs", "info", "rnn_state", "prev_action", "prev_reward"], +) + +# A batch of RNN states with dimensions [state_index, batch, state_object]. +StateBatch = List[List[Any]] + + +class _NewEpisodeDefaultDict(defaultdict): + def __missing__(self, env_id): + if self.default_factory is None: + raise KeyError(env_id) + else: + ret = self[env_id] = self.default_factory(env_id) + return ret + + +@OldAPIStack +class SamplerInput(InputReader, metaclass=ABCMeta): + """Reads input experiences from an existing sampler.""" + + @override(InputReader) + def next(self) -> SampleBatchType: + batches = [self.get_data()] + batches.extend(self.get_extra_batches()) + if len(batches) == 0: + raise RuntimeError("No data available from sampler.") + return concat_samples(batches) + + @abstractmethod + def get_data(self) -> SampleBatchType: + """Called by `self.next()` to return the next batch of data. + + Override this in child classes. + + Returns: + The next batch of data. + """ + raise NotImplementedError + + @abstractmethod + def get_metrics(self) -> List[RolloutMetrics]: + """Returns list of episode metrics since the last call to this method. + + The list will contain one RolloutMetrics object per completed episode. + + Returns: + List of RolloutMetrics objects, one per completed episode since + the last call to this method. + """ + raise NotImplementedError + + @abstractmethod + def get_extra_batches(self) -> List[SampleBatchType]: + """Returns list of extra batches since the last call to this method. + + The list will contain all SampleBatches or + MultiAgentBatches that the user has provided thus-far. Users can + add these "extra batches" to an episode by calling the episode's + `add_extra_batch([SampleBatchType])` method. This can be done from + inside an overridden `Policy.compute_actions_from_input_dict(..., + episodes)` or from a custom callback's `on_episode_[start|step|end]()` + methods. + + Returns: + List of SamplesBatches or MultiAgentBatches provided thus-far by + the user since the last call to this method. + """ + raise NotImplementedError + + +@OldAPIStack +class SyncSampler(SamplerInput): + """Sync SamplerInput that collects experiences when `get_data()` is called.""" + + def __init__( + self, + *, + worker: "RolloutWorker", + env: BaseEnv, + clip_rewards: Union[bool, float], + rollout_fragment_length: int, + count_steps_by: str = "env_steps", + callbacks: "RLlibCallback", + multiple_episodes_in_batch: bool = False, + normalize_actions: bool = True, + clip_actions: bool = False, + observation_fn: Optional["ObservationFunction"] = None, + sample_collector_class: Optional[Type[SampleCollector]] = None, + render: bool = False, + # Obsolete. + policies=None, + policy_mapping_fn=None, + preprocessors=None, + obs_filters=None, + tf_sess=None, + horizon=DEPRECATED_VALUE, + soft_horizon=DEPRECATED_VALUE, + no_done_at_end=DEPRECATED_VALUE, + ): + """Initializes a SyncSampler instance. + + Args: + worker: The RolloutWorker that will use this Sampler for sampling. + env: Any Env object. Will be converted into an RLlib BaseEnv. + clip_rewards: True for +/-1.0 clipping, + actual float value for +/- value clipping. False for no + clipping. + rollout_fragment_length: The length of a fragment to collect + before building a SampleBatch from the data and resetting + the SampleBatchBuilder object. + count_steps_by: One of "env_steps" (default) or "agent_steps". + Use "agent_steps", if you want rollout lengths to be counted + by individual agent steps. In a multi-agent env, + a single env_step contains one or more agent_steps, depending + on how many agents are present at any given time in the + ongoing episode. + callbacks: The RLlibCallback object to use when episode + events happen during rollout. + multiple_episodes_in_batch: Whether to pack multiple + episodes into each batch. This guarantees batches will be + exactly `rollout_fragment_length` in size. + normalize_actions: Whether to normalize actions to the + action space's bounds. + clip_actions: Whether to clip actions according to the + given action_space's bounds. + observation_fn: Optional multi-agent observation func to use for + preprocessing observations. + sample_collector_class: An optional SampleCollector sub-class to + use to collect, store, and retrieve environment-, model-, + and sampler data. + render: Whether to try to render the environment after each step. + """ + # All of the following arguments are deprecated. They will instead be + # provided via the passed in `worker` arg, e.g. `worker.policy_map`. + if log_once("deprecated_sync_sampler_args"): + if policies is not None: + deprecation_warning(old="policies") + if policy_mapping_fn is not None: + deprecation_warning(old="policy_mapping_fn") + if preprocessors is not None: + deprecation_warning(old="preprocessors") + if obs_filters is not None: + deprecation_warning(old="obs_filters") + if tf_sess is not None: + deprecation_warning(old="tf_sess") + if horizon != DEPRECATED_VALUE: + deprecation_warning(old="horizon", error=True) + if soft_horizon != DEPRECATED_VALUE: + deprecation_warning(old="soft_horizon", error=True) + if no_done_at_end != DEPRECATED_VALUE: + deprecation_warning(old="no_done_at_end", error=True) + + self.base_env = convert_to_base_env(env) + self.rollout_fragment_length = rollout_fragment_length + self.extra_batches = queue.Queue() + self.perf_stats = _PerfStats( + ema_coef=worker.config.sampler_perf_stats_ema_coef, + ) + if not sample_collector_class: + sample_collector_class = SimpleListCollector + self.sample_collector = sample_collector_class( + worker.policy_map, + clip_rewards, + callbacks, + multiple_episodes_in_batch, + rollout_fragment_length, + count_steps_by=count_steps_by, + ) + self.render = render + + # Keep a reference to the underlying EnvRunnerV2 instance for + # unit testing purpose. + self._env_runner_obj = EnvRunnerV2( + worker=worker, + base_env=self.base_env, + multiple_episodes_in_batch=multiple_episodes_in_batch, + callbacks=callbacks, + perf_stats=self.perf_stats, + rollout_fragment_length=rollout_fragment_length, + count_steps_by=count_steps_by, + render=self.render, + ) + self._env_runner = self._env_runner_obj.run() + self.metrics_queue = queue.Queue() + + @override(SamplerInput) + def get_data(self) -> SampleBatchType: + while True: + item = next(self._env_runner) + if isinstance(item, RolloutMetrics): + self.metrics_queue.put(item) + else: + return item + + @override(SamplerInput) + def get_metrics(self) -> List[RolloutMetrics]: + completed = [] + while True: + try: + completed.append( + self.metrics_queue.get_nowait()._replace( + perf_stats=self.perf_stats.get() + ) + ) + except queue.Empty: + break + return completed + + @override(SamplerInput) + def get_extra_batches(self) -> List[SampleBatchType]: + extra = [] + while True: + try: + extra.append(self.extra_batches.get_nowait()) + except queue.Empty: + break + return extra diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py new file mode 100644 index 0000000000000000000000000000000000000000..0eeea1ea2c8f00cd0cc3075af18cb0f1b7228126 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/evaluation/worker_set.py @@ -0,0 +1,10 @@ +from ray.rllib.utils.deprecation import Deprecated + + +@Deprecated( + new="ray.rllib.env.env_runner_group.EnvRunnerGroup", + help="The class has only be renamed w/o any changes in functionality.", + error=True, +) +class WorkerSet: + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f579a77701ddb7f94aeccf58f3fc66e44c58fa83 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/attention_net.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/attention_net.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..acb1d64e13bfb7d4e5df258a0b0b50384331d4aa Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/attention_net.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/fcnet.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/fcnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3494b273b51dc9ebf19a21b49c13d1b6f9f4c836 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/fcnet.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/mingpt.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/mingpt.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c82ff0739f51e097c75bb17f5f65f7d1eb5ac22e Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/mingpt.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/recurrent_net.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/recurrent_net.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80b2e3ac56c1a2b76722455bfc64acf5cf3f9d04 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/recurrent_net.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_action_dist.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_action_dist.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5de4dedf5c4422f7b22c5379f3c5451bdc517f24 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_action_dist.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_distributions.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_distributions.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..36effd5c6094d3f900df715abd2875ce324c80cd Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/models/torch/__pycache__/torch_distributions.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__init__.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..cc4a0d9bb05dc8ec1e2a904ba361f6be13a99120 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__init__.py @@ -0,0 +1,30 @@ +from ray.rllib.offline.d4rl_reader import D4RLReader +from ray.rllib.offline.dataset_reader import DatasetReader, get_dataset_and_shards +from ray.rllib.offline.dataset_writer import DatasetWriter +from ray.rllib.offline.io_context import IOContext +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.mixed_input import MixedInput +from ray.rllib.offline.json_reader import JsonReader +from ray.rllib.offline.json_writer import JsonWriter +from ray.rllib.offline.output_writer import OutputWriter, NoopOutput +from ray.rllib.offline.resource import get_offline_io_resource_bundles +from ray.rllib.offline.shuffled_input import ShuffledInput +from ray.rllib.offline.feature_importance import FeatureImportance + + +__all__ = [ + "IOContext", + "JsonReader", + "JsonWriter", + "NoopOutput", + "OutputWriter", + "InputReader", + "MixedInput", + "ShuffledInput", + "D4RLReader", + "DatasetReader", + "DatasetWriter", + "get_dataset_and_shards", + "get_offline_io_resource_bundles", + "FeatureImportance", +] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_reader.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_reader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..659a6f222ff39181bcbd99e39d56fd29572d22ee Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_reader.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_writer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_writer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2fba94e648401d369f53628c3dc5b132d816dee6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/dataset_writer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/feature_importance.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/feature_importance.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6ecb6cd10479b2639f02b56466dd84866dfb730b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/feature_importance.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/io_context.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/io_context.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a0f3fc768db38fe1ffa9c08672aefc4202f04016 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/io_context.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/is_estimator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/is_estimator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..95415b929a582b043862e0041ba3cdc5a39b58a3 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/is_estimator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_reader.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_reader.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8a440310744e944e395b59bdebfafb44190fc729 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_reader.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_writer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_writer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eda9d627c7a47b649a6ab8713a812d91cf6b9b68 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/json_writer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/mixed_input.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/mixed_input.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4164dd64cb9fe3fd2445975ea3194ba882804b76 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/mixed_input.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/off_policy_estimator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/off_policy_estimator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..892f523f3f55e23bb4d22b76a4ebc6b3115baf5f Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/off_policy_estimator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_data.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_data.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..584a3cce8f76a273ff480136302ce262d84db685 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_data.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_env_runner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_env_runner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..60ec8d777702030b128bfb1b8bbc2bf8b0564189 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_env_runner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluation_utils.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluation_utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..552e346672f4e771649410aed3361be5a0570106 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluation_utils.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..94f76ecea97d89a12e89925ec0e6e7943f733509 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_evaluator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_prelearner.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_prelearner.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cf74159a87e7f6e2af7a6172160791f6fea7d268 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/offline_prelearner.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/output_writer.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/output_writer.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6969f054348c00e6503ce953d5336304d91150fb Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/output_writer.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/wis_estimator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/wis_estimator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cc4c4ccfbcfee1dcc4b7b73fd5bcc7fef1d1d51 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/__pycache__/wis_estimator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/d4rl_reader.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/d4rl_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..b9f18634b3d1c3eae21b570e11d1905509db948e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/d4rl_reader.py @@ -0,0 +1,51 @@ +import logging +import gymnasium as gym + +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.io_context import IOContext +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.typing import SampleBatchType +from typing import Dict + +logger = logging.getLogger(__name__) + + +@PublicAPI +class D4RLReader(InputReader): + """Reader object that loads the dataset from the D4RL dataset.""" + + @PublicAPI + def __init__(self, inputs: str, ioctx: IOContext = None): + """Initializes a D4RLReader instance. + + Args: + inputs: String corresponding to the D4RL environment name. + ioctx: Current IO context object. + """ + import d4rl + + self.env = gym.make(inputs) + self.dataset = _convert_to_batch(d4rl.qlearning_dataset(self.env)) + assert self.dataset.count >= 1 + self.counter = 0 + + @override(InputReader) + def next(self) -> SampleBatchType: + if self.counter >= self.dataset.count: + self.counter = 0 + + self.counter += 1 + return self.dataset.slice(start=self.counter, end=self.counter + 1) + + +def _convert_to_batch(dataset: Dict) -> SampleBatchType: + # Converts D4RL dataset to SampleBatch + d = {} + d[SampleBatch.OBS] = dataset["observations"] + d[SampleBatch.ACTIONS] = dataset["actions"] + d[SampleBatch.NEXT_OBS] = dataset["next_observations"] + d[SampleBatch.REWARDS] = dataset["rewards"] + d[SampleBatch.TERMINATEDS] = dataset["terminals"] + + return SampleBatch(d) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_reader.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..1172aa7f5d0d57644640250d3767fab700e6766a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_reader.py @@ -0,0 +1,289 @@ +import logging +import math +from pathlib import Path +import re +import numpy as np +from typing import List, Tuple, TYPE_CHECKING, Optional +import zipfile + +import ray.data +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.io_context import IOContext +from ray.rllib.offline.json_reader import from_json_data, postprocess_actions +from ray.rllib.policy.sample_batch import concat_samples, SampleBatch, DEFAULT_POLICY_ID +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.typing import SampleBatchType + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + +DEFAULT_NUM_CPUS_PER_TASK = 0.5 + +logger = logging.getLogger(__name__) + + +def _unzip_this_path(fpath: Path, extract_path: str): + with zipfile.ZipFile(str(fpath), "r") as zip_ref: + zip_ref.extractall(extract_path) + + +def _unzip_if_needed(paths: List[str], format: str): + """If a path in paths is a zip file, unzip it and use path of the unzipped file""" + ret_paths = [] + for path in paths: + if re.search("\\.zip$", str(path)): + # TODO: We need to add unzip support for s3 + if str(path).startswith("s3://"): + raise ValueError( + "unzip_if_needed currently does not support remote paths from s3" + ) + extract_path = "./" + try: + _unzip_this_path(str(path), extract_path) + except FileNotFoundError: + # intrepreted as a relative path to rllib folder + try: + # TODO: remove this later when we replace all tests with s3 paths + _unzip_this_path(Path(__file__).parent.parent / path, extract_path) + except FileNotFoundError: + raise FileNotFoundError(f"File not found: {path}") + + unzipped_path = str( + Path(extract_path).absolute() / f"{Path(path).stem}.{format}" + ) + ret_paths.append(unzipped_path) + else: + # TODO: We can get rid of this logic when we replace all tests with s3 paths + if str(path).startswith("s3://"): + ret_paths.append(path) + else: + if not Path(path).exists(): + relative_path = str(Path(__file__).parent.parent / path) + if not Path(relative_path).exists(): + raise FileNotFoundError(f"File not found: {path}") + path = relative_path + ret_paths.append(path) + return ret_paths + + +@PublicAPI +def get_dataset_and_shards( + config: "AlgorithmConfig", num_workers: int = 0 +) -> Tuple[ray.data.Dataset, List[ray.data.Dataset]]: + """Returns a dataset and a list of shards. + + This function uses algorithm configs to create a dataset and a list of shards. + The following config keys are used to create the dataset: + input: The input type should be "dataset". + input_config: A dict containing the following key and values: + `format`: str, speciifies the format of the input data. This will be the + format that ray dataset supports. See ray.data.Dataset for + supported formats. Only "parquet" or "json" are supported for now. + `paths`: str, a single string or a list of strings. Each string is a path + to a file or a directory holding the dataset. It can be either a local path + or a remote path (e.g. to an s3 bucket). + `loader_fn`: Callable[None, ray.data.Dataset], Instead of + specifying paths and format, you can specify a function to load the dataset. + `parallelism`: int, The number of tasks to use for loading the dataset. + If not specified, it will be set to the number of workers. + `num_cpus_per_read_task`: float, The number of CPUs to use for each read + task. If not specified, it will be set to 0.5. + + Args: + config: The config dict for the algorithm. + num_workers: The number of shards to create for remote workers. + + Returns: + dataset: The dataset object. + shards: A list of dataset shards. For num_workers > 0 the first returned + shared would be a dummy None shard for local_worker. + """ + # check input and input config keys + assert config.input_ == "dataset", ( + f"Must specify config.input_ as 'dataset' if" + f" calling `get_dataset_and_shards`. Got {config.input_}" + ) + + # check input config format + input_config = config.input_config + format = input_config.get("format") + + supported_fmts = ["json", "parquet"] + if format is not None and format not in supported_fmts: + raise ValueError( + f"Unsupported format {format}. Supported formats are {supported_fmts}" + ) + + # check paths and loader_fn since only one of them is required. + paths = input_config.get("paths") + loader_fn = input_config.get("loader_fn") + if loader_fn and (format or paths): + raise ValueError( + "When using a `loader_fn`, you cannot specify a `format` or `path`." + ) + + # check if at least loader_fn or format + path is specified. + if not (format and paths) and not loader_fn: + raise ValueError( + "Must specify either a `loader_fn` or a `format` and `path` in " + "`input_config`." + ) + + # check paths to be a str or list[str] if not None + if paths is not None: + if isinstance(paths, str): + paths = [paths] + elif isinstance(paths, list): + assert isinstance(paths[0], str), "Paths must be a list of path strings." + else: + raise ValueError("Paths must be a path string or a list of path strings.") + paths = _unzip_if_needed(paths, format) + + # TODO (Kourosh): num_workers is not necessary since we can use parallelism for + # everything. Having two parameters is confusing here. Remove num_workers later. + parallelism = input_config.get("parallelism", num_workers or 1) + cpus_per_task = input_config.get( + "num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK + ) + + if loader_fn: + dataset = loader_fn() + elif format == "json": + dataset = ray.data.read_json( + paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task} + ) + elif format == "parquet": + dataset = ray.data.read_parquet( + paths, parallelism=parallelism, ray_remote_args={"num_cpus": cpus_per_task} + ) + else: + raise ValueError("Un-supported Ray dataset format: ", format) + + # Local worker will be responsible for sampling. + if num_workers == 0: + # Dataset is the only shard we need. + return dataset, [dataset] + # Remote workers are responsible for sampling: + else: + # Each remote worker gets 1 shard. + remote_shards = dataset.repartition( + num_blocks=num_workers, shuffle=False + ).split(num_workers) + + # The first None shard is for the local worker, which + # shouldn't be doing rollout work anyways. + return dataset, [None] + remote_shards + + +@PublicAPI +class DatasetReader(InputReader): + """Reader object that loads data from Ray Dataset. + + Examples: + config = { + "input": "dataset", + "input_config": { + "format": "json", + # A single data file, a directory, or anything + # that ray.data.dataset recognizes. + "paths": "/tmp/sample_batches/", + # By default, parallelism=num_workers. + "parallelism": 3, + # Dataset allocates 0.5 CPU for each reader by default. + # Adjust this value based on the size of your offline dataset. + "num_cpus_per_read_task": 0.5, + } + } + """ + + @PublicAPI + def __init__(self, ds: ray.data.Dataset, ioctx: Optional[IOContext] = None): + """Initializes a DatasetReader instance. + + Args: + ds: Ray dataset to sample from. + """ + self._ioctx = ioctx or IOContext() + self._default_policy = self.policy_map = None + self.preprocessor = None + self._dataset = ds + self.count = None if not self._dataset else self._dataset.count() + # do this to disable the ray data stdout logging + ray.data.DataContext.get_current().enable_progress_bars = False + + # the number of steps to return per call to next() + self.batch_size = self._ioctx.config.get("train_batch_size", 1) + num_workers = self._ioctx.config.get("num_env_runners", 0) + seed = self._ioctx.config.get("seed", None) + if num_workers: + self.batch_size = max(math.ceil(self.batch_size / num_workers), 1) + # We allow the creation of a non-functioning None DatasetReader. + # It's useful for example for a non-rollout local worker. + if ds: + if self._ioctx.worker is not None: + self._policy_map = self._ioctx.worker.policy_map + self._default_policy = self._policy_map.get(DEFAULT_POLICY_ID) + self.preprocessor = ( + self._ioctx.worker.preprocessors.get(DEFAULT_POLICY_ID) + if not self._ioctx.config.get("_disable_preprocessors", False) + else None + ) + print( + f"DatasetReader {self._ioctx.worker_index} has {ds.count()}, samples." + ) + + def iterator(): + while True: + ds = self._dataset.random_shuffle(seed=seed) + yield from ds.iter_rows() + + self._iter = iterator() + else: + self._iter = None + + @override(InputReader) + def next(self) -> SampleBatchType: + # next() should not get called on None DatasetReader. + assert self._iter is not None + ret = [] + count = 0 + while count < self.batch_size: + d = next(self._iter) + # Columns like obs are compressed when written by DatasetWriter. + d = from_json_data(d, self._ioctx.worker) + count += d.count + d = self._preprocess_if_needed(d) + d = postprocess_actions(d, self._ioctx) + d = self._postprocess_if_needed(d) + ret.append(d) + ret = concat_samples(ret) + return ret + + def _preprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: + # TODO: @kourosh, preprocessor is only supported for single agent case. + if self.preprocessor: + for key in (SampleBatch.CUR_OBS, SampleBatch.NEXT_OBS): + if key in batch: + batch[key] = np.stack( + [self.preprocessor.transform(s) for s in batch[key]] + ) + return batch + + def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: + if not self._ioctx.config.get("postprocess_inputs"): + return batch + + if isinstance(batch, SampleBatch): + out = [] + for sub_batch in batch.split_by_episode(): + if self._default_policy is not None: + out.append(self._default_policy.postprocess_trajectory(sub_batch)) + else: + out.append(sub_batch) + return concat_samples(out) + else: + # TODO(ekl) this is trickier since the alignments between agent + # trajectories in the episode are not available any more. + raise NotImplementedError( + "Postprocessing of multi-agent data not implemented yet." + ) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_writer.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..b517933ce985091f2192dd9b03e41e9b2d304cfe --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/dataset_writer.py @@ -0,0 +1,82 @@ +import logging +import os +import time + +from ray import data +from ray.rllib.offline.io_context import IOContext +from ray.rllib.offline.json_writer import _to_json_dict +from ray.rllib.offline.output_writer import OutputWriter +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.typing import SampleBatchType +from typing import Dict, List + +logger = logging.getLogger(__name__) + + +@PublicAPI +class DatasetWriter(OutputWriter): + """Writer object that saves experiences using Datasets.""" + + @PublicAPI + def __init__( + self, + ioctx: IOContext = None, + compress_columns: List[str] = frozenset(["obs", "new_obs"]), + ): + """Initializes a DatasetWriter instance. + + Examples: + config = { + "output": "dataset", + "output_config": { + "format": "json", + "path": "/tmp/test_samples/", + "max_num_samples_per_file": 100000, + } + } + + Args: + ioctx: current IO context object. + compress_columns: list of sample batch columns to compress. + """ + self.ioctx = ioctx or IOContext() + + output_config: Dict = ioctx.output_config + assert ( + "format" in output_config + ), "output_config.format must be specified when using Dataset output." + assert ( + "path" in output_config + ), "output_config.path must be specified when using Dataset output." + + self.format = output_config["format"] + self.path = os.path.abspath(os.path.expanduser(output_config["path"])) + self.max_num_samples_per_file = ( + output_config["max_num_samples_per_file"] + if "max_num_samples_per_file" in output_config + else 100000 + ) + self.compress_columns = compress_columns + + self.samples = [] + + @override(OutputWriter) + def write(self, sample_batch: SampleBatchType): + start = time.time() + + # Make sure columns like obs are compressed and writable. + d = _to_json_dict(sample_batch, self.compress_columns) + self.samples.append(d) + + # Todo: We should flush at the end of sampling even if this + # condition was not reached. + if len(self.samples) >= self.max_num_samples_per_file: + ds = data.from_items(self.samples).repartition(num_blocks=1, shuffle=False) + if self.format == "json": + ds.write_json(self.path, try_create_dir=True) + elif self.format == "parquet": + ds.write_parquet(self.path, try_create_dir=True) + else: + raise ValueError("Unknown output type: ", self.format) + self.samples = [] + logger.debug("Wrote dataset in {}s".format(time.time() - start)) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/__init__.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6d92c40d157ff49069843b7c34e8f18e35fe2c32 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/__init__.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/direct_method.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/direct_method.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fdcdd9976c423ff8a0577cd6021bb4cbb5b1349b Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/direct_method.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/fqe_torch_model.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/fqe_torch_model.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..24069affcf12234b6f19d73da3c41e29cc740e21 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/fqe_torch_model.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/off_policy_estimator.cpython-311.pyc b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/off_policy_estimator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..edcb83a7fccfff4878d8942f0ed30b73becf96d6 Binary files /dev/null and b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/__pycache__/off_policy_estimator.cpython-311.pyc differ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/feature_importance.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/feature_importance.py new file mode 100644 index 0000000000000000000000000000000000000000..a5d4d171893239f51936dc9f1bd33803a79b1f55 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/feature_importance.py @@ -0,0 +1,11 @@ +from ray.rllib.offline.feature_importance import FeatureImportance + +__all__ = ["FeatureImportance"] + +from ray.rllib.utils.deprecation import deprecation_warning + +deprecation_warning( + "ray.rllib.offline.estimators.feature_importance.FeatureImportance", + "ray.rllib.offline.feature_importance.FeatureImportance", + error=True, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/weighted_importance_sampling.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/weighted_importance_sampling.py new file mode 100644 index 0000000000000000000000000000000000000000..cfca393a021253ccdd346ea529c3f0d54690067d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/estimators/weighted_importance_sampling.py @@ -0,0 +1,185 @@ +from typing import Dict, Any, List +import numpy as np +import math + +from ray.data import Dataset + +from ray.rllib.offline.offline_evaluator import OfflineEvaluator +from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator +from ray.rllib.offline.offline_evaluation_utils import ( + remove_time_dim, + compute_is_weights, +) +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy import Policy +from ray.rllib.utils.annotations import override, DeveloperAPI + + +@DeveloperAPI +class WeightedImportanceSampling(OffPolicyEstimator): + r"""The step-wise WIS estimator. + + Let s_t, a_t, and r_t be the state, action, and reward at timestep t. + + For behavior policy \pi_b and evaluation policy \pi_e, define the + cumulative importance ratio at timestep t as: + p_t = \sum_{t'=0}^t (\pi_e(a_{t'} | s_{t'}) / \pi_b(a_{t'} | s_{t'})). + + Define the average importance ratio over episodes i in the dataset D as: + w_t = \sum_{i \in D} p^(i)_t / |D| + + This estimator computes the expected return for \pi_e for an episode as: + V^{\pi_e}(s_0) = \E[\sum_t \gamma ^ {t} * (p_t / w_t) * r_t] + and returns the mean and standard deviation over episodes. + + For more information refer to https://arxiv.org/pdf/1911.06854.pdf""" + + @override(OffPolicyEstimator) + def __init__(self, policy: Policy, gamma: float, epsilon_greedy: float = 0.0): + super().__init__(policy, gamma, epsilon_greedy) + # map from time to cummulative propensity values + self.cummulative_ips_values = [] + # map from time to number of episodes that reached this time + self.episode_timestep_count = [] + # map from eps id to mapping from time to propensity values + self.p = {} + + @override(OffPolicyEstimator) + def estimate_on_single_episode(self, episode: SampleBatch) -> Dict[str, Any]: + estimates_per_epsiode = {} + rewards = episode["rewards"] + + eps_id = episode[SampleBatch.EPS_ID][0] + if eps_id not in self.p: + raise ValueError( + f"Cannot find target weight for episode {eps_id}. " + f"Did it go though the peek_on_single_episode() function?" + ) + + # calculate stepwise weighted IS estimate + v_behavior = 0.0 + v_target = 0.0 + episode_p = self.p[eps_id] + for t in range(episode.count): + v_behavior += rewards[t] * self.gamma**t + w_t = self.cummulative_ips_values[t] / self.episode_timestep_count[t] + v_target += episode_p[t] / w_t * rewards[t] * self.gamma**t + + estimates_per_epsiode["v_behavior"] = v_behavior + estimates_per_epsiode["v_target"] = v_target + + return estimates_per_epsiode + + @override(OffPolicyEstimator) + def estimate_on_single_step_samples( + self, batch: SampleBatch + ) -> Dict[str, List[float]]: + estimates_per_epsiode = {} + rewards, old_prob = batch["rewards"], batch["action_prob"] + new_prob = self.compute_action_probs(batch) + + weights = new_prob / old_prob + v_behavior = rewards + v_target = weights * rewards / np.mean(weights) + + estimates_per_epsiode["v_behavior"] = v_behavior + estimates_per_epsiode["v_target"] = v_target + estimates_per_epsiode["weights"] = weights + estimates_per_epsiode["new_prob"] = new_prob + estimates_per_epsiode["old_prob"] = old_prob + + return estimates_per_epsiode + + @override(OffPolicyEstimator) + def on_before_split_batch_by_episode( + self, sample_batch: SampleBatch + ) -> SampleBatch: + self.cummulative_ips_values = [] + self.episode_timestep_count = [] + self.p = {} + + return sample_batch + + @override(OffPolicyEstimator) + def peek_on_single_episode(self, episode: SampleBatch) -> None: + old_prob = episode["action_prob"] + new_prob = self.compute_action_probs(episode) + + # calculate importance ratios + episode_p = [] + for t in range(episode.count): + if t == 0: + pt_prev = 1.0 + else: + pt_prev = episode_p[t - 1] + episode_p.append(pt_prev * new_prob[t] / old_prob[t]) + + for t, p_t in enumerate(episode_p): + if t >= len(self.cummulative_ips_values): + self.cummulative_ips_values.append(p_t) + self.episode_timestep_count.append(1.0) + else: + self.cummulative_ips_values[t] += p_t + self.episode_timestep_count[t] += 1.0 + + eps_id = episode[SampleBatch.EPS_ID][0] + if eps_id in self.p: + raise ValueError( + f"eps_id {eps_id} was already passed to the peek function. " + f"Make sure dataset contains only unique episodes with unique ids." + ) + self.p[eps_id] = episode_p + + @override(OfflineEvaluator) + def estimate_on_dataset( + self, dataset: Dataset, *, n_parallelism: int = ... + ) -> Dict[str, Any]: + """Computes the weighted importance sampling estimate on a dataset. + + Note: This estimate works for both continuous and discrete action spaces. + + Args: + dataset: Dataset to compute the estimate on. Each record in dataset should + include the following columns: `obs`, `actions`, `action_prob` and + `rewards`. The `obs` on each row shoud be a vector of D dimensions. + n_parallelism: Number of parallel workers to use for the computation. + + Returns: + Dictionary with the following keys: + v_target: The weighted importance sampling estimate. + v_behavior: The behavior policy estimate. + v_gain_mean: The mean of the gain of the target policy over the + behavior policy. + v_gain_ste: The standard error of the gain of the target policy over + the behavior policy. + """ + # compute the weights and weighted rewards + batch_size = max(dataset.count() // n_parallelism, 1) + dataset = dataset.map_batches( + remove_time_dim, batch_size=batch_size, batch_format="pandas" + ) + updated_ds = dataset.map_batches( + compute_is_weights, + batch_size=batch_size, + batch_format="pandas", + fn_kwargs={ + "policy_state": self.policy.get_state(), + "estimator_class": self.__class__, + }, + ) + v_target = updated_ds.mean("weighted_rewards") / updated_ds.mean("weights") + v_behavior = updated_ds.mean("rewards") + v_gain_mean = v_target / v_behavior + v_gain_ste = ( + updated_ds.std("weighted_rewards") + / updated_ds.mean("weights") + / v_behavior + / math.sqrt(dataset.count()) + ) + + return { + "v_target": v_target, + "v_behavior": v_behavior, + "v_gain_mean": v_gain_mean, + "v_gain_ste": v_gain_ste, + } diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/feature_importance.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/feature_importance.py new file mode 100644 index 0000000000000000000000000000000000000000..2efe17790a79e5ba830b3444f31f8f6b096a5e6e --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/feature_importance.py @@ -0,0 +1,283 @@ +import copy +import numpy as np +import pandas as pd +from typing import Callable, Dict, Any + +import ray +from ray.data import Dataset + +from ray.rllib.policy import Policy +from ray.rllib.policy.sample_batch import SampleBatch, convert_ma_batch_to_sample_batch +from ray.rllib.utils.annotations import override, DeveloperAPI, ExperimentalAPI +from ray.rllib.utils.typing import SampleBatchType +from ray.rllib.offline.offline_evaluator import OfflineEvaluator + + +@DeveloperAPI +def _perturb_fn(batch: np.ndarray, index: int): + # shuffle the indexth column features + random_inds = np.random.permutation(batch.shape[0]) + batch[:, index] = batch[random_inds, index] + + +@ExperimentalAPI +def _perturb_df(batch: pd.DataFrame, index: int): + obs_batch = np.vstack(batch["obs"].values) + _perturb_fn(obs_batch, index) + batch["perturbed_obs"] = list(obs_batch) + return batch + + +def _compute_actions( + batch: pd.DataFrame, + policy_state: Dict[str, Any], + input_key: str = "", + output_key: str = "", +): + """A custom local function to do batch prediction of a policy. + + Given the policy state the action predictions are computed as a function of + `input_key` and stored in the `output_key` column. + + Args: + batch: A sub-batch from the dataset. + policy_state: The state of the policy to use for the prediction. + input_key: The key to use for the input to the policy. If not given, the + default is SampleBatch.OBS. + output_key: The key to use for the output of the policy. If not given, the + default is "predicted_actions". + + Returns: + The modified batch with the predicted actions added as a column. + """ + if not input_key: + input_key = SampleBatch.OBS + + policy = Policy.from_state(policy_state) + sample_batch = SampleBatch( + { + SampleBatch.OBS: np.vstack(batch[input_key].values), + } + ) + actions, _, _ = policy.compute_actions_from_input_dict(sample_batch, explore=False) + + if not output_key: + output_key = "predicted_actions" + batch[output_key] = actions + + return batch + + +@ray.remote +def get_feature_importance_on_index( + dataset: ray.data.Dataset, + *, + index: int, + perturb_fn: Callable[[pd.DataFrame, int], None], + batch_size: int, + policy_state: Dict[str, Any], +): + """A remote function to compute the feature importance of a given index. + + Args: + dataset: The dataset to use for the computation. The dataset should have `obs` + and `actions` columns. Each record should be flat d-dimensional array. + index: The index of the feature to compute the importance for. + perturb_fn: The function to use for perturbing the dataset at the given index. + batch_size: The batch size to use for the computation. + policy_state: The state of the policy to use for the computation. + + Returns: + The modified dataset that contains a `delta` column which is the absolute + difference between the expected output and the output due to the perturbation. + """ + perturbed_ds = dataset.map_batches( + perturb_fn, + batch_size=batch_size, + batch_format="pandas", + fn_kwargs={"index": index}, + ) + perturbed_actions = perturbed_ds.map_batches( + _compute_actions, + batch_size=batch_size, + batch_format="pandas", + fn_kwargs={ + "output_key": "perturbed_actions", + "input_key": "perturbed_obs", + "policy_state": policy_state, + }, + ) + + def delta_fn(batch): + # take the abs difference between columns 'ref_actions` and `perturbed_actions` + # and store it in `diff` + batch["delta"] = np.abs(batch["ref_actions"] - batch["perturbed_actions"]) + return batch + + delta = perturbed_actions.map_batches( + delta_fn, batch_size=batch_size, batch_format="pandas" + ) + + return delta + + +@DeveloperAPI +class FeatureImportance(OfflineEvaluator): + @override(OfflineEvaluator) + def __init__( + self, + policy: Policy, + repeat: int = 1, + limit_fraction: float = 1.0, + perturb_fn: Callable[[pd.DataFrame, int], pd.DataFrame] = _perturb_df, + ): + """Feature importance in a model inspection technique that can be used for any + fitted predictor when the data is tablular. + + This implementation is also known as permutation importance that is defined to + be the variation of the model's prediction when a single feature value is + randomly shuffled. In RLlib it is implemented as a custom OffPolicyEstimator + which is used to evaluate RLlib policies without performing environment + interactions. + + Example usage: In the example below the feature importance module is used to + evaluate the policy and the each feature's importance is computed after each + training iteration. The permutation are repeated `self.repeat` times and the + results are averages across repeats. + + ```python + config = ( + AlgorithmConfig() + .offline_data( + off_policy_estimation_methods= + { + "feature_importance": { + "type": FeatureImportance, + "repeat": 10, + "limit_fraction": 0.1, + } + } + ) + ) + + algorithm = DQN(config=config) + results = algorithm.train() + ``` + + Args: + policy: the policy to use for feature importance. + repeat: number of times to repeat the perturbation. + perturb_fn: function to perturb the features. By default reshuffle the + features within the batch. + limit_fraction: fraction of the dataset to use for feature importance + This is only used in estimate_on_dataset when the dataset is too large + to compute feature importance on. + """ + super().__init__(policy) + self.repeat = repeat + self.perturb_fn = perturb_fn + self.limit_fraction = limit_fraction + + def estimate(self, batch: SampleBatchType) -> Dict[str, Any]: + """Estimate the feature importance of the policy. + + Given a batch of tabular observations, the importance of each feature is + computed by perturbing each feature and computing the difference between the + perturbed policy and the reference policy. The importance is computed for each + feature and each perturbation is repeated `self.repeat` times. + + Args: + batch: the batch of data to use for feature importance. + + Returns: + A dict mapping each feature index string to its importance. + """ + batch = convert_ma_batch_to_sample_batch(batch) + obs_batch = batch["obs"] + n_features = obs_batch.shape[-1] + importance = np.zeros((self.repeat, n_features)) + + ref_actions, _, _ = self.policy.compute_actions(obs_batch, explore=False) + for r in range(self.repeat): + for i in range(n_features): + copy_obs_batch = copy.deepcopy(obs_batch) + _perturb_fn(copy_obs_batch, index=i) + perturbed_actions, _, _ = self.policy.compute_actions( + copy_obs_batch, explore=False + ) + importance[r, i] = np.mean(np.abs(perturbed_actions - ref_actions)) + + # take an average across repeats + importance = importance.mean(0) + metrics = {f"feature_{i}": importance[i] for i in range(len(importance))} + + return metrics + + @override(OfflineEvaluator) + def estimate_on_dataset( + self, dataset: Dataset, *, n_parallelism: int = ... + ) -> Dict[str, Any]: + """Estimate the feature importance of the policy given a dataset. + + For each feature in the dataset, the importance is computed by applying + perturbations to each feature and computing the difference between the + perturbed prediction and the reference prediction. The importance + computation for each feature and each perturbation is repeated `self.repeat` + times. If dataset is large the user can initialize the estimator with a + `limit_fraction` to limit the dataset to a fraction of the original dataset. + + The dataset should include a column named `obs` where each row is a vector of D + dimensions. The importance is computed for each dimension of the vector. + + Note (Implementation detail): The computation across features are distributed + with ray workers since each feature is independent of each other. + + Args: + dataset: the dataset to use for feature importance. + n_parallelism: number of parallel workers to use for feature importance. + + Returns: + A dict mapping each feature index string to its importance. + """ + + policy_state = self.policy.get_state() + # step 1: limit the dataset to a few first rows + ds = dataset.limit(int(self.limit_fraction * dataset.count())) + + # step 2: compute the reference actions + bsize = max(1, ds.count() // n_parallelism) + actions_ds = ds.map_batches( + _compute_actions, + batch_size=bsize, + fn_kwargs={ + "output_key": "ref_actions", + "policy_state": policy_state, + }, + ) + + # step 3: compute the feature importance + n_features = ds.take(1)[0][SampleBatch.OBS].shape[-1] + importance = np.zeros((self.repeat, n_features)) + for r in range(self.repeat): + # shuffle the entire dataset + shuffled_ds = actions_ds.random_shuffle() + bsize_per_task = max(1, (shuffled_ds.count() * n_features) // n_parallelism) + + # for each index perturb the dataset and compute the feat importance score + remote_fns = [ + get_feature_importance_on_index.remote( + dataset=shuffled_ds, + index=i, + perturb_fn=self.perturb_fn, + bsize=bsize_per_task, + policy_state=policy_state, + ) + for i in range(n_features) + ] + ds_w_fi_scores = ray.get(remote_fns) + importance[r] = np.array([d.mean("delta") for d in ds_w_fi_scores]) + + importance = importance.mean(0) + metrics = {f"feature_{i}": importance[i] for i in range(len(importance))} + + return metrics diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/input_reader.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/input_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..042e3783c39d25c54286d36c93131b608f5e3a0f --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/input_reader.py @@ -0,0 +1,132 @@ +from abc import ABCMeta, abstractmethod +import logging +import numpy as np +import threading + +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.utils.annotations import PublicAPI +from ray.rllib.utils.framework import try_import_tf +from typing import Dict, List +from ray.rllib.utils.typing import TensorType, SampleBatchType + +tf1, tf, tfv = try_import_tf() + +logger = logging.getLogger(__name__) + + +@PublicAPI +class InputReader(metaclass=ABCMeta): + """API for collecting and returning experiences during policy evaluation.""" + + @abstractmethod + @PublicAPI + def next(self) -> SampleBatchType: + """Returns the next batch of read experiences. + + Returns: + The experience read (SampleBatch or MultiAgentBatch). + """ + raise NotImplementedError + + @PublicAPI + def tf_input_ops(self, queue_size: int = 1) -> Dict[str, TensorType]: + """Returns TensorFlow queue ops for reading inputs from this reader. + + The main use of these ops is for integration into custom model losses. + For example, you can use tf_input_ops() to read from files of external + experiences to add an imitation learning loss to your model. + + This method creates a queue runner thread that will call next() on this + reader repeatedly to feed the TensorFlow queue. + + Args: + queue_size: Max elements to allow in the TF queue. + + .. testcode:: + :skipif: True + + from ray.rllib.models.modelv2 import ModelV2 + from ray.rllib.offline.json_reader import JsonReader + imitation_loss = ... + class MyModel(ModelV2): + def custom_loss(self, policy_loss, loss_inputs): + reader = JsonReader(...) + input_ops = reader.tf_input_ops() + logits, _ = self._build_layers_v2( + {"obs": input_ops["obs"]}, + self.num_outputs, self.options) + il_loss = imitation_loss(logits, input_ops["action"]) + return policy_loss + il_loss + + You can find a runnable version of this in examples/custom_loss.py. + + Returns: + Dict of Tensors, one for each column of the read SampleBatch. + """ + + if hasattr(self, "_queue_runner"): + raise ValueError( + "A queue runner already exists for this input reader. " + "You can only call tf_input_ops() once per reader." + ) + + logger.info("Reading initial batch of data from input reader.") + batch = self.next() + if isinstance(batch, MultiAgentBatch): + raise NotImplementedError( + "tf_input_ops() is not implemented for multi agent batches" + ) + + # Note on casting to `np.array(batch[k])`: In order to get all keys that + # are numbers, we need to convert to numpy everything that is not a numpy array. + # This is because SampleBatches used to only hold numpy arrays, but since our + # RNN efforts under RLModules, we also allow lists. + keys = [ + k + for k in sorted(batch.keys()) + if np.issubdtype(np.array(batch[k]).dtype, np.number) + ] + dtypes = [batch[k].dtype for k in keys] + shapes = {k: (-1,) + s[1:] for (k, s) in [(k, batch[k].shape) for k in keys]} + queue = tf1.FIFOQueue(capacity=queue_size, dtypes=dtypes, names=keys) + tensors = queue.dequeue() + + logger.info("Creating TF queue runner for {}".format(self)) + self._queue_runner = _QueueRunner(self, queue, keys, dtypes) + self._queue_runner.enqueue(batch) + self._queue_runner.start() + + out = {k: tf.reshape(t, shapes[k]) for k, t in tensors.items()} + return out + + +class _QueueRunner(threading.Thread): + """Thread that feeds a TF queue from a InputReader.""" + + def __init__( + self, + input_reader: InputReader, + queue: "tf1.FIFOQueue", + keys: List[str], + dtypes: "tf.dtypes.DType", + ): + threading.Thread.__init__(self) + self.sess = tf1.get_default_session() + self.daemon = True + self.input_reader = input_reader + self.keys = keys + self.queue = queue + self.placeholders = [tf1.placeholder(dtype) for dtype in dtypes] + self.enqueue_op = queue.enqueue(dict(zip(keys, self.placeholders))) + + def enqueue(self, batch: SampleBatchType): + data = {self.placeholders[i]: batch[key] for i, key in enumerate(self.keys)} + self.sess.run(self.enqueue_op, feed_dict=data) + + def run(self): + while True: + try: + batch = self.input_reader.next() + self.enqueue(batch) + except Exception: + logger.exception("Error reading from input") diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/io_context.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/io_context.py new file mode 100644 index 0000000000000000000000000000000000000000..1d0ec1683b935ff4712d0873ac567d935a2b7642 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/io_context.py @@ -0,0 +1,72 @@ +import os +from typing import Optional, TYPE_CHECKING + +from ray.rllib.utils.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + from ray.rllib.evaluation.sampler import SamplerInput + from ray.rllib.evaluation.rollout_worker import RolloutWorker + + +@PublicAPI +class IOContext: + """Class containing attributes to pass to input/output class constructors. + + RLlib auto-sets these attributes when constructing input/output classes, + such as InputReaders and OutputWriters. + """ + + @PublicAPI + def __init__( + self, + log_dir: Optional[str] = None, + config: Optional["AlgorithmConfig"] = None, + worker_index: int = 0, + worker: Optional["RolloutWorker"] = None, + ): + """Initializes a IOContext object. + + Args: + log_dir: The logging directory to read from/write to. + config: The (main) AlgorithmConfig object. + worker_index: When there are multiple workers created, this + uniquely identifies the current worker. 0 for the local + worker, >0 for any of the remote workers. + worker: The RolloutWorker object reference. + """ + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + self.log_dir = log_dir or os.getcwd() + # In case no config is provided, use the default one, but set + # `actions_in_input_normalized=True` if we don't have a worker. + # Not having a worker and/or a config should only be the case in some test + # cases, though. + self.config = config or AlgorithmConfig().offline_data( + actions_in_input_normalized=worker is None + ).training(train_batch_size=1) + self.worker_index = worker_index + self.worker = worker + + @PublicAPI + def default_sampler_input(self) -> Optional["SamplerInput"]: + """Returns the RolloutWorker's SamplerInput object, if any. + + Returns None if the RolloutWorker has no SamplerInput. Note that local + workers in case there are also one or more remote workers by default + do not create a SamplerInput object. + + Returns: + The RolloutWorkers' SamplerInput object or None if none exists. + """ + return self.worker.sampler + + @property + @PublicAPI + def input_config(self): + return self.config.get("input_config", {}) + + @property + @PublicAPI + def output_config(self): + return self.config.get("output_config", {}) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/is_estimator.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/is_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..58c8da3e0c72358046a9192fae8461780eaf3bbf --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/is_estimator.py @@ -0,0 +1,10 @@ +from ray.rllib.offline.estimators.importance_sampling import ImportanceSampling +from ray.rllib.utils.deprecation import Deprecated + + +@Deprecated( + new="ray.rllib.offline.estimators.importance_sampling::ImportanceSampling", + error=True, +) +class ImportanceSamplingEstimator(ImportanceSampling): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/json_reader.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/json_reader.py new file mode 100644 index 0000000000000000000000000000000000000000..30562b515aac2562b33fbcb847720ee1f17c0131 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/json_reader.py @@ -0,0 +1,438 @@ +import glob +import json +import logging +import math + +import numpy as np +import os +from pathlib import Path +import random +import re +import tree # pip install dm_tree +from typing import List, Optional, TYPE_CHECKING, Union +from urllib.parse import urlparse +import zipfile + +try: + from smart_open import smart_open +except ImportError: + smart_open = None + +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.io_context import IOContext +from ray.rllib.policy.policy import Policy +from ray.rllib.policy.sample_batch import ( + DEFAULT_POLICY_ID, + MultiAgentBatch, + SampleBatch, + concat_samples, + convert_ma_batch_to_sample_batch, +) +from ray.rllib.utils.annotations import override, PublicAPI, DeveloperAPI +from ray.rllib.utils.compression import unpack_if_needed +from ray.rllib.utils.spaces.space_utils import clip_action, normalize_action +from ray.rllib.utils.typing import Any, FileType, SampleBatchType + +if TYPE_CHECKING: + from ray.rllib.evaluation import RolloutWorker + +logger = logging.getLogger(__name__) + +WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)] + + +def _adjust_obs_actions_for_policy(json_data: dict, policy: Policy) -> dict: + """Handle nested action/observation spaces for policies. + + Translates nested lists/dicts from the json into proper + np.ndarrays, according to the (nested) observation- and action- + spaces of the given policy. + + Providing nested lists w/o this preprocessing step would + confuse a SampleBatch constructor. + """ + for k, v in json_data.items(): + data_col = ( + policy.view_requirements[k].data_col + if k in policy.view_requirements + else "" + ) + # No action flattening -> Process nested (leaf) action(s). + if policy.config.get("_disable_action_flattening") and ( + k == SampleBatch.ACTIONS + or data_col == SampleBatch.ACTIONS + or k == SampleBatch.PREV_ACTIONS + or data_col == SampleBatch.PREV_ACTIONS + ): + json_data[k] = tree.map_structure_up_to( + policy.action_space_struct, + lambda comp: np.array(comp), + json_data[k], + check_types=False, + ) + # No preprocessing -> Process nested (leaf) observation(s). + elif policy.config.get("_disable_preprocessor_api") and ( + k == SampleBatch.OBS + or data_col == SampleBatch.OBS + or k == SampleBatch.NEXT_OBS + or data_col == SampleBatch.NEXT_OBS + ): + json_data[k] = tree.map_structure_up_to( + policy.observation_space_struct, + lambda comp: np.array(comp), + json_data[k], + check_types=False, + ) + return json_data + + +@DeveloperAPI +def _adjust_dones(json_data: dict) -> dict: + """Make sure DONES in json data is properly translated into TERMINATEDS.""" + new_json_data = {} + for k, v in json_data.items(): + # Translate DONES into TERMINATEDS. + if k == SampleBatch.DONES: + new_json_data[SampleBatch.TERMINATEDS] = v + # Leave everything else as-is. + else: + new_json_data[k] = v + + return new_json_data + + +@DeveloperAPI +def postprocess_actions(batch: SampleBatchType, ioctx: IOContext) -> SampleBatchType: + # Clip actions (from any values into env's bounds), if necessary. + cfg = ioctx.config + # TODO(jungong): We should not clip_action in input reader. + # Use connector to handle this. + if cfg.get("clip_actions"): + if ioctx.worker is None: + raise ValueError( + "clip_actions is True but cannot clip actions since no workers exist" + ) + + if isinstance(batch, SampleBatch): + policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID) + if policy is None: + assert len(ioctx.worker.policy_map) == 1 + policy = next(iter(ioctx.worker.policy_map.values())) + batch[SampleBatch.ACTIONS] = clip_action( + batch[SampleBatch.ACTIONS], policy.action_space_struct + ) + else: + for pid, b in batch.policy_batches.items(): + b[SampleBatch.ACTIONS] = clip_action( + b[SampleBatch.ACTIONS], + ioctx.worker.policy_map[pid].action_space_struct, + ) + # Re-normalize actions (from env's bounds to zero-centered), if + # necessary. + if ( + cfg.get("actions_in_input_normalized") is False + and cfg.get("normalize_actions") is True + ): + if ioctx.worker is None: + raise ValueError( + "actions_in_input_normalized is False but" + "cannot normalize actions since no workers exist" + ) + + # If we have a complex action space and actions were flattened + # and we have to normalize -> Error. + error_msg = ( + "Normalization of offline actions that are flattened is not " + "supported! Make sure that you record actions into offline " + "file with the `_disable_action_flattening=True` flag OR " + "as already normalized (between -1.0 and 1.0) values. " + "Also, when reading already normalized action values from " + "offline files, make sure to set " + "`actions_in_input_normalized=True` so that RLlib will not " + "perform normalization on top." + ) + + if isinstance(batch, SampleBatch): + policy = ioctx.worker.policy_map.get(DEFAULT_POLICY_ID) + if policy is None: + assert len(ioctx.worker.policy_map) == 1 + policy = next(iter(ioctx.worker.policy_map.values())) + if isinstance( + policy.action_space_struct, (tuple, dict) + ) and not policy.config.get("_disable_action_flattening"): + raise ValueError(error_msg) + batch[SampleBatch.ACTIONS] = normalize_action( + batch[SampleBatch.ACTIONS], policy.action_space_struct + ) + else: + for pid, b in batch.policy_batches.items(): + policy = ioctx.worker.policy_map[pid] + if isinstance( + policy.action_space_struct, (tuple, dict) + ) and not policy.config.get("_disable_action_flattening"): + raise ValueError(error_msg) + b[SampleBatch.ACTIONS] = normalize_action( + b[SampleBatch.ACTIONS], + ioctx.worker.policy_map[pid].action_space_struct, + ) + + return batch + + +@DeveloperAPI +def from_json_data(json_data: Any, worker: Optional["RolloutWorker"]): + # Try to infer the SampleBatchType (SampleBatch or MultiAgentBatch). + if "type" in json_data: + data_type = json_data.pop("type") + else: + raise ValueError("JSON record missing 'type' field") + + if data_type == "SampleBatch": + if worker is not None and len(worker.policy_map) != 1: + raise ValueError( + "Found single-agent SampleBatch in input file, but our " + "PolicyMap contains more than 1 policy!" + ) + for k, v in json_data.items(): + json_data[k] = unpack_if_needed(v) + if worker is not None: + policy = next(iter(worker.policy_map.values())) + json_data = _adjust_obs_actions_for_policy(json_data, policy) + json_data = _adjust_dones(json_data) + return SampleBatch(json_data) + elif data_type == "MultiAgentBatch": + policy_batches = {} + for policy_id, policy_batch in json_data["policy_batches"].items(): + inner = {} + for k, v in policy_batch.items(): + # Translate DONES into TERMINATEDS. + if k == SampleBatch.DONES: + k = SampleBatch.TERMINATEDS + inner[k] = unpack_if_needed(v) + if worker is not None: + policy = worker.policy_map[policy_id] + inner = _adjust_obs_actions_for_policy(inner, policy) + inner = _adjust_dones(inner) + policy_batches[policy_id] = SampleBatch(inner) + return MultiAgentBatch(policy_batches, json_data["count"]) + else: + raise ValueError( + "Type field must be one of ['SampleBatch', 'MultiAgentBatch']", data_type + ) + + +# TODO(jungong) : use DatasetReader to back JsonReader, so we reduce +# codebase complexity without losing existing functionality. +@PublicAPI +class JsonReader(InputReader): + """Reader object that loads experiences from JSON file chunks. + + The input files will be read from in random order. + """ + + @PublicAPI + def __init__( + self, inputs: Union[str, List[str]], ioctx: Optional[IOContext] = None + ): + """Initializes a JsonReader instance. + + Args: + inputs: Either a glob expression for files, e.g. `/tmp/**/*.json`, + or a list of single file paths or URIs, e.g., + ["s3://bucket/file.json", "s3://bucket/file2.json"]. + ioctx: Current IO context object or None. + """ + logger.info( + "You are using JSONReader. It is recommended to use " + + "DatasetReader instead for better sharding support." + ) + + self.ioctx = ioctx or IOContext() + self.default_policy = self.policy_map = None + self.batch_size = 1 + if self.ioctx: + self.batch_size = self.ioctx.config.get("train_batch_size", 1) + num_workers = self.ioctx.config.get("num_env_runners", 0) + if num_workers: + self.batch_size = max(math.ceil(self.batch_size / num_workers), 1) + + if self.ioctx.worker is not None: + self.policy_map = self.ioctx.worker.policy_map + self.default_policy = self.policy_map.get(DEFAULT_POLICY_ID) + if self.default_policy is None: + assert len(self.policy_map) == 1 + self.default_policy = next(iter(self.policy_map.values())) + + if isinstance(inputs, str): + inputs = os.path.abspath(os.path.expanduser(inputs)) + if os.path.isdir(inputs): + inputs = [os.path.join(inputs, "*.json"), os.path.join(inputs, "*.zip")] + logger.warning(f"Treating input directory as glob patterns: {inputs}") + else: + inputs = [inputs] + + if any(urlparse(i).scheme not in [""] + WINDOWS_DRIVES for i in inputs): + raise ValueError( + "Don't know how to glob over `{}`, ".format(inputs) + + "please specify a list of files to read instead." + ) + else: + self.files = [] + for i in inputs: + self.files.extend(glob.glob(i)) + elif isinstance(inputs, (list, tuple)): + self.files = list(inputs) + else: + raise ValueError( + "type of inputs must be list or str, not {}".format(inputs) + ) + if self.files: + logger.info("Found {} input files.".format(len(self.files))) + else: + raise ValueError("No files found matching {}".format(inputs)) + self.cur_file = None + + @override(InputReader) + def next(self) -> SampleBatchType: + ret = [] + count = 0 + while count < self.batch_size: + batch = self._try_parse(self._next_line()) + tries = 0 + while not batch and tries < 100: + tries += 1 + logger.debug("Skipping empty line in {}".format(self.cur_file)) + batch = self._try_parse(self._next_line()) + if not batch: + raise ValueError( + "Failed to read valid experience batch from file: {}".format( + self.cur_file + ) + ) + batch = self._postprocess_if_needed(batch) + count += batch.count + ret.append(batch) + ret = concat_samples(ret) + return ret + + def read_all_files(self) -> SampleBatchType: + """Reads through all files and yields one SampleBatchType per line. + + When reaching the end of the last file, will start from the beginning + again. + + Yields: + One SampleBatch or MultiAgentBatch per line in all input files. + """ + for path in self.files: + file = self._try_open_file(path) + while True: + line = file.readline() + if not line: + break + batch = self._try_parse(line) + if batch is None: + break + yield batch + + def _postprocess_if_needed(self, batch: SampleBatchType) -> SampleBatchType: + if not self.ioctx.config.get("postprocess_inputs"): + return batch + + batch = convert_ma_batch_to_sample_batch(batch) + + if isinstance(batch, SampleBatch): + out = [] + for sub_batch in batch.split_by_episode(): + out.append(self.default_policy.postprocess_trajectory(sub_batch)) + return concat_samples(out) + else: + # TODO(ekl) this is trickier since the alignments between agent + # trajectories in the episode are not available any more. + raise NotImplementedError( + "Postprocessing of multi-agent data not implemented yet." + ) + + def _try_open_file(self, path): + if urlparse(path).scheme not in [""] + WINDOWS_DRIVES: + if smart_open is None: + raise ValueError( + "You must install the `smart_open` module to read " + "from URIs like {}".format(path) + ) + ctx = smart_open + else: + # Allow shortcut for home directory ("~/" -> env[HOME]). + if path.startswith("~/"): + path = os.path.join(os.environ.get("HOME", ""), path[2:]) + + # If path doesn't exist, try to interpret is as relative to the + # rllib directory (located ../../ from this very module). + path_orig = path + if not os.path.exists(path): + path = os.path.join(Path(__file__).parent.parent, path) + if not os.path.exists(path): + raise FileNotFoundError(f"Offline file {path_orig} not found!") + + # Unzip files, if necessary and re-point to extracted json file. + if re.search("\\.zip$", path): + with zipfile.ZipFile(path, "r") as zip_ref: + zip_ref.extractall(Path(path).parent) + path = re.sub("\\.zip$", ".json", path) + assert os.path.exists(path) + ctx = open + file = ctx(path, "r") + return file + + def _try_parse(self, line: str) -> Optional[SampleBatchType]: + line = line.strip() + if not line: + return None + try: + batch = self._from_json(line) + except Exception: + logger.exception( + "Ignoring corrupt json record in {}: {}".format(self.cur_file, line) + ) + return None + + batch = postprocess_actions(batch, self.ioctx) + + return batch + + def _next_line(self) -> str: + if not self.cur_file: + self.cur_file = self._next_file() + line = self.cur_file.readline() + tries = 0 + while not line and tries < 100: + tries += 1 + if hasattr(self.cur_file, "close"): # legacy smart_open impls + self.cur_file.close() + self.cur_file = self._next_file() + line = self.cur_file.readline() + if not line: + logger.debug("Ignoring empty file {}".format(self.cur_file)) + if not line: + raise ValueError( + "Failed to read next line from files: {}".format(self.files) + ) + return line + + def _next_file(self) -> FileType: + # If this is the first time, we open a file, make sure all workers + # start with a different one if possible. + if self.cur_file is None and self.ioctx.worker is not None: + idx = self.ioctx.worker.worker_index + total = self.ioctx.worker.num_workers or 1 + path = self.files[round((len(self.files) - 1) * (idx / total))] + # After the first file, pick all others randomly. + else: + path = random.choice(self.files) + return self._try_open_file(path) + + def _from_json(self, data: str) -> SampleBatchType: + if isinstance(data, bytes): # smart_open S3 doesn't respect "r" + data = data.decode("utf-8") + json_data = json.loads(data) + return from_json_data(json_data, self.ioctx.worker) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/json_writer.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/json_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..4e15bfb2e550fa0013571a488c9e5c47044ed67c --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/json_writer.py @@ -0,0 +1,142 @@ +from datetime import datetime +import json +import logging +import numpy as np +import os +from urllib.parse import urlparse +import time + +try: + from smart_open import smart_open +except ImportError: + smart_open = None + +from ray.air._internal.json import SafeFallbackEncoder +from ray.rllib.policy.sample_batch import MultiAgentBatch +from ray.rllib.offline.io_context import IOContext +from ray.rllib.offline.output_writer import OutputWriter +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.compression import pack, compression_supported +from ray.rllib.utils.typing import FileType, SampleBatchType +from typing import Any, Dict, List + +logger = logging.getLogger(__name__) + +WINDOWS_DRIVES = [chr(i) for i in range(ord("c"), ord("z") + 1)] + + +# TODO(jungong): use DatasetWriter to back JsonWriter, so we reduce codebase complexity +# without losing existing functionality. +@PublicAPI +class JsonWriter(OutputWriter): + """Writer object that saves experiences in JSON file chunks.""" + + @PublicAPI + def __init__( + self, + path: str, + ioctx: IOContext = None, + max_file_size: int = 64 * 1024 * 1024, + compress_columns: List[str] = frozenset(["obs", "new_obs"]), + ): + """Initializes a JsonWriter instance. + + Args: + path: a path/URI of the output directory to save files in. + ioctx: current IO context object. + max_file_size: max size of single files before rolling over. + compress_columns: list of sample batch columns to compress. + """ + logger.info( + "You are using JSONWriter. It is recommended to use " + + "DatasetWriter instead." + ) + + self.ioctx = ioctx or IOContext() + self.max_file_size = max_file_size + self.compress_columns = compress_columns + if urlparse(path).scheme not in [""] + WINDOWS_DRIVES: + self.path_is_uri = True + else: + path = os.path.abspath(os.path.expanduser(path)) + # Try to create local dirs if they don't exist + os.makedirs(path, exist_ok=True) + assert os.path.exists(path), "Failed to create {}".format(path) + self.path_is_uri = False + self.path = path + self.file_index = 0 + self.bytes_written = 0 + self.cur_file = None + + @override(OutputWriter) + def write(self, sample_batch: SampleBatchType): + start = time.time() + data = _to_json(sample_batch, self.compress_columns) + f = self._get_file() + f.write(data) + f.write("\n") + if hasattr(f, "flush"): # legacy smart_open impls + f.flush() + self.bytes_written += len(data) + logger.debug( + "Wrote {} bytes to {} in {}s".format(len(data), f, time.time() - start) + ) + + def _get_file(self) -> FileType: + if not self.cur_file or self.bytes_written >= self.max_file_size: + if self.cur_file: + self.cur_file.close() + timestr = datetime.today().strftime("%Y-%m-%d_%H-%M-%S") + path = os.path.join( + self.path, + "output-{}_worker-{}_{}.json".format( + timestr, self.ioctx.worker_index, self.file_index + ), + ) + if self.path_is_uri: + if smart_open is None: + raise ValueError( + "You must install the `smart_open` module to write " + "to URIs like {}".format(path) + ) + self.cur_file = smart_open(path, "w") + else: + self.cur_file = open(path, "w") + self.file_index += 1 + self.bytes_written = 0 + logger.info("Writing to new output file {}".format(self.cur_file)) + return self.cur_file + + +def _to_jsonable(v, compress: bool) -> Any: + if compress and compression_supported(): + return str(pack(v)) + elif isinstance(v, np.ndarray): + return v.tolist() + + return v + + +def _to_json_dict(batch: SampleBatchType, compress_columns: List[str]) -> Dict: + out = {} + if isinstance(batch, MultiAgentBatch): + out["type"] = "MultiAgentBatch" + out["count"] = batch.count + policy_batches = {} + for policy_id, sub_batch in batch.policy_batches.items(): + policy_batches[policy_id] = {} + for k, v in sub_batch.items(): + policy_batches[policy_id][k] = _to_jsonable( + v, compress=k in compress_columns + ) + out["policy_batches"] = policy_batches + else: + out["type"] = "SampleBatch" + for k, v in batch.items(): + out[k] = _to_jsonable(v, compress=k in compress_columns) + return out + + +def _to_json(batch: SampleBatchType, compress_columns: List[str]) -> str: + out = _to_json_dict(batch, compress_columns) + return json.dumps(out, cls=SafeFallbackEncoder) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/mixed_input.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/mixed_input.py new file mode 100644 index 0000000000000000000000000000000000000000..8c8ad60b06f9baa34bfe7432e5ea6f95d59a3fbd --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/mixed_input.py @@ -0,0 +1,58 @@ +from types import FunctionType +from typing import Dict + +import numpy as np +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.offline.io_context import IOContext +from ray.rllib.offline.json_reader import JsonReader +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.typing import SampleBatchType +from ray.tune.registry import registry_get_input, registry_contains_input + + +@DeveloperAPI +class MixedInput(InputReader): + """Mixes input from a number of other input sources. + + .. testcode:: + :skipif: True + + from ray.rllib.offline.io_context import IOContext + from ray.rllib.offline.mixed_input import MixedInput + ioctx = IOContext(...) + MixedInput({ + "sampler": 0.4, + "/tmp/experiences/*.json": 0.4, + "s3://bucket/expert.json": 0.2, + }, ioctx) + """ + + @DeveloperAPI + def __init__(self, dist: Dict[JsonReader, float], ioctx: IOContext): + """Initialize a MixedInput. + + Args: + dist: dict mapping JSONReader paths or "sampler" to + probabilities. The probabilities must sum to 1.0. + ioctx: current IO context object. + """ + if sum(dist.values()) != 1.0: + raise ValueError("Values must sum to 1.0: {}".format(dist)) + self.choices = [] + self.p = [] + for k, v in dist.items(): + if k == "sampler": + self.choices.append(ioctx.default_sampler_input()) + elif isinstance(k, FunctionType): + self.choices.append(k(ioctx)) + elif isinstance(k, str) and registry_contains_input(k): + input_creator = registry_get_input(k) + self.choices.append(input_creator(ioctx)) + else: + self.choices.append(JsonReader(k, ioctx)) + self.p.append(v) + + @override(InputReader) + def next(self) -> SampleBatchType: + source = np.random.choice(self.choices, p=self.p) + return source.next() diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/off_policy_estimator.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/off_policy_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..c8a08fb4a1dfa5c31a5fac94ded7e874504507aa --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/off_policy_estimator.py @@ -0,0 +1,10 @@ +from ray.rllib.offline.estimators.off_policy_estimator import ( # noqa: F401 + OffPolicyEstimator, +) +from ray.rllib.utils.deprecation import deprecation_warning + +deprecation_warning( + old="ray.rllib.offline.off_policy_estimator", + new="ray.rllib.offline.estimators.off_policy_estimator", + error=True, +) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_data.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_data.py new file mode 100644 index 0000000000000000000000000000000000000000..04d52babc877efefaad37631cd0916e03b446515 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_data.py @@ -0,0 +1,237 @@ +import logging +from pathlib import Path +import pyarrow.fs +import ray +import time +import types + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.core import COMPONENT_RL_MODULE +from ray.rllib.env import INPUT_ENV_SPACES +from ray.rllib.offline.offline_prelearner import OfflinePreLearner +from ray.rllib.utils.annotations import ( + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.util.annotations import PublicAPI + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +class OfflineData: + @OverrideToImplementCustomLogic_CallToSuperRecommended + def __init__(self, config: AlgorithmConfig): + + self.config = config + self.is_multi_agent = self.config.is_multi_agent + self.path = ( + self.config.input_ + if isinstance(config.input_, list) + else Path(config.input_) + ) + # Use `read_parquet` as default data read method. + self.data_read_method = self.config.input_read_method + # Override default arguments for the data read method. + self.data_read_method_kwargs = self.config.input_read_method_kwargs + # In case `EpisodeType` or `BatchType` batches are read the size + # could differ from the final `train_batch_size_per_learner`. + self.data_read_batch_size = self.config.input_read_batch_size + + # If data should be materialized. + self.materialize_data = config.materialize_data + # If mapped data should be materialized. + self.materialize_mapped_data = config.materialize_mapped_data + # Flag to identify, if data has already been mapped with the + # `OfflinePreLearner`. + self.data_is_mapped = False + + # Set the filesystem. + self.filesystem = self.config.input_filesystem + self.filesystem_kwargs = self.config.input_filesystem_kwargs + self.filesystem_object = None + + # If a specific filesystem is given, set it up. Note, this could + # be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage. + # this filesystem is specifically needed, if a session has to be created + # with the cloud provider. + if self.filesystem == "gcs": + import gcsfs + + self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs) + elif self.filesystem == "s3": + self.filesystem_object = pyarrow.fs.S3FileSystem(**self.filesystem_kwargs) + elif self.filesystem == "abs": + import adlfs + + self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs) + elif isinstance(self.filesystem, pyarrow.fs.FileSystem): + self.filesystem_object = self.filesystem + elif self.filesystem is not None: + raise ValueError( + f"Unknown `config.input_filesystem` {self.filesystem}! Filesystems " + "can be None for local, any instance of `pyarrow.fs.FileSystem`, " + "'gcs' for GCS, 's3' for S3, or 'abs' for adlfs.AzureBlobFileSystem." + ) + # Add the filesystem object to the write method kwargs. + if self.filesystem_object: + self.data_read_method_kwargs.update( + { + "filesystem": self.filesystem_object, + } + ) + + try: + # Load the dataset. + start_time = time.perf_counter() + self.data = getattr(ray.data, self.data_read_method)( + self.path, **self.data_read_method_kwargs + ) + if self.materialize_data: + self.data = self.data.materialize() + stop_time = time.perf_counter() + logger.debug( + "===> [OfflineData] - Time for loading dataset: " + f"{stop_time - start_time}s." + ) + logger.info("Reading data from {}".format(self.path)) + except Exception as e: + logger.error(e) + # Avoids reinstantiating the batch iterator each time we sample. + self.batch_iterators = None + self.map_batches_kwargs = ( + self.default_map_batches_kwargs | self.config.map_batches_kwargs + ) + self.iter_batches_kwargs = ( + self.default_iter_batches_kwargs | self.config.iter_batches_kwargs + ) + self.returned_streaming_split = False + # Defines the prelearner class. Note, this could be user-defined. + self.prelearner_class = self.config.prelearner_class or OfflinePreLearner + # For remote learner setups. + self.locality_hints = None + self.learner_handles = None + self.module_spec = None + + @OverrideToImplementCustomLogic + def sample( + self, + num_samples: int, + return_iterator: bool = False, + num_shards: int = 1, + ): + # Materialize the mapped data, if necessary. This runs for all the + # data the `OfflinePreLearner` logic and maps them to `MultiAgentBatch`es. + # TODO (simon, sven): This would never update the module nor the + # the connectors. If this is needed we have to check, if we give + # (a) only an iterator and let the learner and OfflinePreLearner + # communicate through the object storage. This only works when + # not materializing. + # (b) Rematerialize the data every couple of iterations. This is + # is costly. + if not self.data_is_mapped: + # Constructor `kwargs` for the `OfflinePreLearner`. + fn_constructor_kwargs = { + "config": self.config, + "learner": self.learner_handles[0], + "spaces": self.spaces[INPUT_ENV_SPACES], + } + # If we have multiple learners, add to the constructor `kwargs`. + if num_shards > 1: + # Call here the learner to get an up-to-date module state. + # TODO (simon): This is a workaround as along as learners cannot + # receive any calls from another actor. + module_state = ray.get( + self.learner_handles[0].get_state.remote( + component=COMPONENT_RL_MODULE + ) + ) + # Add constructor `kwargs` when using remote learners. + fn_constructor_kwargs.update( + { + "learner": None, + "module_spec": self.module_spec, + "module_state": module_state, + } + ) + + self.data = self.data.map_batches( + self.prelearner_class, + fn_constructor_kwargs=fn_constructor_kwargs, + batch_size=self.data_read_batch_size or num_samples, + **self.map_batches_kwargs, + ) + # Set the flag to `True`. + self.data_is_mapped = True + # If the user wants to materialize the data in memory. + if self.materialize_mapped_data: + self.data = self.data.materialize() + # Build an iterator, if necessary. Note, in case that an iterator should be + # returned now and we have already generated from the iterator, i.e. + # `isinstance(self.batch_iterators, types.GeneratorType) == True`, we need + # to create here a new iterator. + if not self.batch_iterators or ( + return_iterator and isinstance(self.batch_iterators, types.GeneratorType) + ): + # If we have more than one learner create an iterator for each of them + # by splitting the data stream. + if num_shards > 1: + logger.debug("===> [OfflineData]: Return streaming_split ... ") + # In case of multiple shards, we return multiple + # `StreamingSplitIterator` instances. + self.batch_iterators = self.data.streaming_split( + n=num_shards, + # Note, `equal` must be `True`, i.e. the batch size must + # be the same for all batches b/c otherwise remote learners + # could block each others. + equal=True, + locality_hints=self.locality_hints, + ) + # Otherwise we create a simple iterator and - if necessary - initialize + # it here. + else: + # If no iterator should be returned, or if we want to return a single + # batch iterator, we instantiate the batch iterator once, here. + self.batch_iterators = self.data.iter_batches( + # This is important. The batch size is now 1, because the data + # is already run through the `OfflinePreLearner` and a single + # instance is a single `MultiAgentBatch` of size `num_samples`. + batch_size=1, + **self.iter_batches_kwargs, + ) + + # If there should be batches + if not return_iterator: + self.batch_iterators = iter(self.batch_iterators) + + # Do we want to return an iterator or a single batch? + if return_iterator: + return self.batch_iterators + else: + # Return a single batch from the iterator. + try: + return next(self.batch_iterators)["batch"][0] + except StopIteration: + # If the batch iterator is exhausted, reinitiate a new one. + logger.debug( + "===> [OfflineData]: Batch iterator exhausted. Reinitiating ..." + ) + self.batch_iterators = None + return self.sample( + num_samples=num_samples, + return_iterator=return_iterator, + num_shards=num_shards, + ) + + @property + def default_map_batches_kwargs(self): + return { + "concurrency": max(2, self.config.num_learners), + "zero_copy_batch": True, + } + + @property + def default_iter_batches_kwargs(self): + return { + "prefetch_batches": 2, + } diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_env_runner.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_env_runner.py new file mode 100644 index 0000000000000000000000000000000000000000..5b7a8dce1d295e94e35e592c4e8667faedcf4d35 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_env_runner.py @@ -0,0 +1,311 @@ +import logging +import ray + +from pathlib import Path +from typing import List + +from ray.rllib.algorithms.algorithm_config import AlgorithmConfig +from ray.rllib.core.columns import Columns +from ray.rllib.env.env_runner import EnvRunner +from ray.rllib.env.single_agent_env_runner import SingleAgentEnvRunner +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.utils.annotations import ( + override, + OverrideToImplementCustomLogic_CallToSuperRecommended, + OverrideToImplementCustomLogic, +) +from ray.rllib.utils.compression import pack_if_needed +from ray.rllib.utils.spaces.space_utils import to_jsonable_if_needed +from ray.rllib.utils.typing import EpisodeType +from ray.util.debug import log_once +from ray.util.annotations import PublicAPI + +logger = logging.Logger(__file__) + +# TODO (simon): This class can be agnostic to the episode type as it +# calls only get_state. + + +@PublicAPI(stability="alpha") +class OfflineSingleAgentEnvRunner(SingleAgentEnvRunner): + """The environment runner to record the single agent case.""" + + @override(SingleAgentEnvRunner) + @OverrideToImplementCustomLogic_CallToSuperRecommended + def __init__(self, *, config: AlgorithmConfig, **kwargs): + # Initialize the parent. + super().__init__(config=config, **kwargs) + + # Get the data context for this `EnvRunner`. + data_context = ray.data.DataContext.get_current() + # Limit the resources for Ray Data to the CPUs given to this `EnvRunner`. + data_context.execution_options.resource_limits.cpu = ( + config.num_cpus_per_env_runner + ) + + # Set the output write method. + self.output_write_method = self.config.output_write_method + self.output_write_method_kwargs = self.config.output_write_method_kwargs + + # Set the filesystem. + self.filesystem = self.config.output_filesystem + self.filesystem_kwargs = self.config.output_filesystem_kwargs + self.filesystem_object = None + + # Set the output base path. + self.output_path = self.config.output + # Set the subdir (environment specific). + self.subdir_path = self.config.env.lower() + # Set the worker-specific path name. Note, this is + # specifically to enable multi-threaded writing into + # the same directory. + self.worker_path = "run-" + f"{self.worker_index}".zfill(6) + + # If a specific filesystem is given, set it up. Note, this could + # be `gcsfs` for GCS, `pyarrow` for S3 or `adlfs` for Azure Blob Storage. + # this filesystem is specifically needed, if a session has to be created + # with the cloud provider. + + if self.filesystem == "gcs": + import gcsfs + + self.filesystem_object = gcsfs.GCSFileSystem(**self.filesystem_kwargs) + elif self.filesystem == "s3": + from pyarrow import fs + + self.filesystem_object = fs.S3FileSystem(**self.filesystem_kwargs) + elif self.filesystem == "abs": + import adlfs + + self.filesystem_object = adlfs.AzureBlobFileSystem(**self.filesystem_kwargs) + elif self.filesystem is not None: + raise ValueError( + f"Unknown filesystem: {self.filesystem}. Filesystems can be " + "'gcs' for GCS, 's3' for S3, or 'abs'" + ) + # Add the filesystem object to the write method kwargs. + self.output_write_method_kwargs.update( + { + "filesystem": self.filesystem_object, + } + ) + + # If we should store `SingleAgentEpisodes` or column data. + self.output_write_episodes = self.config.output_write_episodes + # Which columns should be compressed in the output data. + self.output_compress_columns = self.config.output_compress_columns + + # Buffer these many rows before writing to file. + self.output_max_rows_per_file = self.config.output_max_rows_per_file + # If the user defines a maximum number of rows per file, set the + # event to `False` and check during sampling. + if self.output_max_rows_per_file: + self.write_data_this_iter = False + # Otherwise the event is always `True` and we write always sampled + # data immediately to disk. + else: + self.write_data_this_iter = True + + # If the remaining data should be stored. Note, this is only + # relevant in case `output_max_rows_per_file` is defined. + self.write_remaining_data = self.config.output_write_remaining_data + + # Counts how often `sample` is called to define the output path for + # each file. + self._sample_counter = 0 + + # Define the buffer for experiences stored until written to disk. + self._samples = [] + + @override(SingleAgentEnvRunner) + @OverrideToImplementCustomLogic + def sample( + self, + *, + num_timesteps: int = None, + num_episodes: int = None, + explore: bool = None, + random_actions: bool = False, + force_reset: bool = False, + ) -> List[SingleAgentEpisode]: + """Samples from environments and writes data to disk.""" + + # Call the super sample method. + samples = super().sample( + num_timesteps=num_timesteps, + num_episodes=num_episodes, + explore=explore, + random_actions=random_actions, + force_reset=force_reset, + ) + + self._sample_counter += 1 + + # Add data to the buffers. + if self.output_write_episodes: + + import msgpack + import msgpack_numpy as mnp + + if log_once("msgpack"): + logger.info( + "Packing episodes with `msgpack` and encode array with " + "`msgpack_numpy` for serialization. This is needed for " + "recording episodes." + ) + # Note, we serialize episodes with `msgpack` and `msgpack_numpy` to + # ensure version compatibility. + self._samples.extend( + [msgpack.packb(eps.get_state(), default=mnp.encode) for eps in samples] + ) + else: + self._map_episodes_to_data(samples) + + # If the user defined the maximum number of rows to write. + if self.output_max_rows_per_file: + # Check, if this number is reached. + if len(self._samples) >= self.output_max_rows_per_file: + # Start the recording of data. + self.write_data_this_iter = True + + if self.write_data_this_iter: + # If the user wants a maximum number of experiences per file, + # cut the samples to write to disk from the buffer. + if self.output_max_rows_per_file: + # Reset the event. + self.write_data_this_iter = False + # Ensure that all data ready to be written is released from + # the buffer. Note, this is important in case we have many + # episodes sampled and a relatively small `output_max_rows_per_file`. + while len(self._samples) >= self.output_max_rows_per_file: + # Extract the number of samples to be written to disk this + # iteration. + samples_to_write = self._samples[: self.output_max_rows_per_file] + # Reset the buffer to the remaining data. This only makes sense, if + # `rollout_fragment_length` is smaller `output_max_rows_per_file` or + # a 2 x `output_max_rows_per_file`. + self._samples = self._samples[self.output_max_rows_per_file :] + samples_ds = ray.data.from_items(samples_to_write) + # Otherwise, write the complete data. + else: + samples_ds = ray.data.from_items(self._samples) + try: + # Setup the path for writing data. Each run will be written to + # its own file. A run is a writing event. The path will look + # like. 'base_path/env-name/00000-00000'. + path = ( + Path(self.output_path) + .joinpath(self.subdir_path) + .joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6)) + ) + getattr(samples_ds, self.output_write_method)( + path.as_posix(), **self.output_write_method_kwargs + ) + logger.info(f"Wrote samples to storage at {path}.") + except Exception as e: + logger.error(e) + + self.metrics.log_value( + key="recording_buffer_size", + value=len(self._samples), + ) + + # Finally return the samples as usual. + return samples + + @override(EnvRunner) + @OverrideToImplementCustomLogic + def stop(self) -> None: + """Writes the reamining samples to disk + + Note, if the user defined `max_rows_per_file` the + number of rows for the remaining samples could be + less than the defined maximum row number by the user. + """ + # If there are samples left over we have to write htem to disk. them + # to a dataset. + if self._samples and self.write_remaining_data: + # Convert them to a `ray.data.Dataset`. + samples_ds = ray.data.from_items(self._samples) + # Increase the sample counter for the folder/file name. + self._sample_counter += 1 + # Try to write the dataset to disk/cloud storage. + try: + # Setup the path for writing data. Each run will be written to + # its own file. A run is a writing event. The path will look + # like. 'base_path/env-name/00000-00000'. + path = ( + Path(self.output_path) + .joinpath(self.subdir_path) + .joinpath(self.worker_path + f"-{self._sample_counter}".zfill(6)) + ) + getattr(samples_ds, self.output_write_method)( + path.as_posix(), **self.output_write_method_kwargs + ) + logger.info( + f"Wrote final samples to storage at {path}. Note " + "Note, final samples could be smaller in size than " + f"`max_rows_per_file`, if defined." + ) + except Exception as e: + logger.error(e) + + logger.debug(f"Experience buffer length: {len(self._samples)}") + + @OverrideToImplementCustomLogic + def _map_episodes_to_data(self, samples: List[EpisodeType]) -> None: + """Converts list of episodes to list of single dict experiences. + + Note, this method also appends all sampled experiences to the + buffer. + + Args: + samples: List of episodes to be converted. + """ + # Loop through all sampled episodes. + obs_space = self.env.observation_space + action_space = self.env.action_space + for sample in samples: + # Loop through all items of the episode. + for i in range(len(sample)): + sample_data = { + Columns.EPS_ID: sample.id_, + Columns.AGENT_ID: sample.agent_id, + Columns.MODULE_ID: sample.module_id, + # Compress observations, if requested. + Columns.OBS: pack_if_needed( + to_jsonable_if_needed(sample.get_observations(i), obs_space) + ) + if Columns.OBS in self.output_compress_columns + else to_jsonable_if_needed(sample.get_observations(i), obs_space), + # Compress actions, if requested. + Columns.ACTIONS: pack_if_needed( + to_jsonable_if_needed(sample.get_actions(i), action_space) + ) + if Columns.ACTIONS in self.output_compress_columns + else to_jsonable_if_needed(sample.get_actions(i), action_space), + Columns.REWARDS: sample.get_rewards(i), + # Compress next observations, if requested. + Columns.NEXT_OBS: pack_if_needed( + to_jsonable_if_needed(sample.get_observations(i + 1), obs_space) + ) + if Columns.OBS in self.output_compress_columns + else to_jsonable_if_needed( + sample.get_observations(i + 1), obs_space + ), + Columns.TERMINATEDS: False + if i < len(sample) - 1 + else sample.is_terminated, + Columns.TRUNCATEDS: False + if i < len(sample) - 1 + else sample.is_truncated, + **{ + # Compress any extra model output, if requested. + k: pack_if_needed(sample.get_extra_model_outputs(k, i)) + if k in self.output_compress_columns + else sample.get_extra_model_outputs(k, i) + for k in sample.extra_model_outputs.keys() + }, + } + # Finally append to the data buffer. + self._samples.append(sample_data) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_evaluation_utils.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_evaluation_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..de39f149f6958db0261571d6a4a8b264d359090d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_evaluation_utils.py @@ -0,0 +1,131 @@ +import numpy as np +import pandas as pd +from typing import Any, Dict, Type, TYPE_CHECKING + +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.policy import Policy +from ray.rllib.utils.numpy import convert_to_numpy +from ray.rllib.utils.annotations import DeveloperAPI + +if TYPE_CHECKING: + from ray.rllib.offline.estimators.fqe_torch_model import FQETorchModel + from ray.rllib.offline.estimators.off_policy_estimator import OffPolicyEstimator + + +@DeveloperAPI +def compute_q_and_v_values( + batch: pd.DataFrame, + model_class: Type["FQETorchModel"], + model_state: Dict[str, Any], + compute_q_values: bool = True, +) -> pd.DataFrame: + """Computes the Q and V values for the given batch of samples. + + This function is to be used with map_batches() to perform a batch prediction on a + dataset of records with `obs` and `actions` columns. + + Args: + batch: A sub-batch from the dataset. + model_class: The model class to use for the prediction. This class should be a + sub-class of FQEModel that implements the estimate_q() and estimate_v() + methods. + model_state: The state of the model to use for the prediction. + compute_q_values: Whether to compute the Q values or not. If False, only the V + is computed and returned. + + Returns: + The modified batch with the Q and V values added as columns. + """ + model = model_class.from_state(model_state) + + sample_batch = SampleBatch( + { + SampleBatch.OBS: np.vstack(batch[SampleBatch.OBS]), + SampleBatch.ACTIONS: np.vstack(batch[SampleBatch.ACTIONS]).squeeze(-1), + } + ) + + v_values = model.estimate_v(sample_batch) + v_values = convert_to_numpy(v_values) + batch["v_values"] = v_values + + if compute_q_values: + q_values = model.estimate_q(sample_batch) + q_values = convert_to_numpy(q_values) + batch["q_values"] = q_values + + return batch + + +@DeveloperAPI +def compute_is_weights( + batch: pd.DataFrame, + policy_state: Dict[str, Any], + estimator_class: Type["OffPolicyEstimator"], +) -> pd.DataFrame: + """Computes the importance sampling weights for the given batch of samples. + + For a lot of off-policy estimators, the importance sampling weights are computed as + the propensity score ratio between the new and old policies + (i.e. new_pi(act|obs) / old_pi(act|obs)). This function is to be used with + map_batches() to perform a batch prediction on a dataset of records with `obs`, + `actions`, `action_prob` and `rewards` columns. + + Args: + batch: A sub-batch from the dataset. + policy_state: The state of the policy to use for the prediction. + estimator_class: The estimator class to use for the prediction. This class + + Returns: + The modified batch with the importance sampling weights, weighted rewards, new + and old propensities added as columns. + """ + policy = Policy.from_state(policy_state) + estimator = estimator_class(policy=policy, gamma=0, epsilon_greedy=0) + sample_batch = SampleBatch( + { + SampleBatch.OBS: np.vstack(batch["obs"].values), + SampleBatch.ACTIONS: np.vstack(batch["actions"].values).squeeze(-1), + SampleBatch.ACTION_PROB: np.vstack(batch["action_prob"].values).squeeze(-1), + SampleBatch.REWARDS: np.vstack(batch["rewards"].values).squeeze(-1), + } + ) + new_prob = estimator.compute_action_probs(sample_batch) + old_prob = sample_batch[SampleBatch.ACTION_PROB] + rewards = sample_batch[SampleBatch.REWARDS] + weights = new_prob / old_prob + weighted_rewards = weights * rewards + + batch["weights"] = weights + batch["weighted_rewards"] = weighted_rewards + batch["new_prob"] = new_prob + batch["old_prob"] = old_prob + + return batch + + +@DeveloperAPI +def remove_time_dim(batch: pd.DataFrame) -> pd.DataFrame: + """Removes the time dimension from the given sub-batch of the dataset. + + If each row in a dataset has a time dimension ([T, D]), and T=1, this function will + remove the T dimension to convert each row to of shape [D]. If T > 1, the row is + left unchanged. This function is to be used with map_batches(). + + Args: + batch: The batch to remove the time dimension from. + Returns: + The modified batch with the time dimension removed (when applicable) + """ + BATCHED_KEYS = { + SampleBatch.OBS, + SampleBatch.ACTIONS, + SampleBatch.ACTION_PROB, + SampleBatch.REWARDS, + SampleBatch.NEXT_OBS, + SampleBatch.DONES, + } + for k in batch.columns: + if k in BATCHED_KEYS: + batch[k] = batch[k].apply(lambda x: x[0] if len(x) == 1 else x) + return batch diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_evaluator.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_evaluator.py new file mode 100644 index 0000000000000000000000000000000000000000..60b87ff1296dfbad51629f172e07131170d02894 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_evaluator.py @@ -0,0 +1,77 @@ +import abc +import os +import logging +from typing import Dict, Any + +from ray.data import Dataset + +from ray.rllib.policy import Policy +from ray.rllib.utils.annotations import DeveloperAPI, ExperimentalAPI +from ray.rllib.utils.typing import SampleBatchType + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class OfflineEvaluator(abc.ABC): + """Interface for an offline evaluator of a policy""" + + @DeveloperAPI + def __init__(self, policy: Policy, **kwargs): + """Initializes an OffPolicyEstimator instance. + + Args: + policy: Policy to evaluate. + kwargs: forward compatibility placeholder. + """ + self.policy = policy + + @abc.abstractmethod + @DeveloperAPI + def estimate(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: + """Returns the evaluation results for the given batch of episodes. + + Args: + batch: The batch to evaluate. + kwargs: forward compatibility placeholder. + + Returns: + The evaluation done on the given batch. The returned + dict can be any arbitrary mapping of strings to metrics. + """ + raise NotImplementedError + + @DeveloperAPI + def train(self, batch: SampleBatchType, **kwargs) -> Dict[str, Any]: + """Sometimes you need to train a model inside an evaluator. This method + abstracts the training process. + + Args: + batch: SampleBatch to train on + kwargs: forward compatibility placeholder. + + Returns: + Any optional metrics to return from the evaluator + """ + return {} + + @ExperimentalAPI + def estimate_on_dataset( + self, + dataset: Dataset, + *, + n_parallelism: int = os.cpu_count(), + ) -> Dict[str, Any]: + + """Calculates the estimate of the metrics based on the given offline dataset. + + Typically, the dataset is passed through only once via n_parallel tasks in + mini-batches to improve the run-time of metric estimation. + + Args: + dataset: The ray dataset object to do offline evaluation on. + n_parallelism: The number of parallelism to use for the computation. + + Returns: + Dict[str, Any]: A dictionary of the estimated values. + """ diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_prelearner.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_prelearner.py new file mode 100644 index 0000000000000000000000000000000000000000..8e8306751796e71dbeb4a9234feffcd030cbaeb8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/offline_prelearner.py @@ -0,0 +1,647 @@ +import gymnasium as gym +import logging +import numpy as np +import uuid + +from typing import Any, Dict, List, Optional, Union, Set, Tuple, TYPE_CHECKING + +from ray.actor import ActorHandle +from ray.rllib.core.columns import Columns +from ray.rllib.core.learner import Learner +from ray.rllib.core.rl_module.multi_rl_module import MultiRLModuleSpec +from ray.rllib.env.single_agent_episode import SingleAgentEpisode +from ray.rllib.policy.sample_batch import MultiAgentBatch, SampleBatch +from ray.rllib.utils.annotations import ( + OverrideToImplementCustomLogic, + OverrideToImplementCustomLogic_CallToSuperRecommended, +) +from ray.rllib.utils.compression import unpack_if_needed +from ray.rllib.utils.replay_buffers.replay_buffer import ReplayBuffer +from ray.rllib.utils.spaces.space_utils import from_jsonable_if_needed +from ray.rllib.utils.typing import EpisodeType, ModuleID +from ray.util.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + +#: This is the default schema used if no `input_read_schema` is set in +#: the config. If a user passes in a schema into `input_read_schema` +#: this user-defined schema has to comply with the keys of `SCHEMA`, +#: while values correspond to the columns in the user's dataset. Note +#: that only the user-defined values will be overridden while all +#: other values from SCHEMA remain as defined here. +SCHEMA = { + Columns.EPS_ID: Columns.EPS_ID, + Columns.AGENT_ID: Columns.AGENT_ID, + Columns.MODULE_ID: Columns.MODULE_ID, + Columns.OBS: Columns.OBS, + Columns.ACTIONS: Columns.ACTIONS, + Columns.REWARDS: Columns.REWARDS, + Columns.INFOS: Columns.INFOS, + Columns.NEXT_OBS: Columns.NEXT_OBS, + Columns.TERMINATEDS: Columns.TERMINATEDS, + Columns.TRUNCATEDS: Columns.TRUNCATEDS, + Columns.T: Columns.T, + # TODO (simon): Add remove as soon as we are new stack only. + "agent_index": "agent_index", + "dones": "dones", + "unroll_id": "unroll_id", +} + +logger = logging.getLogger(__name__) + + +@PublicAPI(stability="alpha") +class OfflinePreLearner: + """Class that coordinates data transformation from dataset to learner. + + This class is an essential part of the new `Offline RL API` of `RLlib`. + It is a callable class that is run in `ray.data.Dataset.map_batches` + when iterating over batches for training. It's basic function is to + convert data in batch from rows to episodes (`SingleAGentEpisode`s + for now) and to then run the learner connector pipeline to convert + further to trainable batches. These batches are used directly in the + `Learner`'s `update` method. + + The main reason to run these transformations inside of `map_batches` + is for better performance. Batches can be pre-fetched in `ray.data` + and therefore batch trransformation can be run highly parallelized to + the `Learner''s `update`. + + This class can be overridden to implement custom logic for transforming + batches and make them 'Learner'-ready. When deriving from this class + the `__call__` method and `_map_to_episodes` can be overridden to induce + custom logic for the complete transformation pipeline (`__call__`) or + for converting to episodes only ('_map_to_episodes`). For an example + how this class can be used to also compute values and advantages see + `rllib.algorithm.marwil.marwil_prelearner.MAWRILOfflinePreLearner`. + + Custom `OfflinePreLearner` classes can be passed into + `AlgorithmConfig.offline`'s `prelearner_class`. The `OfflineData` class + will then use the custom class in its data pipeline. + """ + + @OverrideToImplementCustomLogic_CallToSuperRecommended + def __init__( + self, + *, + config: "AlgorithmConfig", + learner: Union[Learner, list[ActorHandle]], + spaces: Optional[Tuple[gym.Space, gym.Space]] = None, + module_spec: Optional[MultiRLModuleSpec] = None, + module_state: Optional[Dict[ModuleID, Any]] = None, + **kwargs: Dict[str, Any], + ): + + self.config = config + self.input_read_episodes = self.config.input_read_episodes + self.input_read_sample_batches = self.config.input_read_sample_batches + # We need this learner to run the learner connector pipeline. + # If it is a `Learner` instance, the `Learner` is local. + if isinstance(learner, Learner): + self._learner = learner + self.learner_is_remote = False + self._module = self._learner._module + # Otherwise we have remote `Learner`s. + else: + self.learner_is_remote = True + # Build the module from spec. Note, this will be a MultiRLModule. + self._module = module_spec.build() + self._module.set_state(module_state) + + # Store the observation and action space if defined, otherwise we + # set them to `None`. Note, if `None` the `convert_from_jsonable` + # will not convert the input space samples. + self.observation_space, self.action_space = spaces or (None, None) + + # Build the learner connector pipeline. + self._learner_connector = self.config.build_learner_connector( + input_observation_space=self.observation_space, + input_action_space=self.action_space, + ) + # Cache the policies to be trained to update weights only for these. + self._policies_to_train = self.config.policies_to_train + self._is_multi_agent = config.is_multi_agent + # Set the counter to zero. + self.iter_since_last_module_update = 0 + # self._future = None + + # Set up an episode buffer, if the module is stateful or we sample from + # `SampleBatch` types. + if ( + self.input_read_sample_batches + or self._module.is_stateful() + or self.input_read_episodes + ): + # Either the user defined a buffer class or we fall back to the default. + prelearner_buffer_class = ( + self.config.prelearner_buffer_class + or self.default_prelearner_buffer_class + ) + prelearner_buffer_kwargs = ( + self.default_prelearner_buffer_kwargs + | self.config.prelearner_buffer_kwargs + ) + # Initialize the buffer. + self.episode_buffer = prelearner_buffer_class( + **prelearner_buffer_kwargs, + ) + + @OverrideToImplementCustomLogic + def __call__(self, batch: Dict[str, np.ndarray]) -> Dict[str, List[EpisodeType]]: + """Prepares plain data batches for training with `Learner`'s. + + Args: + batch: A dictionary of numpy arrays containing either column data + with `self.config.input_read_schema`, `EpisodeType` data, or + `BatchType` data. + + Returns: + A `MultiAgentBatch` that can be passed to `Learner.update` methods. + """ + # If we directly read in episodes we just convert to list. + if self.input_read_episodes: + # Import `msgpack` for decoding. + import msgpack + import msgpack_numpy as mnp + + # Read the episodes and decode them. + episodes = [ + SingleAgentEpisode.from_state( + msgpack.unpackb(state, object_hook=mnp.decode) + ) + for state in batch["item"] + ] + # Ensure that all episodes are done and no duplicates are in the batch. + episodes = self._validate_episodes(episodes) + # Add the episodes to the buffer. + self.episode_buffer.add(episodes) + # TODO (simon): Refactor into a single code block for both cases. + episodes = self.episode_buffer.sample( + num_items=self.config.train_batch_size_per_learner, + batch_length_T=self.config.model_config.get("max_seq_len", 0) + if self._module.is_stateful() + else None, + n_step=self.config.get("n_step", 1) or 1, + # TODO (simon): This can be removed as soon as DreamerV3 has been + # cleaned up, i.e. can use episode samples for training. + sample_episodes=True, + to_numpy=True, + ) + # Else, if we have old stack `SampleBatch`es. + elif self.input_read_sample_batches: + episodes = OfflinePreLearner._map_sample_batch_to_episode( + self._is_multi_agent, + batch, + to_numpy=True, + schema=SCHEMA | self.config.input_read_schema, + input_compress_columns=self.config.input_compress_columns, + )["episodes"] + # Ensure that all episodes are done and no duplicates are in the batch. + episodes = self._validate_episodes(episodes) + # Add the episodes to the buffer. + self.episode_buffer.add(episodes) + # Sample steps from the buffer. + episodes = self.episode_buffer.sample( + num_items=self.config.train_batch_size_per_learner, + batch_length_T=self.config.model_config.get("max_seq_len", 0) + if self._module.is_stateful() + else None, + n_step=self.config.get("n_step", 1) or 1, + # TODO (simon): This can be removed as soon as DreamerV3 has been + # cleaned up, i.e. can use episode samples for training. + sample_episodes=True, + to_numpy=True, + ) + # Otherwise we map the batch to episodes. + else: + episodes = self._map_to_episodes( + self._is_multi_agent, + batch, + schema=SCHEMA | self.config.input_read_schema, + to_numpy=True, + input_compress_columns=self.config.input_compress_columns, + observation_space=self.observation_space, + action_space=self.action_space, + )["episodes"] + + # TODO (simon): Make synching work. Right now this becomes blocking or never + # receives weights. Learners appear to be non accessi ble via other actors. + # Increase the counter for updating the module. + # self.iter_since_last_module_update += 1 + + # if self._future: + # refs, _ = ray.wait([self._future], timeout=0) + # print(f"refs: {refs}") + # if refs: + # module_state = ray.get(self._future) + # + # self._module.set_state(module_state) + # self._future = None + + # # Synch the learner module, if necessary. Note, in case of a local learner + # # we have a reference to the module and therefore an up-to-date module. + # if self.learner_is_remote and self.iter_since_last_module_update + # > self.config.prelearner_module_synch_period: + # # Reset the iteration counter. + # self.iter_since_last_module_update = 0 + # # Request the module weights from the remote learner. + # self._future = + # self._learner.get_module_state.remote(inference_only=False) + # # module_state = + # ray.get(self._learner.get_module_state.remote(inference_only=False)) + # # self._module.set_state(module_state) + + # Run the `Learner`'s connector pipeline. + batch = self._learner_connector( + rl_module=self._module, + batch={}, + episodes=episodes, + shared_data={}, + # TODO (sven): Add MetricsLogger to non-Learner components that have a + # LearnerConnector pipeline. + metrics=None, + ) + # Convert to `MultiAgentBatch`. + batch = MultiAgentBatch( + { + module_id: SampleBatch(module_data) + for module_id, module_data in batch.items() + }, + # TODO (simon): This can be run once for the batch and the + # metrics, but we run it twice: here and later in the learner. + env_steps=sum(e.env_steps() for e in episodes), + ) + # Remove all data from modules that should not be trained. We do + # not want to pass around more data than necessaty. + for module_id in list(batch.policy_batches.keys()): + if not self._should_module_be_updated(module_id, batch): + del batch.policy_batches[module_id] + + # TODO (simon): Log steps trained for metrics (how?). At best in learner + # and not here. But we could precompute metrics here and pass it to the learner + # for logging. Like this we do not have to pass around episode lists. + + # TODO (simon): episodes are only needed for logging here. + return {"batch": [batch]} + + @property + def default_prelearner_buffer_class(self) -> ReplayBuffer: + """Sets the default replay buffer.""" + from ray.rllib.utils.replay_buffers.episode_replay_buffer import ( + EpisodeReplayBuffer, + ) + + # Return the buffer. + return EpisodeReplayBuffer + + @property + def default_prelearner_buffer_kwargs(self) -> Dict[str, Any]: + """Sets the default arguments for the replay buffer. + + Note, the `capacity` might vary with the size of the episodes or + sample batches in the offline dataset. + """ + return { + "capacity": self.config.train_batch_size_per_learner * 10, + "batch_size_B": self.config.train_batch_size_per_learner, + } + + def _validate_episodes( + self, episodes: List[SingleAgentEpisode] + ) -> Set[SingleAgentEpisode]: + """Validate episodes sampled from the dataset. + + Note, our episode buffers cannot handle either duplicates nor + non-ordered fragmentations, i.e. fragments from episodes that do + not arrive in timestep order. + + Args: + episodes: A list of `SingleAgentEpisode` instances sampled + from a dataset. + + Returns: + A set of `SingleAgentEpisode` instances. + + Raises: + ValueError: If not all episodes are `done`. + """ + # Ensure that episodes are all done. + if not all(eps.is_done for eps in episodes): + raise ValueError( + "When sampling from episodes (`input_read_episodes=True`) all " + "recorded episodes must be done (i.e. either `terminated=True`) " + "or `truncated=True`)." + ) + # Ensure that episodes do not contain duplicates. Note, this can happen + # if the dataset is small and pulled batches contain multiple episodes. + unique_episode_ids = set() + episodes = { + eps + for eps in episodes + if eps.id_ not in unique_episode_ids + and not unique_episode_ids.add(eps.id_) + and eps.id_ not in self.episode_buffer.episode_id_to_index.keys() + } + return episodes + + def _should_module_be_updated(self, module_id, multi_agent_batch=None) -> bool: + """Checks which modules in a MultiRLModule should be updated.""" + if not self._policies_to_train: + # In case of no update information, the module is updated. + return True + elif not callable(self._policies_to_train): + return module_id in set(self._policies_to_train) + else: + return self._policies_to_train(module_id, multi_agent_batch) + + @OverrideToImplementCustomLogic + @staticmethod + def _map_to_episodes( + is_multi_agent: bool, + batch: Dict[str, Union[list, np.ndarray]], + schema: Dict[str, str] = SCHEMA, + to_numpy: bool = False, + input_compress_columns: Optional[List[str]] = None, + observation_space: gym.Space = None, + action_space: gym.Space = None, + **kwargs: Dict[str, Any], + ) -> Dict[str, List[EpisodeType]]: + """Maps a batch of data to episodes.""" + + # Set to empty list, if `None`. + input_compress_columns = input_compress_columns or [] + + # If spaces are given, we can use the space-specific + # conversion method to convert space samples. + if observation_space and action_space: + convert = from_jsonable_if_needed + # Otherwise we use an identity function. + else: + + def convert(sample, space): + return sample + + episodes = [] + for i, obs in enumerate(batch[schema[Columns.OBS]]): + + # If multi-agent we need to extract the agent ID. + # TODO (simon): Check, what happens with the module ID. + if is_multi_agent: + agent_id = ( + batch[schema[Columns.AGENT_ID]][i] + if Columns.AGENT_ID in batch + # The old stack uses "agent_index" instead of "agent_id". + # TODO (simon): Remove this as soon as we are new stack only. + else ( + batch[schema["agent_index"]][i] + if schema["agent_index"] in batch + else None + ) + ) + else: + agent_id = None + + if is_multi_agent: + # TODO (simon): Add support for multi-agent episodes. + NotImplementedError + else: + # Build a single-agent episode with a single row of the batch. + episode = SingleAgentEpisode( + id_=str(batch[schema[Columns.EPS_ID]][i]), + agent_id=agent_id, + # Observations might be (a) serialized and/or (b) converted + # to a JSONable (when a composite space was used). We unserialize + # and then reconvert from JSONable to space sample. + observations=[ + convert(unpack_if_needed(obs), observation_space) + if Columns.OBS in input_compress_columns + else convert(obs, observation_space), + convert( + unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i]), + observation_space, + ) + if Columns.OBS in input_compress_columns + else convert( + batch[schema[Columns.NEXT_OBS]][i], observation_space + ), + ], + infos=[ + {}, + batch[schema[Columns.INFOS]][i] + if schema[Columns.INFOS] in batch + else {}, + ], + # Actions might be (a) serialized and/or (b) converted to a JSONable + # (when a composite space was used). We unserializer and then + # reconvert from JSONable to space sample. + actions=[ + convert( + unpack_if_needed(batch[schema[Columns.ACTIONS]][i]), + action_space, + ) + if Columns.ACTIONS in input_compress_columns + else convert(batch[schema[Columns.ACTIONS]][i], action_space) + ], + rewards=[batch[schema[Columns.REWARDS]][i]], + terminated=batch[ + schema[Columns.TERMINATEDS] + if schema[Columns.TERMINATEDS] in batch + else "dones" + ][i], + truncated=batch[schema[Columns.TRUNCATEDS]][i] + if schema[Columns.TRUNCATEDS] in batch + else False, + # TODO (simon): Results in zero-length episodes in connector. + # t_started=batch[Columns.T if Columns.T in batch else + # "unroll_id"][i][0], + # TODO (simon): Single-dimensional columns are not supported. + # Extra model outputs might be serialized. We unserialize them here + # if needed. + # TODO (simon): Check, if we need here also reconversion from + # JSONable in case of composite spaces. + extra_model_outputs={ + k: [ + unpack_if_needed(v[i]) + if k in input_compress_columns + else v[i] + ] + for k, v in batch.items() + if ( + k not in schema + and k not in schema.values() + and k not in ["dones", "agent_index", "type"] + ) + }, + len_lookback_buffer=0, + ) + + if to_numpy: + episode.to_numpy() + episodes.append(episode) + # Note, `map_batches` expects a `Dict` as return value. + return {"episodes": episodes} + + @OverrideToImplementCustomLogic + @staticmethod + def _map_sample_batch_to_episode( + is_multi_agent: bool, + batch: Dict[str, Union[list, np.ndarray]], + schema: Dict[str, str] = SCHEMA, + to_numpy: bool = False, + input_compress_columns: Optional[List[str]] = None, + ) -> Dict[str, List[EpisodeType]]: + """Maps an old stack `SampleBatch` to new stack episodes.""" + + # Set `input_compress_columns` to an empty `list` if `None`. + input_compress_columns = input_compress_columns or [] + + # TODO (simon): CHeck, if needed. It could possibly happen that a batch contains + # data from different episodes. Merging and resplitting the batch would then + # be the solution. + # Check, if batch comes actually from multiple episodes. + # episode_begin_indices = np.where(np.diff(np.hstack(batch["eps_id"])) != 0) + 1 + + # Define a container to collect episodes. + episodes = [] + # Loop over `SampleBatch`es in the `ray.data` batch (a dict). + for i, obs in enumerate(batch[schema[Columns.OBS]]): + + # If multi-agent we need to extract the agent ID. + # TODO (simon): Check, what happens with the module ID. + if is_multi_agent: + agent_id = ( + # The old stack uses "agent_index" instead of "agent_id". + batch[schema["agent_index"]][i][0] + if schema["agent_index"] in batch + else None + ) + else: + agent_id = None + + if is_multi_agent: + # TODO (simon): Add support for multi-agent episodes. + NotImplementedError + else: + # Unpack observations, if needed. Note, observations could + # be either compressed by their entirety (the complete batch + # column) or individually (each column entry). + if isinstance(obs, str): + # Decompress the observations if we have a string, i.e. + # observations are compressed in their entirety. + obs = unpack_if_needed(obs) + # Convert to a list of arrays. This is needed as input by + # the `SingleAgentEpisode`. + obs = [obs[i, ...] for i in range(obs.shape[0])] + # Otherwise observations are only compressed inside of the + # batch column (if at all). + elif isinstance(obs, np.ndarray): + # Unpack observations, if they are compressed otherwise we + # simply convert to a list, which is needed by the + # `SingleAgentEpisode`. + obs = ( + unpack_if_needed(obs.tolist()) + if schema[Columns.OBS] in input_compress_columns + else obs.tolist() + ) + else: + raise TypeError( + f"Unknown observation type: {type(obs)}. When mapping " + "from old recorded `SampleBatches` batched " + "observations should be either of type `np.array` " + "or - if the column is compressed - of `str` type." + ) + + if schema[Columns.NEXT_OBS] in batch: + # Append the last `new_obs` to get the correct length of + # observations. + obs.append( + unpack_if_needed(batch[schema[Columns.NEXT_OBS]][i][-1]) + if schema[Columns.OBS] in input_compress_columns + else batch[schema[Columns.NEXT_OBS]][i][-1] + ) + else: + # Otherwise we duplicate the last observation. + obs.append(obs[-1]) + + # Check, if we have `done`, `truncated`, or `terminated`s in + # the batch. + if ( + schema[Columns.TRUNCATEDS] in batch + and schema[Columns.TERMINATEDS] in batch + ): + truncated = batch[schema[Columns.TRUNCATEDS]][i][-1] + terminated = batch[schema[Columns.TERMINATEDS]][i][-1] + elif ( + schema[Columns.TRUNCATEDS] in batch + and schema[Columns.TERMINATEDS] not in batch + ): + truncated = batch[schema[Columns.TRUNCATEDS]][i][-1] + terminated = False + elif ( + schema[Columns.TRUNCATEDS] not in batch + and schema[Columns.TERMINATEDS] in batch + ): + terminated = batch[schema[Columns.TERMINATEDS]][i][-1] + truncated = False + elif "done" in batch: + terminated = batch["done"][i][-1] + truncated = False + # Otherwise, if no `terminated`, nor `truncated` nor `done` + # is given, we consider the episode as terminated. + else: + terminated = True + truncated = False + + # Create a `SingleAgentEpisode`. + episode = SingleAgentEpisode( + # If the recorded episode has an ID we use this ID, + # otherwise we generate a new one. + id_=str(batch[schema[Columns.EPS_ID]][i][0]) + if schema[Columns.EPS_ID] in batch + else uuid.uuid4().hex, + agent_id=agent_id, + observations=obs, + infos=( + batch[schema[Columns.INFOS]][i] + if schema[Columns.INFOS] in batch + else [{}] * len(obs) + ), + # Actions might be (a) serialized. We unserialize them here. + actions=( + unpack_if_needed(batch[schema[Columns.ACTIONS]][i]) + if Columns.ACTIONS in input_compress_columns + else batch[schema[Columns.ACTIONS]][i] + ), + rewards=batch[schema[Columns.REWARDS]][i], + terminated=terminated, + truncated=truncated, + # TODO (simon): Results in zero-length episodes in connector. + # t_started=batch[Columns.T if Columns.T in batch else + # "unroll_id"][i][0], + # TODO (simon): Single-dimensional columns are not supported. + # Extra model outputs might be serialized. We unserialize them here + # if needed. + # TODO (simon): Check, if we need here also reconversion from + # JSONable in case of composite spaces. + extra_model_outputs={ + k: unpack_if_needed(v[i]) + if k in input_compress_columns + else v[i] + for k, v in batch.items() + if ( + k not in schema + and k not in schema.values() + and k not in ["dones", "agent_index", "type"] + ) + }, + len_lookback_buffer=0, + ) + # Numpy'ized, if necessary. + # TODO (simon, sven): Check, if we should convert all data to lists + # before. Right now only obs are lists. + if to_numpy: + episode.to_numpy() + episodes.append(episode) + # Note, `map_batches` expects a `Dict` as return value. + return {"episodes": episodes} diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/output_writer.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/output_writer.py new file mode 100644 index 0000000000000000000000000000000000000000..ca26c5a538face19f1b552fa843ae558675093f8 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/output_writer.py @@ -0,0 +1,26 @@ +from ray.rllib.utils.annotations import override, PublicAPI +from ray.rllib.utils.typing import SampleBatchType + + +@PublicAPI +class OutputWriter: + """Writer API for saving experiences from policy evaluation.""" + + @PublicAPI + def write(self, sample_batch: SampleBatchType): + """Saves a batch of experiences. + + Args: + sample_batch: SampleBatch or MultiAgentBatch to save. + """ + raise NotImplementedError + + +@PublicAPI +class NoopOutput(OutputWriter): + """Output writer that discards its outputs.""" + + @override(OutputWriter) + def write(self, sample_batch: SampleBatchType): + # Do nothing. + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/resource.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/resource.py new file mode 100644 index 0000000000000000000000000000000000000000..e658b9b682bc77110f2c6a23061c1f0e1ef4ca05 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/resource.py @@ -0,0 +1,30 @@ +from typing import Dict, List, TYPE_CHECKING +from ray.rllib.utils.annotations import PublicAPI + +if TYPE_CHECKING: + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + +DEFAULT_NUM_CPUS_PER_TASK = 0.5 + + +@PublicAPI +def get_offline_io_resource_bundles( + config: "AlgorithmConfig", +) -> List[Dict[str, float]]: + # DatasetReader is the only offline I/O component today that + # requires compute resources. + if config.input_ == "dataset": + input_config = config.input_config + # TODO (Kourosh): parallelism is use for reading the dataset, which defaults to + # num_workers. This logic here relies on the information that dataset reader + # will have the same logic. So to remove the information leakage, inside + # Algorithm config, we should set parallelism to num_workers if not specified + # and only deal with parallelism here or in dataset_reader.py. same thing is + # true with cpus_per_task. + parallelism = input_config.get("parallelism", config.get("num_env_runners", 1)) + cpus_per_task = input_config.get( + "num_cpus_per_read_task", DEFAULT_NUM_CPUS_PER_TASK + ) + return [{"CPU": cpus_per_task} for _ in range(parallelism)] + else: + return [] diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/shuffled_input.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/shuffled_input.py new file mode 100644 index 0000000000000000000000000000000000000000..a7c2610185940b7df8b1da07a1b421fb671b5514 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/shuffled_input.py @@ -0,0 +1,42 @@ +import logging +import random + +from ray.rllib.offline.input_reader import InputReader +from ray.rllib.utils.annotations import override, DeveloperAPI +from ray.rllib.utils.typing import SampleBatchType + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class ShuffledInput(InputReader): + """Randomizes data over a sliding window buffer of N batches. + + This increases the randomization of the data, which is useful if the + batches were not in random order to start with. + """ + + @DeveloperAPI + def __init__(self, child: InputReader, n: int = 0): + """Initializes a ShuffledInput instance. + + Args: + child: child input reader to shuffle. + n: If positive, shuffle input over this many batches. + """ + self.n = n + self.child = child + self.buffer = [] + + @override(InputReader) + def next(self) -> SampleBatchType: + if self.n <= 1: + return self.child.next() + if len(self.buffer) < self.n: + logger.info("Filling shuffle buffer to {} batches".format(self.n)) + while len(self.buffer) < self.n: + self.buffer.append(self.child.next()) + logger.info("Shuffle buffer filled") + i = random.randint(0, len(self.buffer) - 1) + self.buffer[i] = self.child.next() + return random.choice(self.buffer) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/offline/wis_estimator.py b/.venv/lib/python3.11/site-packages/ray/rllib/offline/wis_estimator.py new file mode 100644 index 0000000000000000000000000000000000000000..128b50e24b2ace880299458885a5139736f134bc --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/offline/wis_estimator.py @@ -0,0 +1,13 @@ +from ray.rllib.offline.estimators.weighted_importance_sampling import ( + WeightedImportanceSampling, +) +from ray.rllib.utils.deprecation import Deprecated + + +@Deprecated( + new="ray.rllib.offline.estimators.weighted_importance_sampling::" + "WeightedImportanceSampling", + error=True, +) +class WeightedImportanceSamplingEstimator(WeightedImportanceSampling): + pass diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/actor_manager.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/actor_manager.py new file mode 100644 index 0000000000000000000000000000000000000000..1dc1401fed185bd967cc6b10a898131a30a1e2ee --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/actor_manager.py @@ -0,0 +1,927 @@ +from collections import defaultdict +import copy +from dataclasses import dataclass +import logging +import sys +import time +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, Union + +import ray +from ray.actor import ActorHandle +from ray.exceptions import RayError, RayTaskError +from ray.rllib.utils.typing import T +from ray.util.annotations import DeveloperAPI + + +logger = logging.getLogger(__name__) + + +@DeveloperAPI +class ResultOrError: + """A wrapper around a result or a RayError thrown during remote task/actor calls. + + This is used to return data from `FaultTolerantActorManager` that allows us to + distinguish between RayErrors (remote actor related) and valid results. + """ + + def __init__(self, result: Any = None, error: Exception = None): + """One and only one of result or error should be set. + + Args: + result: The result of the computation. Note that None is a valid result if + the remote function does not return anything. + error: Alternatively, the error that occurred during the computation. + """ + self._result = result + self._error = ( + # Easier to handle if we show the user the original error. + error.as_instanceof_cause() + if isinstance(error, RayTaskError) + else error + ) + + @property + def ok(self): + return self._error is None + + def get(self): + """Returns the result or the error.""" + if self._error: + return self._error + else: + return self._result + + +@DeveloperAPI +@dataclass +class CallResult: + """Represents a single result from a call to an actor. + + Each CallResult contains the index of the actor that was called + plus the result or error from the call. + """ + + actor_id: int + result_or_error: ResultOrError + tag: str + + @property + def ok(self): + """Passes through the ok property from the result_or_error.""" + return self.result_or_error.ok + + def get(self): + """Passes through the get method from the result_or_error.""" + return self.result_or_error.get() + + +@DeveloperAPI +class RemoteCallResults: + """Represents a list of results from calls to a set of actors. + + CallResults provides convenient APIs to iterate over the results + while skipping errors, etc. + + .. testcode:: + :skipif: True + + manager = FaultTolerantActorManager( + actors, max_remote_requests_in_flight_per_actor=2, + ) + results = manager.foreach_actor(lambda w: w.call()) + + # Iterate over all results ignoring errors. + for result in results.ignore_errors(): + print(result.get()) + """ + + class _Iterator: + """An iterator over the results of a remote call.""" + + def __init__(self, call_results: List[CallResult]): + self._call_results = call_results + + def __iter__(self) -> Iterator[CallResult]: + return self + + def __next__(self) -> CallResult: + if not self._call_results: + raise StopIteration + return self._call_results.pop(0) + + def __init__(self): + self.result_or_errors: List[CallResult] = [] + + def add_result(self, actor_id: int, result_or_error: ResultOrError, tag: str): + """Add index of a remote actor plus the call result to the list. + + Args: + actor_id: ID of the remote actor. + result_or_error: The result or error from the call. + tag: A description to identify the call. + """ + self.result_or_errors.append(CallResult(actor_id, result_or_error, tag)) + + def __iter__(self) -> Iterator[ResultOrError]: + """Return an iterator over the results.""" + # Shallow copy the list. + return self._Iterator(copy.copy(self.result_or_errors)) + + def __len__(self) -> int: + return len(self.result_or_errors) + + def ignore_errors(self) -> Iterator[ResultOrError]: + """Return an iterator over the results, skipping all errors.""" + return self._Iterator([r for r in self.result_or_errors if r.ok]) + + def ignore_ray_errors(self) -> Iterator[ResultOrError]: + """Return an iterator over the results, skipping only Ray errors. + + Similar to ignore_errors, but only skips Errors raised because of + remote actor problems (often get restored automatcially). + This is useful for callers that want to handle application errors differently + from Ray errors. + """ + return self._Iterator( + [r for r in self.result_or_errors if not isinstance(r.get(), RayError)] + ) + + +@DeveloperAPI +class FaultAwareApply: + @DeveloperAPI + def ping(self) -> str: + """Ping the actor. Can be used as a health check. + + Returns: + "pong" if actor is up and well. + """ + return "pong" + + @DeveloperAPI + def apply( + self, + func: Callable[[Any, Optional[Any], Optional[Any]], T], + *args, + **kwargs, + ) -> T: + """Calls the given function with this Actor instance. + + A generic interface for applying arbitrary member functions on a + remote actor. + + Args: + func: The function to call, with this actor as first + argument, followed by args, and kwargs. + args: Optional additional args to pass to the function call. + kwargs: Optional additional kwargs to pass to the function call. + + Returns: + The return value of the function call. + """ + try: + return func(self, *args, **kwargs) + except Exception as e: + # Actor should be recreated by Ray. + if self.config.restart_failed_env_runners: + logger.exception(f"Worker exception caught during `apply()`: {e}") + # Small delay to allow logs messages to propagate. + time.sleep(self.config.delay_between_env_runner_restarts_s) + # Kill this worker so Ray Core can restart it. + sys.exit(1) + # Actor should be left dead. + else: + raise e + + +@DeveloperAPI +class FaultTolerantActorManager: + """A manager that is aware of the healthiness of remote actors. + + .. testcode:: + + import time + import ray + from ray.rllib.utils.actor_manager import FaultTolerantActorManager + + @ray.remote + class MyActor: + def apply(self, fn): + return fn(self) + + def do_something(self): + return True + + actors = [MyActor.remote() for _ in range(3)] + manager = FaultTolerantActorManager( + actors, max_remote_requests_in_flight_per_actor=2, + ) + + # Synchronous remote calls. + results = manager.foreach_actor(lambda actor: actor.do_something()) + # Print results ignoring returned errors. + print([r.get() for r in results.ignore_errors()]) + + # Asynchronous remote calls. + manager.foreach_actor_async(lambda actor: actor.do_something()) + time.sleep(2) # Wait for the tasks to finish. + for r in manager.fetch_ready_async_reqs(): + # Handle result and errors. + if r.ok: + print(r.get()) + else: + print("Error: {}".format(r.get())) + """ + + @dataclass + class _ActorState: + """State of a single actor.""" + + # Num of outstanding async requests for this actor. + num_in_flight_async_requests: int = 0 + # Whether this actor is in a healthy state. + is_healthy: bool = True + + def __init__( + self, + actors: Optional[List[ActorHandle]] = None, + max_remote_requests_in_flight_per_actor: int = 2, + init_id: int = 0, + ): + """Construct a FaultTolerantActorManager. + + Args: + actors: A list of ray remote actors to manage on. These actors must have an + ``apply`` method which takes a function with only one parameter (the + actor instance itself). + max_remote_requests_in_flight_per_actor: The maximum number of remote + requests that can be in flight per actor. Any requests made to the pool + that cannot be scheduled because the limit has been reached will be + dropped. This only applies to the asynchronous remote call mode. + init_id: The initial ID to use for the next remote actor. Default is 0. + """ + # For historic reasons, just start remote worker ID from 1, so they never + # collide with local worker ID (0). + self._next_id = init_id + + # Actors are stored in a map and indexed by a unique (int) ID. + self._actors: Dict[int, ActorHandle] = {} + self._remote_actor_states: Dict[int, self._ActorState] = {} + self._restored_actors = set() + self.add_actors(actors or []) + + # For round-robin style async requests, keep track of which actor to send + # a new func next. + self._current_actor_id = self._next_id + + # Maps outstanding async requests to the IDs of the actor IDs that + # are executing them. + self._in_flight_req_to_actor_id: Dict[ray.ObjectRef, int] = {} + + self._max_remote_requests_in_flight_per_actor = ( + max_remote_requests_in_flight_per_actor + ) + + # Useful metric. + self._num_actor_restarts = 0 + + @DeveloperAPI + def actor_ids(self) -> List[int]: + """Returns a list of all worker IDs (healthy or not).""" + return list(self._actors.keys()) + + @DeveloperAPI + def healthy_actor_ids(self) -> List[int]: + """Returns a list of worker IDs that are healthy.""" + return [k for k, v in self._remote_actor_states.items() if v.is_healthy] + + @DeveloperAPI + def add_actors(self, actors: List[ActorHandle]): + """Add a list of actors to the pool. + + Args: + actors: A list of ray remote actors to be added to the pool. + """ + for actor in actors: + self._actors[self._next_id] = actor + self._remote_actor_states[self._next_id] = self._ActorState() + self._next_id += 1 + + @DeveloperAPI + def remove_actor(self, actor_id: int) -> ActorHandle: + """Remove an actor from the pool. + + Args: + actor_id: ID of the actor to remove. + + Returns: + Handle to the actor that was removed. + """ + actor = self._actors[actor_id] + + # Remove the actor from the pool. + del self._actors[actor_id] + del self._remote_actor_states[actor_id] + self._restored_actors.discard(actor_id) + self._remove_async_state(actor_id) + + return actor + + @DeveloperAPI + def num_actors(self) -> int: + """Return the total number of actors in the pool.""" + return len(self._actors) + + @DeveloperAPI + def num_healthy_actors(self) -> int: + """Return the number of healthy remote actors.""" + return sum(s.is_healthy for s in self._remote_actor_states.values()) + + @DeveloperAPI + def total_num_restarts(self) -> int: + """Return the number of remote actors that have been restarted.""" + return self._num_actor_restarts + + @DeveloperAPI + def num_outstanding_async_reqs(self) -> int: + """Return the number of outstanding async requests.""" + return len(self._in_flight_req_to_actor_id) + + @DeveloperAPI + def is_actor_healthy(self, actor_id: int) -> bool: + """Whether a remote actor is in healthy state. + + Args: + actor_id: ID of the remote actor. + + Returns: + True if the actor is healthy, False otherwise. + """ + if actor_id not in self._remote_actor_states: + raise ValueError(f"Unknown actor id: {actor_id}") + return self._remote_actor_states[actor_id].is_healthy + + @DeveloperAPI + def set_actor_state(self, actor_id: int, healthy: bool) -> None: + """Update activate state for a specific remote actor. + + Args: + actor_id: ID of the remote actor. + healthy: Whether the remote actor is healthy. + """ + if actor_id not in self._remote_actor_states: + raise ValueError(f"Unknown actor id: {actor_id}") + + was_healthy = self._remote_actor_states[actor_id].is_healthy + # Set from unhealthy to healthy -> Add to restored set. + if not was_healthy and healthy: + self._restored_actors.add(actor_id) + # Set from healthy to unhealthy -> Remove from restored set. + elif was_healthy and not healthy: + self._restored_actors.discard(actor_id) + + self._remote_actor_states[actor_id].is_healthy = healthy + + if not healthy: + # Remove any async states. + self._remove_async_state(actor_id) + + @DeveloperAPI + def clear(self): + """Clean up managed actors.""" + for actor in self._actors.values(): + ray.kill(actor) + self._actors.clear() + self._remote_actor_states.clear() + self._restored_actors.clear() + self._in_flight_req_to_actor_id.clear() + + @DeveloperAPI + def foreach_actor( + self, + func: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], + *, + healthy_only: bool = True, + remote_actor_ids: Optional[List[int]] = None, + timeout_seconds: Optional[float] = None, + return_obj_refs: bool = False, + mark_healthy: bool = False, + ) -> RemoteCallResults: + """Calls the given function with each actor instance as arg. + + Automatically marks actors unhealthy if they crash during the remote call. + + Args: + func: A single, or a list of Callables, that get applied on the list + of specified remote actors. + healthy_only: If True, applies `func` only to actors currently tagged + "healthy", otherwise to all actors. If `healthy_only=False` and + `mark_healthy=True`, will send `func` to all actors and mark those + actors "healthy" that respond to the request within `timeout_seconds` + and are currently tagged as "unhealthy". + remote_actor_ids: Apply func on a selected set of remote actors. Use None + (default) for all actors. + timeout_seconds: Time to wait (in seconds) for results. Set this to 0.0 for + fire-and-forget. Set this to None (default) to wait infinitely (i.e. for + synchronous execution). + return_obj_refs: whether to return ObjectRef instead of actual results. + Note, for fault tolerance reasons, these returned ObjectRefs should + never be resolved with ray.get() outside of the context of this manager. + mark_healthy: Whether to mark all those actors healthy again that are + currently marked unhealthy AND that returned results from the remote + call (within the given `timeout_seconds`). + Note that actors are NOT set unhealthy, if they simply time out + (only if they return a RayActorError). + Also not that this setting is ignored if `healthy_only=True` (b/c this + setting only affects actors that are currently tagged as unhealthy). + + Returns: + The list of return values of all calls to `func(actor)`. The values may be + actual data returned or exceptions raised during the remote call in the + format of RemoteCallResults. + """ + remote_actor_ids = remote_actor_ids or self.actor_ids() + if healthy_only: + func, remote_actor_ids = self._filter_func_and_remote_actor_id_by_state( + func, remote_actor_ids + ) + + # Send out remote requests. + remote_calls = self._call_actors( + func=func, + remote_actor_ids=remote_actor_ids, + ) + + # Collect remote request results (if available given timeout and/or errors). + _, remote_results = self._fetch_result( + remote_actor_ids=remote_actor_ids, + remote_calls=remote_calls, + tags=[None] * len(remote_calls), + timeout_seconds=timeout_seconds, + return_obj_refs=return_obj_refs, + mark_healthy=mark_healthy, + ) + + return remote_results + + @DeveloperAPI + def foreach_actor_async( + self, + func: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], + tag: str = None, + *, + healthy_only: bool = True, + remote_actor_ids: List[int] = None, + ) -> int: + """Calls given functions against each actors without waiting for results. + + Args: + func: A single Callable applied to all specified remote actors or a list + of Callables, that get applied on the list of specified remote actors. + In the latter case, both list of Callables and list of specified actors + must have the same length. + tag: A tag to identify the results from this async call. + healthy_only: If True, applies `func` only to actors currently tagged + "healthy", otherwise to all actors. If `healthy_only=False` and + later, `self.fetch_ready_async_reqs()` is called with + `mark_healthy=True`, will send `func` to all actors and mark those + actors "healthy" that respond to the request within `timeout_seconds` + and are currently tagged as "unhealthy". + remote_actor_ids: Apply func on a selected set of remote actors. + Note, for fault tolerance reasons, these returned ObjectRefs should + never be resolved with ray.get() outside of the context of this manager. + + Returns: + The number of async requests that are actually fired. + """ + # TODO(avnishn, jungong): so thinking about this a bit more, it would be the + # best if we can attach multiple tags to an async all, like basically this + # parameter should be tags: + # For sync calls, tags would be (). + # For async call users, they can attached multiple tags for a single call, like + # ("rollout_worker", "sync_weight"). + # For async fetch result, we can also specify a single, or list of tags. For + # example, ("eval", "sample") will fetch all the sample() calls on eval + # workers. + if not remote_actor_ids: + remote_actor_ids = self.actor_ids() + + # Perform round robin assignment of all provided calls for any number of our + # actors. Note that this way, some actors might receive more than 1 request in + # this call. + if isinstance(func, list) and len(remote_actor_ids) != len(func): + remote_actor_ids = [ + (self._current_actor_id + i) % self.num_actors() + for i in range(len(func)) + ] + # Update our round-robin pointer. + self._current_actor_id += len(func) % self.num_actors() + + if healthy_only: + func, remote_actor_ids = self._filter_func_and_remote_actor_id_by_state( + func, remote_actor_ids + ) + + num_calls_to_make: Dict[int, int] = defaultdict(lambda: 0) + # Drop calls to actors that are too busy. + if isinstance(func, list): + assert len(func) == len(remote_actor_ids) + limited_func = [] + limited_remote_actor_ids = [] + for i, f in zip(remote_actor_ids, func): + num_outstanding_reqs = self._remote_actor_states[ + i + ].num_in_flight_async_requests + if ( + num_outstanding_reqs + num_calls_to_make[i] + < self._max_remote_requests_in_flight_per_actor + ): + num_calls_to_make[i] += 1 + limited_func.append(f) + limited_remote_actor_ids.append(i) + else: + limited_func = func + limited_remote_actor_ids = [] + for i in remote_actor_ids: + num_outstanding_reqs = self._remote_actor_states[ + i + ].num_in_flight_async_requests + if ( + num_outstanding_reqs + num_calls_to_make[i] + < self._max_remote_requests_in_flight_per_actor + ): + num_calls_to_make[i] += 1 + limited_remote_actor_ids.append(i) + + remote_calls = self._call_actors( + func=limited_func, + remote_actor_ids=limited_remote_actor_ids, + ) + + # Save these as outstanding requests. + for id, call in zip(limited_remote_actor_ids, remote_calls): + self._remote_actor_states[id].num_in_flight_async_requests += 1 + self._in_flight_req_to_actor_id[call] = (tag, id) + + return len(remote_calls) + + @DeveloperAPI + def fetch_ready_async_reqs( + self, + *, + tags: Union[str, List[str], Tuple[str]] = (), + timeout_seconds: Optional[float] = 0.0, + return_obj_refs: bool = False, + mark_healthy: bool = False, + ) -> RemoteCallResults: + """Get results from outstanding async requests that are ready. + + Automatically mark actors unhealthy if they fail to respond. + + Note: If tags is an empty tuple then results from all ready async requests are + returned. + + Args: + timeout_seconds: ray.get() timeout. Default is 0, which only fetched those + results (immediately) that are already ready. + tags: A tag or a list of tags to identify the results from this async call. + return_obj_refs: Whether to return ObjectRef instead of actual results. + mark_healthy: Whether to mark all those actors healthy again that are + currently marked unhealthy AND that returned results from the remote + call (within the given `timeout_seconds`). + Note that actors are NOT set to unhealthy, if they simply time out, + meaning take a longer time to fulfil the remote request. We only ever + mark an actor unhealthy, if they raise a RayActorError inside the remote + request. + Also note that this settings is ignored if the preceding + `foreach_actor_async()` call used the `healthy_only=True` argument (b/c + `mark_healthy` only affects actors that are currently tagged as + unhealthy). + + Returns: + A list of return values of all calls to `func(actor)` that are ready. + The values may be actual data returned or exceptions raised during the + remote call in the format of RemoteCallResults. + """ + # Construct the list of in-flight requests filtered by tag. + remote_calls, remote_actor_ids, valid_tags = self._filter_calls_by_tag(tags) + ready, remote_results = self._fetch_result( + remote_actor_ids=remote_actor_ids, + remote_calls=remote_calls, + tags=valid_tags, + timeout_seconds=timeout_seconds, + return_obj_refs=return_obj_refs, + mark_healthy=mark_healthy, + ) + + for obj_ref, result in zip(ready, remote_results): + # Decrease outstanding request on this actor by 1. + self._remote_actor_states[result.actor_id].num_in_flight_async_requests -= 1 + # Also, remove this call here from the in-flight list, + # obj_refs may have already been removed when we disable an actor. + if obj_ref in self._in_flight_req_to_actor_id: + del self._in_flight_req_to_actor_id[obj_ref] + + return remote_results + + @staticmethod + def handle_remote_call_result_errors( + results_or_errors: RemoteCallResults, + *, + ignore_ray_errors: bool, + ) -> None: + """Checks given results for application errors and raises them if necessary. + + Args: + results_or_errors: The results or errors to check. + ignore_ray_errors: Whether to ignore RayErrors within the elements of + `results_or_errors`. + """ + for result_or_error in results_or_errors: + # Good result. + if result_or_error.ok: + continue + # RayError, but we ignore it. + elif ignore_ray_errors: + logger.exception(result_or_error.get()) + # Raise RayError. + else: + raise result_or_error.get() + + @DeveloperAPI + def probe_unhealthy_actors( + self, + timeout_seconds: Optional[float] = None, + mark_healthy: bool = False, + ) -> List[int]: + """Ping all unhealthy actors to try bringing them back. + + Args: + timeout_seconds: Timeout in seconds (to avoid pinging hanging workers + indefinitely). + mark_healthy: Whether to mark all those actors healthy again that are + currently marked unhealthy AND that respond to the `ping` remote request + (within the given `timeout_seconds`). + Note that actors are NOT set to unhealthy, if they simply time out, + meaning take a longer time to fulfil the remote request. We only ever + mark and actor unhealthy, if they return a RayActorError from the remote + request. + Also note that this settings is ignored if `healthy_only=True` (b/c this + setting only affects actors that are currently tagged as unhealthy). + + Returns: + A list of actor IDs that were restored by the `ping.remote()` call PLUS + those actors that were previously restored via other remote requests. + The cached set of such previously restored actors will be erased in this + call. + """ + # Collect recently restored actors (from `self._fetch_result` calls other than + # the one triggered here via the `ping`). + restored_actors = list(self._restored_actors) + self._restored_actors.clear() + + # Probe all unhealthy actors via a simple `ping()`. + unhealthy_actor_ids = [ + actor_id + for actor_id in self.actor_ids() + if not self.is_actor_healthy(actor_id) + ] + # No unhealthy actors currently -> Return recently restored ones. + if not unhealthy_actor_ids: + return restored_actors + + # Some unhealthy actors -> `ping()` all of them to trigger a new fetch and + # capture all restored ones. + remote_results = self.foreach_actor( + func=lambda actor: actor.ping(), + remote_actor_ids=unhealthy_actor_ids, + healthy_only=False, # We specifically want to ping unhealthy actors. + timeout_seconds=timeout_seconds, + mark_healthy=mark_healthy, + ) + + # Return previously restored actors AND actors restored via the `ping()` call. + return restored_actors + [ + result.actor_id for result in remote_results if result.ok + ] + + def _call_actors( + self, + func: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], + *, + remote_actor_ids: List[int] = None, + ) -> List[ray.ObjectRef]: + """Apply functions on a list of remote actors. + + Args: + func: A single, or a list of Callables, that get applied on the list + of specified remote actors. + remote_actor_ids: Apply func on this selected set of remote actors. + + Returns: + A list of ObjectRefs returned from the remote calls. + """ + if isinstance(func, list): + assert len(remote_actor_ids) == len( + func + ), "Funcs must have the same number of callables as actor indices." + + if remote_actor_ids is None: + remote_actor_ids = self.actor_ids() + + if isinstance(func, list): + calls = [ + self._actors[i].apply.remote(f) for i, f in zip(remote_actor_ids, func) + ] + else: + calls = [self._actors[i].apply.remote(func) for i in remote_actor_ids] + + return calls + + @DeveloperAPI + def _fetch_result( + self, + *, + remote_actor_ids: List[int], + remote_calls: List[ray.ObjectRef], + tags: List[str], + timeout_seconds: Optional[float] = None, + return_obj_refs: bool = False, + mark_healthy: bool = False, + ) -> Tuple[List[ray.ObjectRef], RemoteCallResults]: + """Try fetching results from remote actor calls. + + Mark whether an actor is healthy or not accordingly. + + Args: + remote_actor_ids: IDs of the actors these remote + calls were fired against. + remote_calls: List of remote calls to fetch. + tags: List of tags used for identifying the remote calls. + timeout_seconds: Timeout (in sec) for the ray.wait() call. Default is None, + meaning wait indefinitely for all results. + return_obj_refs: Whether to return ObjectRef instead of actual results. + mark_healthy: Whether to mark certain actors healthy based on the results + of these remote calls. Useful, for example, to make sure actors + do not come back without proper state restoration. + + Returns: + A list of ready ObjectRefs mapping to the results of those calls. + """ + # Notice that we do not return the refs to any unfinished calls to the + # user, since it is not safe to handle such remote actor calls outside the + # context of this actor manager. These requests are simply dropped. + timeout = float(timeout_seconds) if timeout_seconds is not None else None + + # This avoids calling ray.init() in the case of 0 remote calls. + # This is useful if the number of remote workers is 0. + if not remote_calls: + return [], RemoteCallResults() + + readies, _ = ray.wait( + remote_calls, + num_returns=len(remote_calls), + timeout=timeout, + # Make sure remote results are fetched locally in parallel. + fetch_local=not return_obj_refs, + ) + + # Remote data should already be fetched to local object store at this point. + remote_results = RemoteCallResults() + for ready in readies: + # Find the corresponding actor ID for this remote call. + actor_id = remote_actor_ids[remote_calls.index(ready)] + tag = tags[remote_calls.index(ready)] + + # If caller wants ObjectRefs, return directly without resolving. + if return_obj_refs: + remote_results.add_result(actor_id, ResultOrError(result=ready), tag) + continue + + # Try getting the ready results. + try: + result = ray.get(ready) + + # Any error type other than `RayError` happening during ray.get() -> + # Throw exception right here (we don't know how to handle these non-remote + # worker issues and should therefore crash). + except RayError as e: + # Return error to the user. + remote_results.add_result(actor_id, ResultOrError(error=e), tag) + + # Mark the actor as unhealthy, take it out of service, and wait for + # Ray Core to restore it. + if self.is_actor_healthy(actor_id): + logger.error( + f"Ray error ({str(e)}), taking actor {actor_id} out of service." + ) + self.set_actor_state(actor_id, healthy=False) + + # If no errors, add result to `RemoteCallResults` to be returned. + else: + # Return valid result to the user. + remote_results.add_result(actor_id, ResultOrError(result=result), tag) + + # Actor came back from an unhealthy state. Mark this actor as healthy + # and add it to our healthy set. + if mark_healthy and not self.is_actor_healthy(actor_id): + logger.warning( + f"Bringing previously unhealthy, now-healthy actor {actor_id} " + "back into service." + ) + self.set_actor_state(actor_id, healthy=True) + self._num_actor_restarts += 1 + + # Make sure, to-be-returned results are sound. + assert len(readies) == len(remote_results) + + return readies, remote_results + + def _filter_func_and_remote_actor_id_by_state( + self, + func: Union[Callable[[Any], Any], List[Callable[[Any], Any]]], + remote_actor_ids: List[int], + ): + """Filter out func and remote worker ids by actor state. + + Args: + func: A single, or a list of Callables. + remote_actor_ids: IDs of potential remote workers to apply func on. + + Returns: + A tuple of (filtered func, filtered remote worker ids). + """ + if isinstance(func, list): + assert len(remote_actor_ids) == len( + func + ), "Func must have the same number of callables as remote actor ids." + # We are given a list of functions to apply. + # Need to filter the functions together with worker IDs. + temp_func = [] + temp_remote_actor_ids = [] + for f, i in zip(func, remote_actor_ids): + if self.is_actor_healthy(i): + temp_func.append(f) + temp_remote_actor_ids.append(i) + func = temp_func + remote_actor_ids = temp_remote_actor_ids + else: + # Simply filter the worker IDs. + remote_actor_ids = [i for i in remote_actor_ids if self.is_actor_healthy(i)] + + return func, remote_actor_ids + + def _filter_calls_by_tag( + self, tags: Union[str, List[str], Tuple[str]] + ) -> Tuple[List[ray.ObjectRef], List[ActorHandle], List[str]]: + """Return all the in flight requests that match the given tags, if any. + + Args: + tags: A str or a list/tuple of str. If tags is empty, return all the in + flight requests. + + Returns: + A tuple consisting of a list of the remote calls that match the tag(s), + a list of the corresponding remote actor IDs for these calls (same length), + and a list of the tags corresponding to these calls (same length). + """ + if isinstance(tags, str): + tags = {tags} + elif isinstance(tags, (list, tuple)): + tags = set(tags) + else: + raise ValueError( + f"tags must be either a str or a list/tuple of str, got {type(tags)}." + ) + remote_calls = [] + remote_actor_ids = [] + valid_tags = [] + for call, (tag, actor_id) in self._in_flight_req_to_actor_id.items(): + # the default behavior is to return all ready results. + if len(tags) == 0 or tag in tags: + remote_calls.append(call) + remote_actor_ids.append(actor_id) + valid_tags.append(tag) + + return remote_calls, remote_actor_ids, valid_tags + + def _remove_async_state(self, actor_id: int): + """Remove internal async state of for a given actor. + + This is called when an actor is removed from the pool or being marked + unhealthy. + + Args: + actor_id: The id of the actor. + """ + # Remove any outstanding async requests for this actor. + # Use `list` here to not change a looped generator while we mutate the + # underlying dict. + for id, req in list(self._in_flight_req_to_actor_id.items()): + if id == actor_id: + del self._in_flight_req_to_actor_id[req] + + def actors(self): + # TODO(jungong) : remove this API once EnvRunnerGroup.remote_workers() + # and EnvRunnerGroup._remote_workers() are removed. + return self._actors diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/compression.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/compression.py new file mode 100644 index 0000000000000000000000000000000000000000..cd5e3e6975b4535b6a10a713d85d03ba2093c180 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/compression.py @@ -0,0 +1,90 @@ +from ray.rllib.utils.annotations import DeveloperAPI + +import logging +import time +import base64 +import numpy as np +from ray import cloudpickle as pickle + +logger = logging.getLogger(__name__) + +try: + import lz4.frame + + LZ4_ENABLED = True +except ImportError: + logger.warning( + "lz4 not available, disabling sample compression. " + "This will significantly impact RLlib performance. " + "To install lz4, run `pip install lz4`." + ) + LZ4_ENABLED = False + + +@DeveloperAPI +def compression_supported(): + return LZ4_ENABLED + + +@DeveloperAPI +def pack(data): + if LZ4_ENABLED: + data = pickle.dumps(data) + data = lz4.frame.compress(data) + # TODO(ekl) we shouldn't need to base64 encode this data, but this + # seems to not survive a transfer through the object store if we don't. + data = base64.b64encode(data).decode("ascii") + return data + + +@DeveloperAPI +def pack_if_needed(data): + if isinstance(data, np.ndarray): + data = pack(data) + return data + + +@DeveloperAPI +def unpack(data): + if LZ4_ENABLED: + data = base64.b64decode(data) + data = lz4.frame.decompress(data) + data = pickle.loads(data) + return data + + +@DeveloperAPI +def unpack_if_needed(data): + if is_compressed(data): + data = unpack(data) + return data + + +@DeveloperAPI +def is_compressed(data): + return isinstance(data, bytes) or isinstance(data, str) + + +# Intel(R) Core(TM) i7-4600U CPU @ 2.10GHz +# Compression speed: 753.664 MB/s +# Compression ratio: 87.4839812046 +# Decompression speed: 910.9504 MB/s +if __name__ == "__main__": + size = 32 * 80 * 80 * 4 + data = np.ones(size).reshape((32, 80, 80, 4)) + + count = 0 + start = time.time() + while time.time() - start < 1: + pack(data) + count += 1 + compressed = pack(data) + print("Compression speed: {} MB/s".format(count * size * 4 / 1e6)) + print("Compression ratio: {}".format(round(size * 4 / len(compressed), 2))) + + count = 0 + start = time.time() + while time.time() - start < 1: + unpack(compressed) + count += 1 + print("Decompression speed: {} MB/s".format(count * size * 4 / 1e6)) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/filter.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/filter.py new file mode 100644 index 0000000000000000000000000000000000000000..d969abddb119e970621e0d95a2f00cadbf5de95a --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/filter.py @@ -0,0 +1,420 @@ +import logging +import threading + +import numpy as np +import tree # pip install dm_tree + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.numpy import SMALL_NUMBER +from ray.rllib.utils.typing import TensorStructType +from ray.rllib.utils.serialization import _serialize_ndarray, _deserialize_ndarray +from ray.rllib.utils.deprecation import deprecation_warning + +logger = logging.getLogger(__name__) + + +@OldAPIStack +class Filter: + """Processes input, possibly statefully.""" + + def apply_changes(self, other: "Filter", *args, **kwargs) -> None: + """Updates self with "new state" from other filter.""" + raise NotImplementedError + + def copy(self) -> "Filter": + """Creates a new object with same state as self. + + Returns: + A copy of self. + """ + raise NotImplementedError + + def sync(self, other: "Filter") -> None: + """Copies all state from other filter to self.""" + raise NotImplementedError + + def reset_buffer(self) -> None: + """Creates copy of current state and resets accumulated state""" + raise NotImplementedError + + def as_serializable(self) -> "Filter": + raise NotImplementedError + + @Deprecated(new="Filter.reset_buffer()", error=True) + def clear_buffer(self): + pass + + +@OldAPIStack +class NoFilter(Filter): + is_concurrent = True + + def __call__(self, x: TensorStructType, update=True): + # Process no further if already np.ndarray, dict, or tuple. + if isinstance(x, (np.ndarray, dict, tuple)): + return x + + try: + return np.asarray(x) + except Exception: + raise ValueError("Failed to convert to array", x) + + def apply_changes(self, other: "NoFilter", *args, **kwargs) -> None: + pass + + def copy(self) -> "NoFilter": + return self + + def sync(self, other: "NoFilter") -> None: + pass + + def reset_buffer(self) -> None: + pass + + def as_serializable(self) -> "NoFilter": + return self + + +# http://www.johndcook.com/blog/standard_deviation/ +@OldAPIStack +class RunningStat: + def __init__(self, shape=()): + self.num_pushes = 0 + self.mean_array = np.zeros(shape) + self.std_array = np.zeros(shape) + + def copy(self): + other = RunningStat() + # TODO: Remove these safe-guards if not needed anymore. + other.num_pushes = self.num_pushes if hasattr(self, "num_pushes") else self._n + other.mean_array = ( + np.copy(self.mean_array) + if hasattr(self, "mean_array") + else np.copy(self._M) + ) + other.std_array = ( + np.copy(self.std_array) if hasattr(self, "std_array") else np.copy(self._S) + ) + return other + + def push(self, x): + x = np.asarray(x) + # Unvectorized update of the running statistics. + if x.shape != self.mean_array.shape: + raise ValueError( + "Unexpected input shape {}, expected {}, value = {}".format( + x.shape, self.mean_array.shape, x + ) + ) + self.num_pushes += 1 + if self.num_pushes == 1: + self.mean_array[...] = x + else: + delta = x - self.mean_array + self.mean_array[...] += delta / self.num_pushes + self.std_array[...] += ( + (delta / self.num_pushes) * delta * (self.num_pushes - 1) + ) + + def update(self, other): + n1 = float(self.num_pushes) + n2 = float(other.num_pushes) + n = n1 + n2 + if n == 0: + # Avoid divide by zero, which creates nans + return + delta = self.mean_array - other.mean_array + delta2 = delta * delta + m = (n1 * self.mean_array + n2 * other.mean_array) / n + s = self.std_array + other.std_array + (delta2 / n) * n1 * n2 + self.num_pushes = n + self.mean_array = m + self.std_array = s + + def __repr__(self): + return "(n={}, mean_mean={}, mean_std={})".format( + self.n, np.mean(self.mean), np.mean(self.std) + ) + + @property + def n(self): + return self.num_pushes + + @property + def mean(self): + return self.mean_array + + @property + def var(self): + return ( + self.std_array / (self.num_pushes - 1) + if self.num_pushes > 1 + else np.square(self.mean_array) + ).astype(np.float32) + + @property + def std(self): + return np.sqrt(self.var) + + @property + def shape(self): + return self.mean_array.shape + + def to_state(self): + return { + "num_pushes": self.num_pushes, + "mean_array": _serialize_ndarray(self.mean_array), + "std_array": _serialize_ndarray(self.std_array), + } + + @staticmethod + def from_state(state): + running_stats = RunningStat() + running_stats.num_pushes = state["num_pushes"] + running_stats.mean_array = _deserialize_ndarray(state["mean_array"]) + running_stats.std_array = _deserialize_ndarray(state["std_array"]) + return running_stats + + +@OldAPIStack +class MeanStdFilter(Filter): + """Keeps track of a running mean for seen states""" + + is_concurrent = False + + def __init__(self, shape, demean=True, destd=True, clip=10.0): + self.shape = shape + # We don't have a preprocessor, if shape is None (Discrete) or + # flat_shape is Tuple[np.ndarray] or Dict[str, np.ndarray] + # (complex inputs). + flat_shape = tree.flatten(self.shape) + self.no_preprocessor = shape is None or ( + isinstance(self.shape, (dict, tuple)) + and len(flat_shape) > 0 + and isinstance(flat_shape[0], np.ndarray) + ) + # If preprocessing (flattening dicts/tuples), make sure shape + # is an np.ndarray, so we don't confuse it with a complex Tuple + # space's shape structure (which is a Tuple[np.ndarray]). + if not self.no_preprocessor: + self.shape = np.array(self.shape) + self.demean = demean + self.destd = destd + self.clip = clip + # Running stats. + self.running_stats = tree.map_structure(lambda s: RunningStat(s), self.shape) + + # In distributed rollouts, each worker sees different states. + # The buffer is used to keep track of deltas amongst all the + # observation filters. + self.buffer = None + self.reset_buffer() + + def reset_buffer(self) -> None: + self.buffer = tree.map_structure(lambda s: RunningStat(s), self.shape) + + def apply_changes( + self, other: "MeanStdFilter", with_buffer: bool = False, *args, **kwargs + ) -> None: + """Applies updates from the buffer of another filter. + + Args: + other: Other filter to apply info from + with_buffer: Flag for specifying if the buffer should be + copied from other. + + .. testcode:: + :skipif: True + + a = MeanStdFilter(()) + a(1) + a(2) + print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) + + .. testoutput:: + + [2, 1.5, 2] + + .. testcode:: + :skipif: True + + b = MeanStdFilter(()) + b(10) + a.apply_changes(b, with_buffer=False) + print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) + + .. testoutput:: + + [3, 4.333333333333333, 2] + + .. testcode:: + :skipif: True + + a.apply_changes(b, with_buffer=True) + print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) + + .. testoutput:: + + [4, 5.75, 1] + """ + tree.map_structure( + lambda rs, other_rs: rs.update(other_rs), self.running_stats, other.buffer + ) + if with_buffer: + self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer) + + def copy(self) -> "MeanStdFilter": + """Returns a copy of `self`.""" + other = MeanStdFilter(self.shape) + other.sync(self) + return other + + def as_serializable(self) -> "MeanStdFilter": + return self.copy() + + def sync(self, other: "MeanStdFilter") -> None: + """Syncs all fields together from other filter. + + .. testcode:: + :skipif: True + + a = MeanStdFilter(()) + a(1) + a(2) + print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) + + .. testoutput:: + + [2, array(1.5), 2] + + .. testcode:: + :skipif: True + + b = MeanStdFilter(()) + b(10) + print([b.running_stats.n, b.running_stats.mean, b.buffer.n]) + + .. testoutput:: + + [1, array(10.0), 1] + + .. testcode:: + :skipif: True + + a.sync(b) + print([a.running_stats.n, a.running_stats.mean, a.buffer.n]) + + .. testoutput:: + + [1, array(10.0), 1] + """ + self.demean = other.demean + self.destd = other.destd + self.clip = other.clip + self.running_stats = tree.map_structure( + lambda rs: rs.copy(), other.running_stats + ) + self.buffer = tree.map_structure(lambda b: b.copy(), other.buffer) + + def __call__(self, x: TensorStructType, update: bool = True) -> TensorStructType: + if self.no_preprocessor: + x = tree.map_structure(lambda x_: np.asarray(x_), x) + else: + x = np.asarray(x) + + def _helper(x, rs, buffer, shape): + # Discrete|MultiDiscrete spaces -> No normalization. + if shape is None: + return x + + # Keep dtype as is througout this filter. + orig_dtype = x.dtype + + if update: + if len(x.shape) == len(rs.shape) + 1: + # The vectorized case. + for i in range(x.shape[0]): + rs.push(x[i]) + buffer.push(x[i]) + else: + # The unvectorized case. + rs.push(x) + buffer.push(x) + if self.demean: + x = x - rs.mean + if self.destd: + x = x / (rs.std + SMALL_NUMBER) + if self.clip: + x = np.clip(x, -self.clip, self.clip) + return x.astype(orig_dtype) + + if self.no_preprocessor: + return tree.map_structure_up_to( + x, _helper, x, self.running_stats, self.buffer, self.shape + ) + else: + return _helper(x, self.running_stats, self.buffer, self.shape) + + +@OldAPIStack +class ConcurrentMeanStdFilter(MeanStdFilter): + is_concurrent = True + + def __init__(self, *args, **kwargs): + super(ConcurrentMeanStdFilter, self).__init__(*args, **kwargs) + deprecation_warning( + old="ConcurrentMeanStdFilter", + error=False, + help="ConcurrentMeanStd filters are only used for testing and will " + "therefore be deprecated in the course of moving to the " + "Connetors API, where testing of filters will be done by other " + "means.", + ) + + self._lock = threading.RLock() + + def lock_wrap(func): + def wrapper(*args, **kwargs): + with self._lock: + return func(*args, **kwargs) + + return wrapper + + self.__getattribute__ = lock_wrap(self.__getattribute__) + + def as_serializable(self) -> "MeanStdFilter": + """Returns non-concurrent version of current class""" + other = MeanStdFilter(self.shape) + other.sync(self) + return other + + def copy(self) -> "ConcurrentMeanStdFilter": + """Returns a copy of Filter.""" + other = ConcurrentMeanStdFilter(self.shape) + other.sync(self) + return other + + def __repr__(self) -> str: + return "ConcurrentMeanStdFilter({}, {}, {}, {}, {}, {})".format( + self.shape, + self.demean, + self.destd, + self.clip, + self.running_stats, + self.buffer, + ) + + +@OldAPIStack +def get_filter(filter_config, shape): + if filter_config == "MeanStdFilter": + return MeanStdFilter(shape, clip=None) + elif filter_config == "ConcurrentMeanStdFilter": + return ConcurrentMeanStdFilter(shape, clip=None) + elif filter_config == "NoFilter": + return NoFilter() + elif callable(filter_config): + return filter_config(shape) + else: + raise Exception("Unknown observation_filter: " + str(filter_config)) diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/minibatch_utils.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/minibatch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..57a849da796fcab36a4cb3afd100dcb5d87ff465 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/minibatch_utils.py @@ -0,0 +1,331 @@ +import math +from typing import List, Optional + +from ray.rllib.policy.sample_batch import MultiAgentBatch, concat_samples +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import DeveloperAPI +from ray.rllib.utils.typing import EpisodeType + + +@DeveloperAPI +class MiniBatchIteratorBase: + """The base class for all minibatch iterators.""" + + def __init__( + self, + batch: MultiAgentBatch, + *, + num_epochs: int = 1, + shuffle_batch_per_epoch: bool = True, + minibatch_size: int, + num_total_minibatches: int = 0, + ) -> None: + """Initializes a MiniBatchIteratorBase instance. + + Args: + batch: The input multi-agent batch. + num_epochs: The number of complete passes over the entire train batch. Each + pass might be further split into n minibatches (if `minibatch_size` + provided). The train batch is generated from the given `episodes` + through the Learner connector pipeline. + minibatch_size: The size of minibatches to use to further split the train + batch into per epoch. The train batch is generated from the given + `episodes` through the Learner connector pipeline. + num_total_minibatches: The total number of minibatches to loop through + (over all `num_epochs` epochs). It's only required to set this to != 0 + in multi-agent + multi-GPU situations, in which the MultiAgentEpisodes + themselves are roughly sharded equally, however, they might contain + SingleAgentEpisodes with very lopsided length distributions. Thus, + without this fixed, pre-computed value, one Learner might go through a + different number of minibatche passes than others causing a deadlock. + """ + pass + + +@DeveloperAPI +class MiniBatchCyclicIterator(MiniBatchIteratorBase): + """This implements a simple multi-agent minibatch iterator. + + This iterator will split the input multi-agent batch into minibatches where the + size of batch for each module_id (aka policy_id) is equal to minibatch_size. If the + input batch is smaller than minibatch_size, then the iterator will cycle through + the batch until it has covered `num_epochs` epochs. + """ + + def __init__( + self, + batch: MultiAgentBatch, + *, + num_epochs: int = 1, + minibatch_size: int, + shuffle_batch_per_epoch: bool = True, + num_total_minibatches: int = 0, + ) -> None: + """Initializes a MiniBatchCyclicIterator instance.""" + super().__init__( + batch, + num_epochs=num_epochs, + minibatch_size=minibatch_size, + shuffle_batch_per_epoch=shuffle_batch_per_epoch, + ) + + self._batch = batch + self._minibatch_size = minibatch_size + self._num_epochs = num_epochs + self._shuffle_batch_per_epoch = shuffle_batch_per_epoch + + # mapping from module_id to the start index of the batch + self._start = {mid: 0 for mid in batch.policy_batches.keys()} + # mapping from module_id to the number of epochs covered for each module_id + self._num_covered_epochs = {mid: 0 for mid in batch.policy_batches.keys()} + + self._minibatch_count = 0 + self._num_total_minibatches = num_total_minibatches + + def __iter__(self): + while ( + # Make sure each item in the total batch gets at least iterated over + # `self._num_epochs` times. + ( + self._num_total_minibatches == 0 + and min(self._num_covered_epochs.values()) < self._num_epochs + ) + # Make sure we reach at least the given minimum number of mini-batches. + or ( + self._num_total_minibatches > 0 + and self._minibatch_count < self._num_total_minibatches + ) + ): + minibatch = {} + for module_id, module_batch in self._batch.policy_batches.items(): + + if len(module_batch) == 0: + raise ValueError( + f"The batch for module_id {module_id} is empty! " + "This will create an infinite loop because we need to cover " + "the same number of samples for each module_id." + ) + s = self._start[module_id] # start + + # TODO (sven): Fix this bug for LSTMs: + # In an RNN-setting, the Learner connector already has zero-padded + # and added a timerank to the batch. Thus, n_step would still be based + # on the BxT dimension, rather than the new B dimension (excluding T), + # which then leads to minibatches way too large. + # However, changing this already would break APPO/IMPALA w/o LSTMs as + # these setups require sequencing, BUT their batches are not yet time- + # ranked (this is done only in their loss functions via the + # `make_time_major` utility). + n_steps = self._minibatch_size + + samples_to_concat = [] + + # get_len is a function that returns the length of a batch + # if we are not slicing the batch in the batch dimension B, then + # the length of the batch is simply the length of the batch + # o.w the length of the batch is the length list of seq_lens. + if module_batch._slice_seq_lens_in_B: + assert module_batch.get(SampleBatch.SEQ_LENS) is not None, ( + "MiniBatchCyclicIterator requires SampleBatch.SEQ_LENS" + "to be present in the batch for slicing a batch in the batch " + "dimension B." + ) + + def get_len(b): + return len(b[SampleBatch.SEQ_LENS]) + + n_steps = int( + get_len(module_batch) + * (self._minibatch_size / len(module_batch)) + ) + + else: + + def get_len(b): + return len(b) + + # Cycle through the batch until we have enough samples. + while s + n_steps >= get_len(module_batch): + sample = module_batch[s:] + samples_to_concat.append(sample) + len_sample = get_len(sample) + assert len_sample > 0, "Length of a sample must be > 0!" + n_steps -= len_sample + s = 0 + self._num_covered_epochs[module_id] += 1 + # Shuffle the individual single-agent batch, if required. + # This should happen once per minibatch iteration in order to make + # each iteration go through a different set of minibatches. + if self._shuffle_batch_per_epoch: + module_batch.shuffle() + + e = s + n_steps # end + if e > s: + samples_to_concat.append(module_batch[s:e]) + + # concatenate all the samples, we should have minibatch_size of sample + # after this step + minibatch[module_id] = concat_samples(samples_to_concat) + # roll minibatch to zero when we reach the end of the batch + self._start[module_id] = e + + # Note (Kourosh): env_steps is the total number of env_steps that this + # multi-agent batch is covering. It should be simply inherited from the + # original multi-agent batch. + minibatch = MultiAgentBatch(minibatch, len(self._batch)) + yield minibatch + + self._minibatch_count += 1 + + +class MiniBatchDummyIterator(MiniBatchIteratorBase): + def __init__(self, batch: MultiAgentBatch, **kwargs): + super().__init__(batch, **kwargs) + self._batch = batch + + def __iter__(self): + yield self._batch + + +@DeveloperAPI +class ShardBatchIterator: + """Iterator for sharding batch into num_shards batches. + + Args: + batch: The input multi-agent batch. + num_shards: The number of shards to split the batch into. + + Yields: + A MultiAgentBatch of size len(batch) / num_shards. + """ + + def __init__(self, batch: MultiAgentBatch, num_shards: int): + self._batch = batch + self._num_shards = num_shards + + def __iter__(self): + for i in range(self._num_shards): + # TODO (sven): The following way of sharding a multi-agent batch destroys + # the relationship of the different agents' timesteps to each other. + # Thus, in case the algorithm requires agent-synchronized data (aka. + # "lockstep"), the `ShardBatchIterator` cannot be used. + batch_to_send = {} + for pid, sub_batch in self._batch.policy_batches.items(): + batch_size = math.ceil(len(sub_batch) / self._num_shards) + start = batch_size * i + end = min(start + batch_size, len(sub_batch)) + batch_to_send[pid] = sub_batch[int(start) : int(end)] + # TODO (Avnish): int(batch_size) ? How should we shard MA batches really? + new_batch = MultiAgentBatch(batch_to_send, int(batch_size)) + yield new_batch + + +@DeveloperAPI +class ShardEpisodesIterator: + """Iterator for sharding a list of Episodes into `num_shards` lists of Episodes.""" + + def __init__( + self, + episodes: List[EpisodeType], + num_shards: int, + len_lookback_buffer: Optional[int] = None, + ): + """Initializes a ShardEpisodesIterator instance. + + Args: + episodes: The input list of Episodes. + num_shards: The number of shards to split the episodes into. + len_lookback_buffer: An optional length of a lookback buffer to enforce + on the returned shards. When spitting an episode, the second piece + might need a lookback buffer (into the first piece) depending on the + user's settings. + """ + self._episodes = sorted(episodes, key=len, reverse=True) + self._num_shards = num_shards + self._len_lookback_buffer = len_lookback_buffer + self._total_length = sum(len(e) for e in episodes) + self._target_lengths = [0 for _ in range(self._num_shards)] + remaining_length = self._total_length + for s in range(self._num_shards): + len_ = remaining_length // (num_shards - s) + self._target_lengths[s] = len_ + remaining_length -= len_ + + def __iter__(self) -> List[EpisodeType]: + """Runs one iteration through this sharder. + + Yields: + A sub-list of Episodes of size roughly `len(episodes) / num_shards`. The + yielded sublists might have slightly different total sums of episode + lengths, in order to not have to drop even a single timestep. + """ + sublists = [[] for _ in range(self._num_shards)] + lengths = [0 for _ in range(self._num_shards)] + episode_index = 0 + + while episode_index < len(self._episodes): + episode = self._episodes[episode_index] + min_index = lengths.index(min(lengths)) + + # Add the whole episode if it fits within the target length + if lengths[min_index] + len(episode) <= self._target_lengths[min_index]: + sublists[min_index].append(episode) + lengths[min_index] += len(episode) + episode_index += 1 + # Otherwise, slice the episode + else: + remaining_length = self._target_lengths[min_index] - lengths[min_index] + if remaining_length > 0: + slice_part, remaining_part = ( + # Note that the first slice will automatically "inherit" the + # lookback buffer size of the episode. + episode[:remaining_length], + # However, the second slice might need a user defined lookback + # buffer (into the first slice). + episode.slice( + slice(remaining_length, None), + len_lookback_buffer=self._len_lookback_buffer, + ), + ) + sublists[min_index].append(slice_part) + lengths[min_index] += len(slice_part) + self._episodes[episode_index] = remaining_part + else: + assert remaining_length == 0 + sublists[min_index].append(episode) + episode_index += 1 + + for sublist in sublists: + yield sublist + + +@DeveloperAPI +class ShardObjectRefIterator: + """Iterator for sharding a list of ray ObjectRefs into num_shards sub-lists. + + Args: + object_refs: The input list of ray ObjectRefs. + num_shards: The number of shards to split the references into. + + Yields: + A sub-list of ray ObjectRefs with lengths as equal as possible. + """ + + def __init__(self, object_refs, num_shards: int): + self._object_refs = object_refs + self._num_shards = num_shards + + def __iter__(self): + # Calculate the size of each sublist + n = len(self._object_refs) + sublist_size = n // self._num_shards + remaining_elements = n % self._num_shards + + start = 0 + for i in range(self._num_shards): + # Determine the end index for the current sublist + end = start + sublist_size + (1 if i < remaining_elements else 0) + # Append the sublist to the result + yield self._object_refs[start:end] + # Update the start index for the next sublist + start = end diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/policy.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/policy.py new file mode 100644 index 0000000000000000000000000000000000000000..a5b6b2ccfda6512fbbd5e148304cecd1859cced6 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/policy.py @@ -0,0 +1,303 @@ +import gymnasium as gym +import logging +import numpy as np +from typing import ( + Callable, + Dict, + List, + Optional, + Tuple, + Type, + Union, + TYPE_CHECKING, +) +import tree # pip install dm_tree + + +import ray.cloudpickle as pickle +from ray.rllib.core.rl_module import validate_module_id +from ray.rllib.models.preprocessors import ATARI_OBS_SHAPE +from ray.rllib.policy.policy import PolicySpec +from ray.rllib.policy.sample_batch import SampleBatch +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.deprecation import Deprecated +from ray.rllib.utils.framework import try_import_tf +from ray.rllib.utils.typing import ( + ActionConnectorDataType, + AgentConnectorDataType, + AgentConnectorsOutput, + PartialAlgorithmConfigDict, + PolicyState, + TensorStructType, + TensorType, +) +from ray.util import log_once + +if TYPE_CHECKING: + from ray.rllib.policy.policy import Policy + +logger = logging.getLogger(__name__) + +tf1, tf, tfv = try_import_tf() + + +@OldAPIStack +def create_policy_for_framework( + policy_id: str, + policy_class: Type["Policy"], + merged_config: PartialAlgorithmConfigDict, + observation_space: gym.Space, + action_space: gym.Space, + worker_index: int = 0, + session_creator: Optional[Callable[[], "tf1.Session"]] = None, + seed: Optional[int] = None, +): + """Framework-specific policy creation logics. + + Args: + policy_id: Policy ID. + policy_class: Policy class type. + merged_config: Complete policy config. + observation_space: Observation space of env. + action_space: Action space of env. + worker_index: Index of worker holding this policy. Default is 0. + session_creator: An optional tf1.Session creation callable. + seed: Optional random seed. + """ + from ray.rllib.algorithms.algorithm_config import AlgorithmConfig + + if isinstance(merged_config, AlgorithmConfig): + merged_config = merged_config.to_dict() + + # add policy_id to merged_config + merged_config["__policy_id"] = policy_id + + framework = merged_config.get("framework", "tf") + # Tf. + if framework in ["tf2", "tf"]: + var_scope = policy_id + (f"_wk{worker_index}" if worker_index else "") + # For tf static graph, build every policy in its own graph + # and create a new session for it. + if framework == "tf": + with tf1.Graph().as_default(): + # Session creator function provided manually -> Use this one to + # create the tf1 session. + if session_creator: + sess = session_creator() + # Use a default session creator, based only on our `tf_session_args` in + # the config. + else: + sess = tf1.Session( + config=tf1.ConfigProto(**merged_config["tf_session_args"]) + ) + + with sess.as_default(): + # Set graph-level seed. + if seed is not None: + tf1.set_random_seed(seed) + with tf1.variable_scope(var_scope): + return policy_class( + observation_space, action_space, merged_config + ) + # For tf-eager: no graph, no session. + else: + with tf1.variable_scope(var_scope): + return policy_class(observation_space, action_space, merged_config) + # Non-tf: No graph, no session. + else: + return policy_class(observation_space, action_space, merged_config) + + +@OldAPIStack +def parse_policy_specs_from_checkpoint( + path: str, +) -> Tuple[PartialAlgorithmConfigDict, Dict[str, PolicySpec], Dict[str, PolicyState]]: + """Read and parse policy specifications from a checkpoint file. + + Args: + path: Path to a policy checkpoint. + + Returns: + A tuple of: base policy config, dictionary of policy specs, and + dictionary of policy states. + """ + with open(path, "rb") as f: + checkpoint_dict = pickle.load(f) + # Policy data is contained as a serialized binary blob under their + # ID keys. + w = pickle.loads(checkpoint_dict["worker"]) + + policy_config = w["policy_config"] + policy_states = w.get("policy_states", w["state"]) + serialized_policy_specs = w["policy_specs"] + policy_specs = { + id: PolicySpec.deserialize(spec) for id, spec in serialized_policy_specs.items() + } + + return policy_config, policy_specs, policy_states + + +@OldAPIStack +def local_policy_inference( + policy: "Policy", + env_id: str, + agent_id: str, + obs: TensorStructType, + reward: Optional[float] = None, + terminated: Optional[bool] = None, + truncated: Optional[bool] = None, + info: Optional[Dict] = None, + explore: bool = None, + timestep: Optional[int] = None, +) -> TensorStructType: + """Run a connector enabled policy using environment observation. + + policy_inference manages policy and agent/action connectors, + so the user does not have to care about RNN state buffering or + extra fetch dictionaries. + Note that connectors are intentionally run separately from + compute_actions_from_input_dict(), so we can have the option + of running per-user connectors on the client side in a + server-client deployment. + + Args: + policy: Policy object used in inference. + env_id: Environment ID. RLlib builds environments' trajectories internally with + connectors based on this, i.e. one trajectory per (env_id, agent_id) tuple. + agent_id: Agent ID. RLlib builds agents' trajectories internally with connectors + based on this, i.e. one trajectory per (env_id, agent_id) tuple. + obs: Environment observation to base the action on. + reward: Reward that is potentially used during inference. If not required, + may be left empty. Some policies have ViewRequirements that require this. + This can be set to zero at the first inference step - for example after + calling gmy.Env.reset. + terminated: `Terminated` flag that is potentially used during inference. If not + required, may be left None. Some policies have ViewRequirements that + require this extra information. + truncated: `Truncated` flag that is potentially used during inference. If not + required, may be left None. Some policies have ViewRequirements that + require this extra information. + info: Info that is potentially used durin inference. If not required, + may be left empty. Some policies have ViewRequirements that require this. + explore: Whether to pick an exploitation or exploration action + (default: None -> use self.config["explore"]). + timestep: The current (sampling) time step. + + Returns: + List of outputs from policy forward pass. + """ + assert ( + policy.agent_connectors + ), "policy_inference only works with connector enabled policies." + + __check_atari_obs_space(obs) + + # Put policy in inference mode, so we don't spend time on training + # only transformations. + policy.agent_connectors.in_eval() + policy.action_connectors.in_eval() + + # TODO(jungong) : support multiple env, multiple agent inference. + input_dict = {SampleBatch.NEXT_OBS: obs} + if reward is not None: + input_dict[SampleBatch.REWARDS] = reward + if terminated is not None: + input_dict[SampleBatch.TERMINATEDS] = terminated + if truncated is not None: + input_dict[SampleBatch.TRUNCATEDS] = truncated + if info is not None: + input_dict[SampleBatch.INFOS] = info + + acd_list: List[AgentConnectorDataType] = [ + AgentConnectorDataType(env_id, agent_id, input_dict) + ] + ac_outputs: List[AgentConnectorsOutput] = policy.agent_connectors(acd_list) + outputs = [] + for ac in ac_outputs: + policy_output = policy.compute_actions_from_input_dict( + ac.data.sample_batch, + explore=explore, + timestep=timestep, + ) + + # Note (Kourosh): policy output is batched, the AgentConnectorDataType should + # not be batched during inference. This is the assumption made in AgentCollector + policy_output = tree.map_structure(lambda x: x[0], policy_output) + + action_connector_data = ActionConnectorDataType( + env_id, agent_id, ac.data.raw_dict, policy_output + ) + + if policy.action_connectors: + acd = policy.action_connectors(action_connector_data) + actions = acd.output + else: + actions = policy_output[0] + + outputs.append(actions) + + # Notify agent connectors with this new policy output. + # Necessary for state buffering agent connectors, for example. + policy.agent_connectors.on_policy_output(action_connector_data) + return outputs + + +@OldAPIStack +def compute_log_likelihoods_from_input_dict( + policy: "Policy", batch: Union[SampleBatch, Dict[str, TensorStructType]] +): + """Returns log likelihood for actions in given batch for policy. + + Computes likelihoods by passing the observations through the current + policy's `compute_log_likelihoods()` method + + Args: + batch: The SampleBatch or MultiAgentBatch to calculate action + log likelihoods from. This batch/batches must contain OBS + and ACTIONS keys. + + Returns: + The probabilities of the actions in the batch, given the + observations and the policy. + """ + num_state_inputs = 0 + for k in batch.keys(): + if k.startswith("state_in_"): + num_state_inputs += 1 + state_keys = ["state_in_{}".format(i) for i in range(num_state_inputs)] + log_likelihoods: TensorType = policy.compute_log_likelihoods( + actions=batch[SampleBatch.ACTIONS], + obs_batch=batch[SampleBatch.OBS], + state_batches=[batch[k] for k in state_keys], + prev_action_batch=batch.get(SampleBatch.PREV_ACTIONS), + prev_reward_batch=batch.get(SampleBatch.PREV_REWARDS), + actions_normalized=policy.config.get("actions_in_input_normalized", False), + ) + return log_likelihoods + + +@Deprecated(new="Policy.from_checkpoint([checkpoint path], [policy IDs]?)", error=True) +def load_policies_from_checkpoint(path, policy_ids=None): + pass + + +def __check_atari_obs_space(obs): + # TODO(Artur): Remove this after we have migrated deepmind style preprocessing into + # connectors (and don't auto-wrap in RW anymore) + if any( + o.shape == ATARI_OBS_SHAPE if isinstance(o, np.ndarray) else False + for o in tree.flatten(obs) + ): + if log_once("warn_about_possibly_non_wrapped_atari_env"): + logger.warning( + "The observation you fed into local_policy_inference() has " + "dimensions (210, 160, 3), which is the standard for atari " + "environments. If RLlib raises an error including a related " + "dimensionality mismatch, you may need to use " + "ray.rllib.env.wrappers.atari_wrappers.wrap_deepmind to wrap " + "you environment." + ) + + +# @OldAPIStack +validate_policy_id = validate_module_id diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/sgd.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/sgd.py new file mode 100644 index 0000000000000000000000000000000000000000..3e126c0a2f450147d9d3f9653ae0b320da778ac7 --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/sgd.py @@ -0,0 +1,136 @@ +"""Utils for minibatch SGD across multiple RLlib policies.""" + +import logging +import numpy as np +import random + +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.policy.sample_batch import SampleBatch, MultiAgentBatch +from ray.rllib.utils.metrics.learner_info import LearnerInfoBuilder + +logger = logging.getLogger(__name__) + + +@OldAPIStack +def standardized(array: np.ndarray): + """Normalize the values in an array. + + Args: + array (np.ndarray): Array of values to normalize. + + Returns: + array with zero mean and unit standard deviation. + """ + return (array - array.mean()) / max(1e-4, array.std()) + + +@OldAPIStack +def minibatches(samples: SampleBatch, sgd_minibatch_size: int, shuffle: bool = True): + """Return a generator yielding minibatches from a sample batch. + + Args: + samples: SampleBatch to split up. + sgd_minibatch_size: Size of minibatches to return. + shuffle: Whether to shuffle the order of the generated minibatches. + Note that in case of a non-recurrent policy, the incoming batch + is globally shuffled first regardless of this setting, before + the minibatches are generated from it! + + Yields: + SampleBatch: Each of size `sgd_minibatch_size`. + """ + if not sgd_minibatch_size: + yield samples + return + + if isinstance(samples, MultiAgentBatch): + raise NotImplementedError( + "Minibatching not implemented for multi-agent in simple mode" + ) + + if "state_in_0" not in samples and "state_out_0" not in samples: + samples.shuffle() + + all_slices = samples._get_slice_indices(sgd_minibatch_size) + data_slices, state_slices = all_slices + + if len(state_slices) == 0: + if shuffle: + random.shuffle(data_slices) + for i, j in data_slices: + yield samples[i:j] + else: + all_slices = list(zip(data_slices, state_slices)) + if shuffle: + # Make sure to shuffle data and states while linked together. + random.shuffle(all_slices) + for (i, j), (si, sj) in all_slices: + yield samples.slice(i, j, si, sj) + + +@OldAPIStack +def do_minibatch_sgd( + samples, + policies, + local_worker, + num_sgd_iter, + sgd_minibatch_size, + standardize_fields, +): + """Execute minibatch SGD. + + Args: + samples: Batch of samples to optimize. + policies: Dictionary of policies to optimize. + local_worker: Master rollout worker instance. + num_sgd_iter: Number of epochs of optimization to take. + sgd_minibatch_size: Size of minibatches to use for optimization. + standardize_fields: List of sample field names that should be + normalized prior to optimization. + + Returns: + averaged info fetches over the last SGD epoch taken. + """ + + # Handle everything as if multi-agent. + samples = samples.as_multi_agent() + + # Use LearnerInfoBuilder as a unified way to build the final + # results dict from `learn_on_loaded_batch` call(s). + # This makes sure results dicts always have the same structure + # no matter the setup (multi-GPU, multi-agent, minibatch SGD, + # tf vs torch). + learner_info_builder = LearnerInfoBuilder(num_devices=1) + for policy_id, policy in policies.items(): + if policy_id not in samples.policy_batches: + continue + + batch = samples.policy_batches[policy_id] + for field in standardize_fields: + batch[field] = standardized(batch[field]) + + # Check to make sure that the sgd_minibatch_size is not smaller + # than max_seq_len otherwise this will cause indexing errors while + # performing sgd when using a RNN or Attention model + if ( + policy.is_recurrent() + and policy.config["model"]["max_seq_len"] > sgd_minibatch_size + ): + raise ValueError( + "`sgd_minibatch_size` ({}) cannot be smaller than" + "`max_seq_len` ({}).".format( + sgd_minibatch_size, policy.config["model"]["max_seq_len"] + ) + ) + + for i in range(num_sgd_iter): + for minibatch in minibatches(batch, sgd_minibatch_size): + results = ( + local_worker.learn_on_batch( + MultiAgentBatch({policy_id: minibatch}, minibatch.count) + ) + )[policy_id] + learner_info_builder.add_learn_on_batch_results(results, policy_id) + + learner_info = learner_info_builder.finalize() + return learner_info diff --git a/.venv/lib/python3.11/site-packages/ray/rllib/utils/tf_run_builder.py b/.venv/lib/python3.11/site-packages/ray/rllib/utils/tf_run_builder.py new file mode 100644 index 0000000000000000000000000000000000000000..1a4116f245203e7cedbc9348c893305186dd566d --- /dev/null +++ b/.venv/lib/python3.11/site-packages/ray/rllib/utils/tf_run_builder.py @@ -0,0 +1,115 @@ +import logging +import os +import time + +from ray.util.debug import log_once +from ray.rllib.utils.annotations import OldAPIStack +from ray.rllib.utils.framework import try_import_tf + +tf1, tf, tfv = try_import_tf() +logger = logging.getLogger(__name__) + + +@OldAPIStack +class _TFRunBuilder: + """Used to incrementally build up a TensorFlow run. + + This is particularly useful for batching ops from multiple different + policies in the multi-agent setting. + """ + + def __init__(self, session, debug_name): + self.session = session + self.debug_name = debug_name + self.feed_dict = {} + self.fetches = [] + self._executed = None + + def add_feed_dict(self, feed_dict): + assert not self._executed + for k in feed_dict: + if k in self.feed_dict: + raise ValueError("Key added twice: {}".format(k)) + self.feed_dict.update(feed_dict) + + def add_fetches(self, fetches): + assert not self._executed + base_index = len(self.fetches) + self.fetches.extend(fetches) + return list(range(base_index, len(self.fetches))) + + def get(self, to_fetch): + if self._executed is None: + try: + self._executed = _run_timeline( + self.session, + self.fetches, + self.debug_name, + self.feed_dict, + os.environ.get("TF_TIMELINE_DIR"), + ) + except Exception as e: + logger.exception( + "Error fetching: {}, feed_dict={}".format( + self.fetches, self.feed_dict + ) + ) + raise e + if isinstance(to_fetch, int): + return self._executed[to_fetch] + elif isinstance(to_fetch, list): + return [self.get(x) for x in to_fetch] + elif isinstance(to_fetch, tuple): + return tuple(self.get(x) for x in to_fetch) + else: + raise ValueError("Unsupported fetch type: {}".format(to_fetch)) + + +_count = 0 + + +def _run_timeline(sess, ops, debug_name, feed_dict=None, timeline_dir=None): + if feed_dict is None: + feed_dict = {} + + if timeline_dir: + from tensorflow.python.client import timeline + + try: + run_options = tf1.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + except AttributeError: + run_options = None + # In local mode, tf1.RunOptions is not available, see #26511 + if log_once("tf1.RunOptions_not_available"): + logger.exception( + "Can not access tf.RunOptions.FULL_TRACE. This may be because " + "you have used `ray.init(local_mode=True)`. RLlib will use " + "timeline without `options=tf.RunOptions.FULL_TRACE`." + ) + run_metadata = tf1.RunMetadata() + start = time.time() + fetches = sess.run( + ops, options=run_options, run_metadata=run_metadata, feed_dict=feed_dict + ) + trace = timeline.Timeline(step_stats=run_metadata.step_stats) + global _count + outf = os.path.join( + timeline_dir, + "timeline-{}-{}-{}.json".format(debug_name, os.getpid(), _count % 10), + ) + _count += 1 + trace_file = open(outf, "w") + logger.info( + "Wrote tf timeline ({} s) to {}".format( + time.time() - start, os.path.abspath(outf) + ) + ) + trace_file.write(trace.generate_chrome_trace_format()) + else: + if log_once("tf_timeline"): + logger.info( + "Executing TF run without tracing. To dump TF timeline traces " + "to disk, set the TF_TIMELINE_DIR environment variable." + ) + fetches = sess.run(ops, feed_dict=feed_dict) + return fetches