Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_v2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/registry.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/util.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__init__.py +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/lambdas.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/mean_std_filter.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/obs_preproc.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/state_buffer.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/clip_reward.py +56 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/env_sampling.py +30 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/lambdas.py +86 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/mean_std_filter.py +187 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/obs_preproc.py +69 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/pipeline.py +72 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/state_buffer.py +120 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/synced_filter.py +52 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/view_requirement.py +135 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__init__.py +43 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/add_columns_from_episodes_to_train_batch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/add_next_observations_from_episodes_to_train_batch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/compute_returns_to_go.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/general_advantage_estimation.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/get_actions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/module_to_env_pipeline.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/normalize_and_clip_actions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/remove_single_ts_time_rank_from_batch.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/unbatch_to_individual_items.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/action_dist.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/catalog.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/distributions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/modelv2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/preprocessors.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/repeated_values.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/utils.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/attention_net.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/fcnet.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/recurrent_net.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/tf_action_dist.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/tf_distributions.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/tf_modelv2.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/visionnet.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__init__.py +17 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__pycache__/__init__.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__pycache__/gru_gate.cpython-311.pyc +0 -0
- .venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__pycache__/multi_head_attention.cpython-311.pyc +0 -0
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector.cpython-311.pyc
ADDED
|
Binary file (21.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/connector_v2.cpython-311.pyc
ADDED
|
Binary file (47.9 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/registry.cpython-311.pyc
ADDED
|
Binary file (1.85 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/__pycache__/util.cpython-311.pyc
ADDED
|
Binary file (8.07 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__init__.py
ADDED
|
File without changes
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/lambdas.cpython-311.pyc
ADDED
|
Binary file (4.16 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/mean_std_filter.cpython-311.pyc
ADDED
|
Binary file (10.2 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/obs_preproc.cpython-311.pyc
ADDED
|
Binary file (4.11 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/__pycache__/state_buffer.cpython-311.pyc
ADDED
|
Binary file (6.36 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/clip_reward.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
|
| 5 |
+
from ray.rllib.connectors.connector import (
|
| 6 |
+
AgentConnector,
|
| 7 |
+
ConnectorContext,
|
| 8 |
+
)
|
| 9 |
+
from ray.rllib.connectors.registry import register_connector
|
| 10 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 11 |
+
from ray.rllib.utils.typing import AgentConnectorDataType
|
| 12 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
@OldAPIStack
|
| 16 |
+
class ClipRewardAgentConnector(AgentConnector):
|
| 17 |
+
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
|
| 18 |
+
super().__init__(ctx)
|
| 19 |
+
assert (
|
| 20 |
+
not sign or not limit
|
| 21 |
+
), "should not enable both sign and limit reward clipping."
|
| 22 |
+
self.sign = sign
|
| 23 |
+
self.limit = limit
|
| 24 |
+
|
| 25 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 26 |
+
d = ac_data.data
|
| 27 |
+
assert (
|
| 28 |
+
type(d) is dict
|
| 29 |
+
), "Single agent data must be of type Dict[str, TensorStructType]"
|
| 30 |
+
|
| 31 |
+
if SampleBatch.REWARDS not in d:
|
| 32 |
+
# Nothing to clip. May happen for initial obs.
|
| 33 |
+
return ac_data
|
| 34 |
+
|
| 35 |
+
if self.sign:
|
| 36 |
+
d[SampleBatch.REWARDS] = np.sign(d[SampleBatch.REWARDS])
|
| 37 |
+
elif self.limit:
|
| 38 |
+
d[SampleBatch.REWARDS] = np.clip(
|
| 39 |
+
d[SampleBatch.REWARDS],
|
| 40 |
+
a_min=-self.limit,
|
| 41 |
+
a_max=self.limit,
|
| 42 |
+
)
|
| 43 |
+
return ac_data
|
| 44 |
+
|
| 45 |
+
def to_state(self):
|
| 46 |
+
return ClipRewardAgentConnector.__name__, {
|
| 47 |
+
"sign": self.sign,
|
| 48 |
+
"limit": self.limit,
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
@staticmethod
|
| 52 |
+
def from_state(ctx: ConnectorContext, params: Any):
|
| 53 |
+
return ClipRewardAgentConnector(ctx, **params)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
register_connector(ClipRewardAgentConnector.__name__, ClipRewardAgentConnector)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/env_sampling.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.connector import (
|
| 4 |
+
AgentConnector,
|
| 5 |
+
ConnectorContext,
|
| 6 |
+
)
|
| 7 |
+
from ray.rllib.connectors.registry import register_connector
|
| 8 |
+
from ray.rllib.utils.typing import AgentConnectorDataType
|
| 9 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@OldAPIStack
|
| 13 |
+
class EnvSamplingAgentConnector(AgentConnector):
|
| 14 |
+
def __init__(self, ctx: ConnectorContext, sign=False, limit=None):
|
| 15 |
+
super().__init__(ctx)
|
| 16 |
+
self.observation_space = ctx.observation_space
|
| 17 |
+
|
| 18 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 19 |
+
# EnvSamplingAgentConnector is a no-op connector.
|
| 20 |
+
return ac_data
|
| 21 |
+
|
| 22 |
+
def to_state(self):
|
| 23 |
+
return EnvSamplingAgentConnector.__name__, {}
|
| 24 |
+
|
| 25 |
+
@staticmethod
|
| 26 |
+
def from_state(ctx: ConnectorContext, params: Any):
|
| 27 |
+
return EnvSamplingAgentConnector(ctx, **params)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
register_connector(EnvSamplingAgentConnector.__name__, EnvSamplingAgentConnector)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/lambdas.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Callable, Type
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import tree # dm_tree
|
| 5 |
+
|
| 6 |
+
from ray.rllib.connectors.connector import (
|
| 7 |
+
AgentConnector,
|
| 8 |
+
ConnectorContext,
|
| 9 |
+
)
|
| 10 |
+
from ray.rllib.connectors.registry import register_connector
|
| 11 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 12 |
+
from ray.rllib.utils.typing import (
|
| 13 |
+
AgentConnectorDataType,
|
| 14 |
+
AgentConnectorsOutput,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@OldAPIStack
|
| 20 |
+
def register_lambda_agent_connector(
|
| 21 |
+
name: str, fn: Callable[[Any], Any]
|
| 22 |
+
) -> Type[AgentConnector]:
|
| 23 |
+
"""A util to register any simple transforming function as an AgentConnector
|
| 24 |
+
|
| 25 |
+
The only requirement is that fn should take a single data object and return
|
| 26 |
+
a single data object.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
name: Name of the resulting actor connector.
|
| 30 |
+
fn: The function that transforms env / agent data.
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A new AgentConnector class that transforms data using fn.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
class LambdaAgentConnector(AgentConnector):
|
| 37 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 38 |
+
return AgentConnectorDataType(
|
| 39 |
+
ac_data.env_id, ac_data.agent_id, fn(ac_data.data)
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def to_state(self):
|
| 43 |
+
return name, None
|
| 44 |
+
|
| 45 |
+
@staticmethod
|
| 46 |
+
def from_state(ctx: ConnectorContext, params: Any):
|
| 47 |
+
return LambdaAgentConnector(ctx)
|
| 48 |
+
|
| 49 |
+
LambdaAgentConnector.__name__ = name
|
| 50 |
+
LambdaAgentConnector.__qualname__ = name
|
| 51 |
+
|
| 52 |
+
register_connector(name, LambdaAgentConnector)
|
| 53 |
+
|
| 54 |
+
return LambdaAgentConnector
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
@OldAPIStack
|
| 58 |
+
def flatten_data(data: AgentConnectorsOutput):
|
| 59 |
+
assert isinstance(
|
| 60 |
+
data, AgentConnectorsOutput
|
| 61 |
+
), "Single agent data must be of type AgentConnectorsOutput"
|
| 62 |
+
|
| 63 |
+
raw_dict = data.raw_dict
|
| 64 |
+
sample_batch = data.sample_batch
|
| 65 |
+
|
| 66 |
+
flattened = {}
|
| 67 |
+
for k, v in sample_batch.items():
|
| 68 |
+
if k in [SampleBatch.INFOS, SampleBatch.ACTIONS] or k.startswith("state_out_"):
|
| 69 |
+
# Do not flatten infos, actions, and state_out_ columns.
|
| 70 |
+
flattened[k] = v
|
| 71 |
+
continue
|
| 72 |
+
if v is None:
|
| 73 |
+
# Keep the same column shape.
|
| 74 |
+
flattened[k] = None
|
| 75 |
+
continue
|
| 76 |
+
flattened[k] = np.array(tree.flatten(v))
|
| 77 |
+
flattened = SampleBatch(flattened, is_training=False)
|
| 78 |
+
|
| 79 |
+
return AgentConnectorsOutput(raw_dict, flattened)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# Agent connector to build and return a flattened observation SampleBatch
|
| 83 |
+
# in addition to the original input dict.
|
| 84 |
+
FlattenDataAgentConnector = OldAPIStack(
|
| 85 |
+
register_lambda_agent_connector("FlattenDataAgentConnector", flatten_data)
|
| 86 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/mean_std_filter.py
ADDED
|
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, List
|
| 2 |
+
from gymnasium.spaces import Discrete, MultiDiscrete
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import tree
|
| 6 |
+
|
| 7 |
+
from ray.rllib.connectors.agent.synced_filter import SyncedFilterAgentConnector
|
| 8 |
+
from ray.rllib.connectors.connector import AgentConnector
|
| 9 |
+
from ray.rllib.connectors.connector import (
|
| 10 |
+
ConnectorContext,
|
| 11 |
+
)
|
| 12 |
+
from ray.rllib.connectors.registry import register_connector
|
| 13 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 14 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 15 |
+
from ray.rllib.utils.filter import Filter
|
| 16 |
+
from ray.rllib.utils.filter import MeanStdFilter, ConcurrentMeanStdFilter
|
| 17 |
+
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
| 18 |
+
from ray.rllib.utils.typing import AgentConnectorDataType
|
| 19 |
+
from ray.rllib.utils.filter import RunningStat
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
@OldAPIStack
|
| 23 |
+
class MeanStdObservationFilterAgentConnector(SyncedFilterAgentConnector):
|
| 24 |
+
"""A connector used to mean-std-filter observations.
|
| 25 |
+
|
| 26 |
+
Incoming observations are filtered such that the output of this filter is on
|
| 27 |
+
average zero and has a standard deviation of 1. This filtering is applied
|
| 28 |
+
separately per element of the observation space.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(
|
| 32 |
+
self,
|
| 33 |
+
ctx: ConnectorContext,
|
| 34 |
+
demean: bool = True,
|
| 35 |
+
destd: bool = True,
|
| 36 |
+
clip: float = 10.0,
|
| 37 |
+
):
|
| 38 |
+
SyncedFilterAgentConnector.__init__(self, ctx)
|
| 39 |
+
# We simply use the old MeanStdFilter until non-connector env_runner is fully
|
| 40 |
+
# deprecated to avoid duplicate code
|
| 41 |
+
|
| 42 |
+
filter_shape = tree.map_structure(
|
| 43 |
+
lambda s: (
|
| 44 |
+
None
|
| 45 |
+
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
|
| 46 |
+
else np.array(s.shape)
|
| 47 |
+
),
|
| 48 |
+
get_base_struct_from_space(ctx.observation_space),
|
| 49 |
+
)
|
| 50 |
+
self.filter = MeanStdFilter(filter_shape, demean=demean, destd=destd, clip=clip)
|
| 51 |
+
|
| 52 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 53 |
+
d = ac_data.data
|
| 54 |
+
assert (
|
| 55 |
+
type(d) is dict
|
| 56 |
+
), "Single agent data must be of type Dict[str, TensorStructType]"
|
| 57 |
+
if SampleBatch.OBS in d:
|
| 58 |
+
d[SampleBatch.OBS] = self.filter(
|
| 59 |
+
d[SampleBatch.OBS], update=self._is_training
|
| 60 |
+
)
|
| 61 |
+
if SampleBatch.NEXT_OBS in d:
|
| 62 |
+
d[SampleBatch.NEXT_OBS] = self.filter(
|
| 63 |
+
d[SampleBatch.NEXT_OBS], update=self._is_training
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
return ac_data
|
| 67 |
+
|
| 68 |
+
def to_state(self):
|
| 69 |
+
# Flattening is deterministic
|
| 70 |
+
flattened_rs = tree.flatten(self.filter.running_stats)
|
| 71 |
+
flattened_buffer = tree.flatten(self.filter.buffer)
|
| 72 |
+
return MeanStdObservationFilterAgentConnector.__name__, {
|
| 73 |
+
"shape": self.filter.shape,
|
| 74 |
+
"no_preprocessor": self.filter.no_preprocessor,
|
| 75 |
+
"demean": self.filter.demean,
|
| 76 |
+
"destd": self.filter.destd,
|
| 77 |
+
"clip": self.filter.clip,
|
| 78 |
+
"running_stats": [s.to_state() for s in flattened_rs],
|
| 79 |
+
"buffer": [s.to_state() for s in flattened_buffer],
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
# demean, destd, clip, and a state dict
|
| 83 |
+
@staticmethod
|
| 84 |
+
def from_state(
|
| 85 |
+
ctx: ConnectorContext,
|
| 86 |
+
params: List[Any] = None,
|
| 87 |
+
demean: bool = True,
|
| 88 |
+
destd: bool = True,
|
| 89 |
+
clip: float = 10.0,
|
| 90 |
+
):
|
| 91 |
+
connector = MeanStdObservationFilterAgentConnector(ctx, demean, destd, clip)
|
| 92 |
+
if params:
|
| 93 |
+
connector.filter.shape = params["shape"]
|
| 94 |
+
connector.filter.no_preprocessor = params["no_preprocessor"]
|
| 95 |
+
connector.filter.demean = params["demean"]
|
| 96 |
+
connector.filter.destd = params["destd"]
|
| 97 |
+
connector.filter.clip = params["clip"]
|
| 98 |
+
|
| 99 |
+
# Unflattening is deterministic
|
| 100 |
+
running_stats = [RunningStat.from_state(s) for s in params["running_stats"]]
|
| 101 |
+
connector.filter.running_stats = tree.unflatten_as(
|
| 102 |
+
connector.filter.shape, running_stats
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
# Unflattening is deterministic
|
| 106 |
+
buffer = [RunningStat.from_state(s) for s in params["buffer"]]
|
| 107 |
+
connector.filter.buffer = tree.unflatten_as(connector.filter.shape, buffer)
|
| 108 |
+
|
| 109 |
+
return connector
|
| 110 |
+
|
| 111 |
+
def reset_state(self) -> None:
|
| 112 |
+
"""Creates copy of current state and resets accumulated state"""
|
| 113 |
+
if not self._is_training:
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"State of {} can only be changed when trainin.".format(self.__name__)
|
| 116 |
+
)
|
| 117 |
+
self.filter.reset_buffer()
|
| 118 |
+
|
| 119 |
+
def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
|
| 120 |
+
"""Updates self with state from other filter."""
|
| 121 |
+
# inline this as soon as we deprecate ordinary filter with non-connector
|
| 122 |
+
# env_runner
|
| 123 |
+
if not self._is_training:
|
| 124 |
+
raise ValueError(
|
| 125 |
+
"Changes can only be applied to {} when trainin.".format(self.__name__)
|
| 126 |
+
)
|
| 127 |
+
return self.filter.apply_changes(other, *args, **kwargs)
|
| 128 |
+
|
| 129 |
+
def copy(self) -> "Filter":
|
| 130 |
+
"""Creates a new object with same state as self.
|
| 131 |
+
|
| 132 |
+
This is a legacy Filter method that we need to keep around for now
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
A copy of self.
|
| 136 |
+
"""
|
| 137 |
+
# inline this as soon as we deprecate ordinary filter with non-connector
|
| 138 |
+
# env_runner
|
| 139 |
+
return self.filter.copy()
|
| 140 |
+
|
| 141 |
+
def sync(self, other: "AgentConnector") -> None:
|
| 142 |
+
"""Copies all state from other filter to self."""
|
| 143 |
+
# inline this as soon as we deprecate ordinary filter with non-connector
|
| 144 |
+
# env_runner
|
| 145 |
+
if not self._is_training:
|
| 146 |
+
raise ValueError(
|
| 147 |
+
"{} can only be synced when trainin.".format(self.__name__)
|
| 148 |
+
)
|
| 149 |
+
return self.filter.sync(other.filter)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
@OldAPIStack
|
| 153 |
+
class ConcurrentMeanStdObservationFilterAgentConnector(
|
| 154 |
+
MeanStdObservationFilterAgentConnector
|
| 155 |
+
):
|
| 156 |
+
"""A concurrent version of the MeanStdObservationFilterAgentConnector.
|
| 157 |
+
|
| 158 |
+
This version's filter has all operations wrapped by a threading.RLock.
|
| 159 |
+
It can therefore be safely used by multiple threads.
|
| 160 |
+
"""
|
| 161 |
+
|
| 162 |
+
def __init__(self, ctx: ConnectorContext, demean=True, destd=True, clip=10.0):
|
| 163 |
+
SyncedFilterAgentConnector.__init__(self, ctx)
|
| 164 |
+
# We simply use the old MeanStdFilter until non-connector env_runner is fully
|
| 165 |
+
# deprecated to avoid duplicate code
|
| 166 |
+
|
| 167 |
+
filter_shape = tree.map_structure(
|
| 168 |
+
lambda s: (
|
| 169 |
+
None
|
| 170 |
+
if isinstance(s, (Discrete, MultiDiscrete)) # noqa
|
| 171 |
+
else np.array(s.shape)
|
| 172 |
+
),
|
| 173 |
+
get_base_struct_from_space(ctx.observation_space),
|
| 174 |
+
)
|
| 175 |
+
self.filter = ConcurrentMeanStdFilter(
|
| 176 |
+
filter_shape, demean=True, destd=True, clip=10.0
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
|
| 180 |
+
register_connector(
|
| 181 |
+
MeanStdObservationFilterAgentConnector.__name__,
|
| 182 |
+
MeanStdObservationFilterAgentConnector,
|
| 183 |
+
)
|
| 184 |
+
register_connector(
|
| 185 |
+
ConcurrentMeanStdObservationFilterAgentConnector.__name__,
|
| 186 |
+
ConcurrentMeanStdObservationFilterAgentConnector,
|
| 187 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/obs_preproc.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any
|
| 2 |
+
|
| 3 |
+
from ray.rllib.connectors.connector import (
|
| 4 |
+
AgentConnector,
|
| 5 |
+
ConnectorContext,
|
| 6 |
+
)
|
| 7 |
+
from ray.rllib.connectors.registry import register_connector
|
| 8 |
+
from ray.rllib.models.preprocessors import get_preprocessor, NoPreprocessor
|
| 9 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 10 |
+
from ray.rllib.utils.typing import AgentConnectorDataType
|
| 11 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
@OldAPIStack
|
| 15 |
+
class ObsPreprocessorConnector(AgentConnector):
|
| 16 |
+
"""A connector that wraps around existing RLlib observation preprocessors.
|
| 17 |
+
|
| 18 |
+
This includes:
|
| 19 |
+
- OneHotPreprocessor for Discrete and Multi-Discrete spaces.
|
| 20 |
+
- GenericPixelPreprocessor and AtariRamPreprocessor for Atari spaces.
|
| 21 |
+
- TupleFlatteningPreprocessor and DictFlatteningPreprocessor for flattening
|
| 22 |
+
arbitrary nested input observations.
|
| 23 |
+
- RepeatedValuesPreprocessor for padding observations from RLlib Repeated
|
| 24 |
+
observation space.
|
| 25 |
+
"""
|
| 26 |
+
|
| 27 |
+
def __init__(self, ctx: ConnectorContext):
|
| 28 |
+
super().__init__(ctx)
|
| 29 |
+
|
| 30 |
+
if hasattr(ctx.observation_space, "original_space"):
|
| 31 |
+
# ctx.observation_space is the space this Policy deals with.
|
| 32 |
+
# We need to preprocess data from the original observation space here.
|
| 33 |
+
obs_space = ctx.observation_space.original_space
|
| 34 |
+
else:
|
| 35 |
+
obs_space = ctx.observation_space
|
| 36 |
+
|
| 37 |
+
self._preprocessor = get_preprocessor(obs_space)(
|
| 38 |
+
obs_space, ctx.config.get("model", {})
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
def is_identity(self):
|
| 42 |
+
"""Returns whether this preprocessor connector is a no-op preprocessor."""
|
| 43 |
+
return isinstance(self._preprocessor, NoPreprocessor)
|
| 44 |
+
|
| 45 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 46 |
+
d = ac_data.data
|
| 47 |
+
assert type(d) is dict, (
|
| 48 |
+
"Single agent data must be of type Dict[str, TensorStructType] but is of "
|
| 49 |
+
"type {}".format(type(d))
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
if SampleBatch.OBS in d:
|
| 53 |
+
d[SampleBatch.OBS] = self._preprocessor.transform(d[SampleBatch.OBS])
|
| 54 |
+
if SampleBatch.NEXT_OBS in d:
|
| 55 |
+
d[SampleBatch.NEXT_OBS] = self._preprocessor.transform(
|
| 56 |
+
d[SampleBatch.NEXT_OBS]
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
return ac_data
|
| 60 |
+
|
| 61 |
+
def to_state(self):
|
| 62 |
+
return ObsPreprocessorConnector.__name__, None
|
| 63 |
+
|
| 64 |
+
@staticmethod
|
| 65 |
+
def from_state(ctx: ConnectorContext, params: Any):
|
| 66 |
+
return ObsPreprocessorConnector(ctx)
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
register_connector(ObsPreprocessorConnector.__name__, ObsPreprocessorConnector)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/pipeline.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Any, List
|
| 3 |
+
from collections import defaultdict
|
| 4 |
+
|
| 5 |
+
from ray.rllib.connectors.connector import (
|
| 6 |
+
AgentConnector,
|
| 7 |
+
Connector,
|
| 8 |
+
ConnectorContext,
|
| 9 |
+
ConnectorPipeline,
|
| 10 |
+
)
|
| 11 |
+
from ray.rllib.connectors.registry import get_connector, register_connector
|
| 12 |
+
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
|
| 13 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 14 |
+
from ray.util.timer import _Timer
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
logger = logging.getLogger(__name__)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@OldAPIStack
|
| 21 |
+
class AgentConnectorPipeline(ConnectorPipeline, AgentConnector):
|
| 22 |
+
def __init__(self, ctx: ConnectorContext, connectors: List[Connector]):
|
| 23 |
+
super().__init__(ctx, connectors)
|
| 24 |
+
self.timers = defaultdict(_Timer)
|
| 25 |
+
|
| 26 |
+
def reset(self, env_id: str):
|
| 27 |
+
for c in self.connectors:
|
| 28 |
+
c.reset(env_id)
|
| 29 |
+
|
| 30 |
+
def on_policy_output(self, output: ActionConnectorDataType):
|
| 31 |
+
for c in self.connectors:
|
| 32 |
+
c.on_policy_output(output)
|
| 33 |
+
|
| 34 |
+
def __call__(
|
| 35 |
+
self, acd_list: List[AgentConnectorDataType]
|
| 36 |
+
) -> List[AgentConnectorDataType]:
|
| 37 |
+
ret = acd_list
|
| 38 |
+
for c in self.connectors:
|
| 39 |
+
timer = self.timers[str(c)]
|
| 40 |
+
with timer:
|
| 41 |
+
ret = c(ret)
|
| 42 |
+
return ret
|
| 43 |
+
|
| 44 |
+
def to_state(self):
|
| 45 |
+
children = []
|
| 46 |
+
for c in self.connectors:
|
| 47 |
+
state = c.to_state()
|
| 48 |
+
assert isinstance(state, tuple) and len(state) == 2, (
|
| 49 |
+
"Serialized connector state must be in the format of "
|
| 50 |
+
f"Tuple[name: str, params: Any]. Instead we got {state}"
|
| 51 |
+
f"for connector {c.__name__}."
|
| 52 |
+
)
|
| 53 |
+
children.append(state)
|
| 54 |
+
return AgentConnectorPipeline.__name__, children
|
| 55 |
+
|
| 56 |
+
@staticmethod
|
| 57 |
+
def from_state(ctx: ConnectorContext, params: List[Any]):
|
| 58 |
+
assert (
|
| 59 |
+
type(params) is list
|
| 60 |
+
), "AgentConnectorPipeline takes a list of connector params."
|
| 61 |
+
connectors = []
|
| 62 |
+
for state in params:
|
| 63 |
+
try:
|
| 64 |
+
name, subparams = state
|
| 65 |
+
connectors.append(get_connector(name, ctx, subparams))
|
| 66 |
+
except Exception as e:
|
| 67 |
+
logger.error(f"Failed to de-serialize connector state: {state}")
|
| 68 |
+
raise e
|
| 69 |
+
return AgentConnectorPipeline(ctx, connectors)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
register_connector(AgentConnectorPipeline.__name__, AgentConnectorPipeline)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/state_buffer.py
ADDED
|
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
import logging
|
| 3 |
+
import pickle
|
| 4 |
+
from typing import Any
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from ray.rllib.utils.annotations import override
|
| 8 |
+
import tree # dm_tree
|
| 9 |
+
|
| 10 |
+
from ray.rllib.connectors.connector import (
|
| 11 |
+
AgentConnector,
|
| 12 |
+
Connector,
|
| 13 |
+
ConnectorContext,
|
| 14 |
+
)
|
| 15 |
+
from ray import cloudpickle
|
| 16 |
+
from ray.rllib.connectors.registry import register_connector
|
| 17 |
+
from ray.rllib.core.columns import Columns
|
| 18 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 19 |
+
from ray.rllib.utils.spaces.space_utils import get_base_struct_from_space
|
| 20 |
+
from ray.rllib.utils.typing import ActionConnectorDataType, AgentConnectorDataType
|
| 21 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
logger = logging.getLogger(__name__)
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@OldAPIStack
|
| 28 |
+
class StateBufferConnector(AgentConnector):
|
| 29 |
+
def __init__(self, ctx: ConnectorContext, states: Any = None):
|
| 30 |
+
super().__init__(ctx)
|
| 31 |
+
|
| 32 |
+
self._initial_states = ctx.initial_states
|
| 33 |
+
self._action_space_struct = get_base_struct_from_space(ctx.action_space)
|
| 34 |
+
|
| 35 |
+
self._states = defaultdict(lambda: defaultdict(lambda: (None, None, None)))
|
| 36 |
+
self._enable_new_api_stack = False
|
| 37 |
+
# TODO(jungong) : we would not need this if policies are never stashed
|
| 38 |
+
# during the rollout of a single episode.
|
| 39 |
+
if states:
|
| 40 |
+
try:
|
| 41 |
+
self._states = cloudpickle.loads(states)
|
| 42 |
+
except pickle.UnpicklingError:
|
| 43 |
+
# StateBufferConnector states are only needed for rare cases
|
| 44 |
+
# like stashing then restoring a policy during the rollout of
|
| 45 |
+
# a single episode.
|
| 46 |
+
# It is ok to ignore the error for most of the cases here.
|
| 47 |
+
logger.info(
|
| 48 |
+
"Can not restore StateBufferConnector states. This warning can "
|
| 49 |
+
"usually be ignore, unless it is from restoring a stashed policy."
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
@override(Connector)
|
| 53 |
+
def in_eval(self):
|
| 54 |
+
super().in_eval()
|
| 55 |
+
|
| 56 |
+
def reset(self, env_id: str):
|
| 57 |
+
# States should not be carried over between episodes.
|
| 58 |
+
if env_id in self._states:
|
| 59 |
+
del self._states[env_id]
|
| 60 |
+
|
| 61 |
+
def on_policy_output(self, ac_data: ActionConnectorDataType):
|
| 62 |
+
# Buffer latest output states for next input __call__.
|
| 63 |
+
self._states[ac_data.env_id][ac_data.agent_id] = ac_data.output
|
| 64 |
+
|
| 65 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 66 |
+
d = ac_data.data
|
| 67 |
+
assert (
|
| 68 |
+
type(d) is dict
|
| 69 |
+
), "Single agent data must be of type Dict[str, TensorStructType]"
|
| 70 |
+
|
| 71 |
+
env_id = ac_data.env_id
|
| 72 |
+
agent_id = ac_data.agent_id
|
| 73 |
+
assert (
|
| 74 |
+
env_id is not None and agent_id is not None
|
| 75 |
+
), f"StateBufferConnector requires env_id(f{env_id}) and agent_id(f{agent_id})"
|
| 76 |
+
|
| 77 |
+
action, states, fetches = self._states[env_id][agent_id]
|
| 78 |
+
|
| 79 |
+
if action is not None:
|
| 80 |
+
d[SampleBatch.ACTIONS] = action # Last action
|
| 81 |
+
else:
|
| 82 |
+
# Default zero action.
|
| 83 |
+
d[SampleBatch.ACTIONS] = tree.map_structure(
|
| 84 |
+
lambda s: np.zeros_like(s.sample(), s.dtype)
|
| 85 |
+
if hasattr(s, "dtype")
|
| 86 |
+
else np.zeros_like(s.sample()),
|
| 87 |
+
self._action_space_struct,
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
if states is None:
|
| 91 |
+
states = self._initial_states
|
| 92 |
+
if self._enable_new_api_stack:
|
| 93 |
+
if states:
|
| 94 |
+
d[Columns.STATE_OUT] = states
|
| 95 |
+
else:
|
| 96 |
+
for i, v in enumerate(states):
|
| 97 |
+
d["state_out_{}".format(i)] = v
|
| 98 |
+
|
| 99 |
+
# Also add extra fetches if available.
|
| 100 |
+
if fetches:
|
| 101 |
+
d.update(fetches)
|
| 102 |
+
|
| 103 |
+
return ac_data
|
| 104 |
+
|
| 105 |
+
def to_state(self):
|
| 106 |
+
# Note(jungong) : it is ok to use cloudpickle here for stats because:
|
| 107 |
+
# 1. self._states may contain arbitary data objects, and will be hard
|
| 108 |
+
# to serialize otherwise.
|
| 109 |
+
# 2. seriazlized states are only useful if a policy is stashed and
|
| 110 |
+
# restored during the rollout of a single episode. So it is ok to
|
| 111 |
+
# use cloudpickle for such non-persistent data bits.
|
| 112 |
+
states = cloudpickle.dumps(self._states)
|
| 113 |
+
return StateBufferConnector.__name__, states
|
| 114 |
+
|
| 115 |
+
@staticmethod
|
| 116 |
+
def from_state(ctx: ConnectorContext, params: Any):
|
| 117 |
+
return StateBufferConnector(ctx, params)
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
register_connector(StateBufferConnector.__name__, StateBufferConnector)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/synced_filter.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.connectors.connector import (
|
| 2 |
+
AgentConnector,
|
| 3 |
+
ConnectorContext,
|
| 4 |
+
)
|
| 5 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 6 |
+
from ray.rllib.utils.filter import Filter
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@OldAPIStack
|
| 10 |
+
class SyncedFilterAgentConnector(AgentConnector):
|
| 11 |
+
"""An agent connector that filters with synchronized parameters."""
|
| 12 |
+
|
| 13 |
+
def __init__(self, ctx: ConnectorContext, *args, **kwargs):
|
| 14 |
+
super().__init__(ctx)
|
| 15 |
+
if args or kwargs:
|
| 16 |
+
raise ValueError(
|
| 17 |
+
"SyncedFilterAgentConnector does not take any additional arguments, "
|
| 18 |
+
"but got args=`{}` and kwargs={}.".format(args, kwargs)
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
def apply_changes(self, other: "Filter", *args, **kwargs) -> None:
|
| 22 |
+
"""Updates self with state from other filter."""
|
| 23 |
+
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
|
| 24 |
+
# non-connecto env_runner
|
| 25 |
+
return self.filter.apply_changes(other, *args, **kwargs)
|
| 26 |
+
|
| 27 |
+
def copy(self) -> "Filter":
|
| 28 |
+
"""Creates a new object with same state as self.
|
| 29 |
+
|
| 30 |
+
This is a legacy Filter method that we need to keep around for now
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
A copy of self.
|
| 34 |
+
"""
|
| 35 |
+
# inline this as soon as we deprecate ordinary filter with non-connector
|
| 36 |
+
# env_runner
|
| 37 |
+
return self.filter.copy()
|
| 38 |
+
|
| 39 |
+
def sync(self, other: "AgentConnector") -> None:
|
| 40 |
+
"""Copies all state from other filter to self."""
|
| 41 |
+
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
|
| 42 |
+
# non-connector env_runner
|
| 43 |
+
return self.filter.sync(other.filter)
|
| 44 |
+
|
| 45 |
+
def reset_state(self) -> None:
|
| 46 |
+
"""Creates copy of current state and resets accumulated state"""
|
| 47 |
+
raise NotImplementedError
|
| 48 |
+
|
| 49 |
+
def as_serializable(self) -> "Filter":
|
| 50 |
+
# TODO: (artur) inline this as soon as we deprecate ordinary filter with
|
| 51 |
+
# non-connector env_runner
|
| 52 |
+
return self.filter.as_serializable()
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/agent/view_requirement.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from collections import defaultdict
|
| 2 |
+
from typing import Any
|
| 3 |
+
|
| 4 |
+
from ray.rllib.connectors.connector import (
|
| 5 |
+
AgentConnector,
|
| 6 |
+
ConnectorContext,
|
| 7 |
+
)
|
| 8 |
+
from ray.rllib.connectors.registry import register_connector
|
| 9 |
+
from ray.rllib.policy.sample_batch import SampleBatch
|
| 10 |
+
from ray.rllib.utils.typing import (
|
| 11 |
+
AgentConnectorDataType,
|
| 12 |
+
AgentConnectorsOutput,
|
| 13 |
+
)
|
| 14 |
+
from ray.rllib.utils.annotations import OldAPIStack
|
| 15 |
+
from ray.rllib.evaluation.collectors.agent_collector import AgentCollector
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@OldAPIStack
|
| 19 |
+
class ViewRequirementAgentConnector(AgentConnector):
|
| 20 |
+
"""This connector does 2 things:
|
| 21 |
+
1. It filters data columns based on view_requirements for training and inference.
|
| 22 |
+
2. It buffers the right amount of history for computing the sample batch for
|
| 23 |
+
action computation.
|
| 24 |
+
The output of this connector is AgentConnectorsOut, which basically is
|
| 25 |
+
a tuple of 2 things:
|
| 26 |
+
{
|
| 27 |
+
"raw_dict": {"obs": ...}
|
| 28 |
+
"sample_batch": SampleBatch
|
| 29 |
+
}
|
| 30 |
+
raw_dict, which contains raw data for the latest time slice,
|
| 31 |
+
can be used to construct a complete episode by Sampler for training purpose.
|
| 32 |
+
The "for_action" SampleBatch can be used to directly call the policy.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
def __init__(self, ctx: ConnectorContext):
|
| 36 |
+
super().__init__(ctx)
|
| 37 |
+
|
| 38 |
+
self._view_requirements = ctx.view_requirements
|
| 39 |
+
_enable_new_api_stack = False
|
| 40 |
+
|
| 41 |
+
# a dict of env_id to a dict of agent_id to a list of agent_collector objects
|
| 42 |
+
self.agent_collectors = defaultdict(
|
| 43 |
+
lambda: defaultdict(
|
| 44 |
+
lambda: AgentCollector(
|
| 45 |
+
self._view_requirements,
|
| 46 |
+
max_seq_len=ctx.config["model"]["max_seq_len"],
|
| 47 |
+
intial_states=ctx.initial_states,
|
| 48 |
+
disable_action_flattening=ctx.config.get(
|
| 49 |
+
"_disable_action_flattening", False
|
| 50 |
+
),
|
| 51 |
+
is_policy_recurrent=ctx.is_policy_recurrent,
|
| 52 |
+
# Note(jungong): We only leverage AgentCollector for building sample
|
| 53 |
+
# batches for computing actions.
|
| 54 |
+
# So regardless of whether this ViewRequirement connector is in
|
| 55 |
+
# training or inference mode, we should tell these AgentCollectors
|
| 56 |
+
# to behave in inference mode, so they don't accumulate episode data
|
| 57 |
+
# that is not useful for inference.
|
| 58 |
+
is_training=False,
|
| 59 |
+
_enable_new_api_stack=_enable_new_api_stack,
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def reset(self, env_id: str):
|
| 65 |
+
if env_id in self.agent_collectors:
|
| 66 |
+
del self.agent_collectors[env_id]
|
| 67 |
+
|
| 68 |
+
def transform(self, ac_data: AgentConnectorDataType) -> AgentConnectorDataType:
|
| 69 |
+
d = ac_data.data
|
| 70 |
+
assert (
|
| 71 |
+
type(d) is dict
|
| 72 |
+
), "Single agent data must be of type Dict[str, TensorStructType]"
|
| 73 |
+
|
| 74 |
+
env_id = ac_data.env_id
|
| 75 |
+
agent_id = ac_data.agent_id
|
| 76 |
+
# TODO: we don't keep episode_id around so use env_id as episode_id ?
|
| 77 |
+
episode_id = env_id if SampleBatch.EPS_ID not in d else d[SampleBatch.EPS_ID]
|
| 78 |
+
|
| 79 |
+
assert env_id is not None and agent_id is not None, (
|
| 80 |
+
f"ViewRequirementAgentConnector requires env_id({env_id}) "
|
| 81 |
+
"and agent_id({agent_id})"
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
assert (
|
| 85 |
+
self._view_requirements
|
| 86 |
+
), "ViewRequirements required by ViewRequirementAgentConnector"
|
| 87 |
+
|
| 88 |
+
# Note(jungong) : we need to keep the entire input dict here.
|
| 89 |
+
# A column may be used by postprocessing (GAE) even if its
|
| 90 |
+
# view_requirement.used_for_training is False.
|
| 91 |
+
training_dict = d
|
| 92 |
+
|
| 93 |
+
agent_collector = self.agent_collectors[env_id][agent_id]
|
| 94 |
+
|
| 95 |
+
if SampleBatch.NEXT_OBS not in d:
|
| 96 |
+
raise ValueError(f"connector data {d} should contain next_obs.")
|
| 97 |
+
# TODO(avnishn; kourosh) Unsure how agent_index is necessary downstream
|
| 98 |
+
# since there is no mapping from agent_index to agent_id that exists.
|
| 99 |
+
# need to remove this from the SampleBatch later.
|
| 100 |
+
# fall back to using dummy index if no index is available
|
| 101 |
+
if SampleBatch.AGENT_INDEX in d:
|
| 102 |
+
agent_index = d[SampleBatch.AGENT_INDEX]
|
| 103 |
+
else:
|
| 104 |
+
try:
|
| 105 |
+
agent_index = float(agent_id)
|
| 106 |
+
except ValueError:
|
| 107 |
+
agent_index = -1
|
| 108 |
+
if agent_collector.is_empty():
|
| 109 |
+
agent_collector.add_init_obs(
|
| 110 |
+
episode_id=episode_id,
|
| 111 |
+
agent_index=agent_index,
|
| 112 |
+
env_id=env_id,
|
| 113 |
+
init_obs=d[SampleBatch.NEXT_OBS],
|
| 114 |
+
init_infos=d.get(SampleBatch.INFOS),
|
| 115 |
+
)
|
| 116 |
+
else:
|
| 117 |
+
agent_collector.add_action_reward_next_obs(d)
|
| 118 |
+
sample_batch = agent_collector.build_for_inference()
|
| 119 |
+
|
| 120 |
+
return_data = AgentConnectorDataType(
|
| 121 |
+
env_id, agent_id, AgentConnectorsOutput(training_dict, sample_batch)
|
| 122 |
+
)
|
| 123 |
+
return return_data
|
| 124 |
+
|
| 125 |
+
def to_state(self):
|
| 126 |
+
return ViewRequirementAgentConnector.__name__, None
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def from_state(ctx: ConnectorContext, params: Any):
|
| 130 |
+
return ViewRequirementAgentConnector(ctx)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
register_connector(
|
| 134 |
+
ViewRequirementAgentConnector.__name__, ViewRequirementAgentConnector
|
| 135 |
+
)
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__init__.py
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.connectors.common.add_observations_from_episodes_to_batch import (
|
| 2 |
+
AddObservationsFromEpisodesToBatch,
|
| 3 |
+
)
|
| 4 |
+
from ray.rllib.connectors.common.add_states_from_episodes_to_batch import (
|
| 5 |
+
AddStatesFromEpisodesToBatch,
|
| 6 |
+
)
|
| 7 |
+
from ray.rllib.connectors.common.add_time_dim_to_batch_and_zero_pad import (
|
| 8 |
+
AddTimeDimToBatchAndZeroPad,
|
| 9 |
+
)
|
| 10 |
+
from ray.rllib.connectors.common.agent_to_module_mapping import AgentToModuleMapping
|
| 11 |
+
from ray.rllib.connectors.common.batch_individual_items import BatchIndividualItems
|
| 12 |
+
from ray.rllib.connectors.common.numpy_to_tensor import NumpyToTensor
|
| 13 |
+
from ray.rllib.connectors.learner.add_columns_from_episodes_to_train_batch import (
|
| 14 |
+
AddColumnsFromEpisodesToTrainBatch,
|
| 15 |
+
)
|
| 16 |
+
from ray.rllib.connectors.learner.add_next_observations_from_episodes_to_train_batch import ( # noqa
|
| 17 |
+
AddNextObservationsFromEpisodesToTrainBatch,
|
| 18 |
+
)
|
| 19 |
+
from ray.rllib.connectors.learner.add_one_ts_to_episodes_and_truncate import (
|
| 20 |
+
AddOneTsToEpisodesAndTruncate,
|
| 21 |
+
)
|
| 22 |
+
from ray.rllib.connectors.learner.compute_returns_to_go import ComputeReturnsToGo
|
| 23 |
+
from ray.rllib.connectors.learner.general_advantage_estimation import (
|
| 24 |
+
GeneralAdvantageEstimation,
|
| 25 |
+
)
|
| 26 |
+
from ray.rllib.connectors.learner.learner_connector_pipeline import (
|
| 27 |
+
LearnerConnectorPipeline,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
__all__ = [
|
| 31 |
+
"AddColumnsFromEpisodesToTrainBatch",
|
| 32 |
+
"AddNextObservationsFromEpisodesToTrainBatch",
|
| 33 |
+
"AddObservationsFromEpisodesToBatch",
|
| 34 |
+
"AddOneTsToEpisodesAndTruncate",
|
| 35 |
+
"AddStatesFromEpisodesToBatch",
|
| 36 |
+
"AddTimeDimToBatchAndZeroPad",
|
| 37 |
+
"AgentToModuleMapping",
|
| 38 |
+
"BatchIndividualItems",
|
| 39 |
+
"ComputeReturnsToGo",
|
| 40 |
+
"GeneralAdvantageEstimation",
|
| 41 |
+
"LearnerConnectorPipeline",
|
| 42 |
+
"NumpyToTensor",
|
| 43 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.91 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/add_columns_from_episodes_to_train_batch.cpython-311.pyc
ADDED
|
Binary file (7.62 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/add_next_observations_from_episodes_to_train_batch.cpython-311.pyc
ADDED
|
Binary file (4.92 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/compute_returns_to_go.cpython-311.pyc
ADDED
|
Binary file (3.14 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/learner/__pycache__/general_advantage_estimation.cpython-311.pyc
ADDED
|
Binary file (8.74 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (1.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/get_actions.cpython-311.pyc
ADDED
|
Binary file (4.55 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/module_to_env_pipeline.cpython-311.pyc
ADDED
|
Binary file (714 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/normalize_and_clip_actions.cpython-311.pyc
ADDED
|
Binary file (7.29 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/remove_single_ts_time_rank_from_batch.cpython-311.pyc
ADDED
|
Binary file (3.6 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/connectors/module_to_env/__pycache__/unbatch_to_individual_items.cpython-311.pyc
ADDED
|
Binary file (4.47 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (596 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/action_dist.cpython-311.pyc
ADDED
|
Binary file (5.15 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/catalog.cpython-311.pyc
ADDED
|
Binary file (35.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/distributions.cpython-311.pyc
ADDED
|
Binary file (10.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/modelv2.cpython-311.pyc
ADDED
|
Binary file (21.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/preprocessors.cpython-311.pyc
ADDED
|
Binary file (26.7 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/repeated_values.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/__pycache__/utils.cpython-311.pyc
ADDED
|
Binary file (9.93 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (584 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/attention_net.cpython-311.pyc
ADDED
|
Binary file (28.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/fcnet.cpython-311.pyc
ADDED
|
Binary file (7.03 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/recurrent_net.cpython-311.pyc
ADDED
|
Binary file (14.4 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/tf_action_dist.cpython-311.pyc
ADDED
|
Binary file (51.8 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/tf_distributions.cpython-311.pyc
ADDED
|
Binary file (34.1 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/tf_modelv2.cpython-311.pyc
ADDED
|
Binary file (8.26 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/__pycache__/visionnet.cpython-311.pyc
ADDED
|
Binary file (10.5 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__init__.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ray.rllib.models.tf.layers.gru_gate import GRUGate
|
| 2 |
+
from ray.rllib.models.tf.layers.noisy_layer import NoisyLayer
|
| 3 |
+
from ray.rllib.models.tf.layers.relative_multi_head_attention import (
|
| 4 |
+
PositionalEmbedding,
|
| 5 |
+
RelativeMultiHeadAttention,
|
| 6 |
+
)
|
| 7 |
+
from ray.rllib.models.tf.layers.skip_connection import SkipConnection
|
| 8 |
+
from ray.rllib.models.tf.layers.multi_head_attention import MultiHeadAttention
|
| 9 |
+
|
| 10 |
+
__all__ = [
|
| 11 |
+
"GRUGate",
|
| 12 |
+
"MultiHeadAttention",
|
| 13 |
+
"NoisyLayer",
|
| 14 |
+
"PositionalEmbedding",
|
| 15 |
+
"RelativeMultiHeadAttention",
|
| 16 |
+
"SkipConnection",
|
| 17 |
+
]
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (814 Bytes). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__pycache__/gru_gate.cpython-311.pyc
ADDED
|
Binary file (4.21 kB). View file
|
|
|
.venv/lib/python3.11/site-packages/ray/rllib/models/tf/layers/__pycache__/multi_head_attention.cpython-311.pyc
ADDED
|
Binary file (4.23 kB). View file
|
|
|