Duibonduil commited on
Commit
10b3362
·
verified ·
1 Parent(s): 236dcf5

Upload 4 files

Browse files
examples/tools/gym_tool/actions.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+ from examples.tools.tool_action import GymAction
4
+ from aworld.core.tool.action_factory import ActionFactory
5
+ from aworld.core.tool.action import ExecutableAction
6
+
7
+
8
+ @ActionFactory.register(name=GymAction.PLAY.value.name,
9
+ desc=GymAction.PLAY.value.desc,
10
+ tool_name="openai_gym")
11
+ class Play(ExecutableAction):
12
+ """"""
examples/tools/gym_tool/async_openai_gym.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+
4
+ from typing import Dict, Any, Tuple, SupportsFloat, Union, List
5
+
6
+ from pydantic import BaseModel
7
+
8
+ from aworld.config import ConfigDict
9
+ from examples.tools.tool_action import GymAction
10
+ from aworld.core.common import ActionModel, Observation, ActionResult
11
+ from aworld.core.tool.base import AsyncTool, ToolFactory
12
+ from aworld.utils.import_package import import_packages
13
+ from aworld.tools.utils import build_observation
14
+
15
+
16
+ class ActionType(object):
17
+ DISCRETE = 'discrete'
18
+ CONTINUOUS = 'continuous'
19
+
20
+
21
+ @ToolFactory.register(name="openai_gym", desc="gym classic control game", asyn=True, supported_action=GymAction)
22
+ class OpenAIGym(AsyncTool):
23
+ def __init__(self, conf: Union[Dict[str, Any], ConfigDict, BaseModel], **kwargs) -> None:
24
+ """Gym environment constructor.
25
+
26
+ Args:
27
+ env_id: gym environment full name
28
+ wrappers: gym environment wrapper list
29
+ """
30
+ import_packages(['pygame', 'gymnasium'])
31
+ super(OpenAIGym, self).__init__(conf, **kwargs)
32
+ self.env_id = self.conf.get("env_id")
33
+ self._render = self.conf.get('render', True)
34
+ if self._render:
35
+ kwargs['render_mode'] = self.conf.get('render_mode', True)
36
+ kwargs.pop('name', None)
37
+ self.env = self._gym_env_wrappers(self.env_id, self.conf.get("wrappers", []), **kwargs)
38
+ self.action_space = self.env.action_space
39
+
40
+ async def do_step(self, actions: List[ActionModel], **kwargs) -> Tuple[
41
+ Observation, SupportsFloat, bool, bool, Dict[str, Any]]:
42
+ if self._render:
43
+ await self.render()
44
+ action = actions[0].params['result']
45
+ action = OpenAIGym.transform_action(action=action)
46
+ state, reward, terminal, truncate, info = self.env.step(action)
47
+ info.update(kwargs)
48
+ self._finished = terminal
49
+
50
+ action_results = []
51
+ for _ in actions:
52
+ action_results.append(ActionResult(content=OpenAIGym.transform_state(state=state), success=True))
53
+ return (build_observation(observer=self.name(),
54
+ action_result=action_results,
55
+ ability=GymAction.PLAY.value.name,
56
+ content=OpenAIGym.transform_state(state=state),
57
+ env_id=self.env_id,
58
+ done=terminal,
59
+ **kwargs),
60
+ reward,
61
+ terminal,
62
+ truncate,
63
+ info)
64
+
65
+ async def render(self):
66
+ return self.env.render()
67
+
68
+ async def close(self):
69
+ if self.env:
70
+ self.env.close()
71
+ self.env = None
72
+
73
+ async def reset(self, *, seed: int | None = None, options: Dict[str, str] | None = None) -> Tuple[
74
+ Any, Dict[str, Any]]:
75
+ state = self.env.reset()
76
+ return build_observation(observer=self.name(),
77
+ ability=GymAction.PLAY.value.name,
78
+ content=OpenAIGym.transform_state(state=state),
79
+ env_id=self.env_id,
80
+ done=False), {}
81
+
82
+ def _action_dim(self):
83
+ from gymnasium import spaces
84
+
85
+ if isinstance(self.env.action_space, spaces.Discrete):
86
+ self.action_type = ActionType.DISCRETE
87
+ return self.env.action_space.n
88
+ elif isinstance(self.env.action_space, spaces.Box):
89
+ self.action_type = ActionType.CONTINUOUS
90
+ return self.env.action_space.shape[0]
91
+ else:
92
+ raise Exception('unsupported env.action_space: {}'.format(self.env.action_space))
93
+
94
+ def _state_dim(self):
95
+ if len(self.env.observation_space.shape) == 1:
96
+ return self.env.observation_space.shape[0]
97
+ else:
98
+ raise Exception('unsupported observation_space.shape: {}'.format(self.env.observation_space))
99
+
100
+ def _gym_env_wrappers(self, env_id, wrappers: list = [], **kwargs):
101
+ import gymnasium
102
+
103
+ env = gymnasium.make(env_id, **kwargs)
104
+
105
+ if wrappers:
106
+ for wrapper in wrappers:
107
+ env = wrapper(env)
108
+
109
+ return env
110
+
111
+ @staticmethod
112
+ def transform_state(state: Any):
113
+ if isinstance(state, tuple):
114
+ states = dict()
115
+ for n, state in enumerate(state):
116
+ state = OpenAIGym.transform_state(state=state)
117
+ if isinstance(state, dict):
118
+ for name, state in state.items():
119
+ states['gym{}-{}'.format(n, name)] = state
120
+ else:
121
+ states['gym{}'.format(n)] = state
122
+ return states
123
+ elif isinstance(state, dict):
124
+ states = dict()
125
+ for state_name, state in state.items():
126
+ state = OpenAIGym.transform_state(state=state)
127
+ if isinstance(state, dict):
128
+ for name, state in state.items():
129
+ states['{}-{}'.format(state_name, name)] = state
130
+ else:
131
+ states['{}'.format(state_name)] = state
132
+ return states
133
+ else:
134
+ return state
135
+
136
+ @staticmethod
137
+ def transform_action(action: Any):
138
+ if not isinstance(action, dict):
139
+ return action
140
+ else:
141
+ actions = dict()
142
+ for name, action in action.items():
143
+ if '-' in name:
144
+ name, inner_name = name.split('-', 1)
145
+ if name not in actions:
146
+ actions[name] = dict()
147
+ actions[name][inner_name] = action
148
+ else:
149
+ actions[name] = action
150
+ for name, action in actions.items():
151
+ if isinstance(action, dict):
152
+ actions[name] = OpenAIGym.transform_action(action=action)
153
+ return actions
examples/tools/gym_tool/openai_gym.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ # Copyright (c) 2025 inclusionAI.
3
+
4
+ from typing import Dict, Any, Tuple, SupportsFloat, List, Union
5
+
6
+ from aworld.config import ConfigDict, ToolConfig
7
+ from examples.tools.tool_action import GymAction
8
+ from aworld.core.common import Observation, ActionModel, ActionResult
9
+ from aworld.core.tool.base import Tool, ToolFactory
10
+ from aworld.utils.import_package import import_packages
11
+ from aworld.tools.utils import build_observation
12
+
13
+
14
+ class ActionType(object):
15
+ DISCRETE = 'discrete'
16
+ CONTINUOUS = 'continuous'
17
+
18
+
19
+ @ToolFactory.register(name="openai_gym",
20
+ desc="gym classic control game",
21
+ supported_action=GymAction,
22
+ conf_file_name=f'openai_gym_tool.yaml')
23
+ class OpenAIGym(Tool):
24
+ def __init__(self, conf: Union[Dict[str, Any], ConfigDict, ToolConfig], **kwargs) -> None:
25
+ """Gym environment constructor.
26
+
27
+ Args:
28
+ env_id: gym environment full name
29
+ wrappers: gym environment wrapper list
30
+ """
31
+ import_packages(['pygame', 'gymnasium'])
32
+ super(OpenAIGym, self).__init__(conf, **kwargs)
33
+ self.env_id = self.conf.get("env_id")
34
+ self._render = self.conf.get('render', True)
35
+ if self._render:
36
+ kwargs['render_mode'] = self.conf.get('render_mode', 'human')
37
+ kwargs.pop('name', None)
38
+ self.env = self._gym_env_wrappers(self.env_id, self.conf.get("wrappers", []), **kwargs)
39
+ self.action_space = self.env.action_space
40
+
41
+ def do_step(self, actions: List[ActionModel], **kwargs) -> Tuple[
42
+ Observation, SupportsFloat, bool, bool, Dict[str, Any]]:
43
+ if self._render:
44
+ self.render()
45
+ action = actions[0].params['result']
46
+ action = OpenAIGym.transform_action(action=action)
47
+ state, reward, terminal, truncate, info = self.env.step(action)
48
+ info.update(kwargs)
49
+ self._finished = terminal
50
+
51
+ action_results = []
52
+ for _ in actions:
53
+ action_results.append(ActionResult(content=OpenAIGym.transform_state(state=state), success=True))
54
+ return (build_observation(observer=self.name(),
55
+ action_result=action_results,
56
+ ability=GymAction.PLAY.value.name,
57
+ content=OpenAIGym.transform_state(state=state),
58
+ env_id=self.env_id,
59
+ done=terminal,
60
+ **kwargs),
61
+ reward,
62
+ terminal,
63
+ truncate,
64
+ info)
65
+
66
+ def render(self):
67
+ return self.env.render()
68
+
69
+ def close(self):
70
+ if self.env:
71
+ self.env.close()
72
+ self.env = None
73
+
74
+ def reset(self, *, seed: int | None = None, options: Dict[str, str] | None = None) -> Tuple[Any, Dict[str, Any]]:
75
+ state = self.env.reset()
76
+ return build_observation(observer=self.name(),
77
+ ability=GymAction.PLAY.value.name,
78
+ content=OpenAIGym.transform_state(state=state),
79
+ env_id=self.env_id,
80
+ done=False), {}
81
+
82
+ def _action_dim(self):
83
+ from gymnasium import spaces
84
+
85
+ if isinstance(self.env.action_space, spaces.Discrete):
86
+ self.action_type = ActionType.DISCRETE
87
+ return self.env.action_space.n
88
+ elif isinstance(self.env.action_space, spaces.Box):
89
+ self.action_type = ActionType.CONTINUOUS
90
+ return self.env.action_space.shape[0]
91
+ else:
92
+ raise Exception('unsupported env.action_space: {}'.format(self.env.action_space))
93
+
94
+ def _state_dim(self):
95
+ if len(self.env.observation_space.shape) == 1:
96
+ return self.env.observation_space.shape[0]
97
+ else:
98
+ raise Exception('unsupported observation_space.shape: {}'.format(self.env.observation_space))
99
+
100
+ def _gym_env_wrappers(self, env_id, wrappers: list = [], **kwargs):
101
+ import gymnasium
102
+
103
+ env = gymnasium.make(env_id, **kwargs)
104
+
105
+ if wrappers:
106
+ for wrapper in wrappers:
107
+ env = wrapper(env)
108
+
109
+ return env
110
+
111
+ @staticmethod
112
+ def transform_state(state: Any):
113
+ if isinstance(state, tuple):
114
+ states = dict()
115
+ for n, state in enumerate(state):
116
+ state = OpenAIGym.transform_state(state=state)
117
+ if isinstance(state, dict):
118
+ for name, state in state.items():
119
+ states['gym{}-{}'.format(n, name)] = state
120
+ else:
121
+ states['gym{}'.format(n)] = state
122
+ return states
123
+ elif isinstance(state, dict):
124
+ states = dict()
125
+ for state_name, state in state.items():
126
+ state = OpenAIGym.transform_state(state=state)
127
+ if isinstance(state, dict):
128
+ for name, state in state.items():
129
+ states['{}-{}'.format(state_name, name)] = state
130
+ else:
131
+ states['{}'.format(state_name)] = state
132
+ return states
133
+ else:
134
+ return state
135
+
136
+ @staticmethod
137
+ def transform_action(action: Any):
138
+ if not isinstance(action, dict):
139
+ return action
140
+ else:
141
+ actions = dict()
142
+ for name, action in action.items():
143
+ if '-' in name:
144
+ name, inner_name = name.split('-', 1)
145
+ if name not in actions:
146
+ actions[name] = dict()
147
+ actions[name][inner_name] = action
148
+ else:
149
+ actions[name] = action
150
+ for name, action in actions.items():
151
+ if isinstance(action, dict):
152
+ actions[name] = OpenAIGym.transform_action(action=action)
153
+ return actions
examples/tools/gym_tool/requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ gymnasium~=1.1.0
2
+ pygame~=2.6.1