tarantula11 commited on
Commit
d742a8e
·
verified ·
1 Parent(s): 8c57b4a

Upload submission from kinitro-agent-template

Browse files
Files changed (8) hide show
  1. .gitignore +6 -4
  2. agent.capnp +9 -9
  3. agent.py +1 -0
  4. agent_server.py +56 -44
  5. evaluation.py +499 -0
  6. main.py +230 -9
  7. pyproject.toml +7 -1
  8. uv.lock +0 -0
.gitignore CHANGED
@@ -1,3 +1,5 @@
 
 
1
  # Byte-compiled / optimized / DLL files
2
  __pycache__/
3
  *.py[codz]
@@ -131,7 +133,7 @@ __pypackages__/
131
  celerybeat-schedule
132
  celerybeat.pid
133
 
134
- # Redis
135
  *.rdb
136
  *.aof
137
  *.pid
@@ -195,9 +197,9 @@ cython_debug/
195
  .abstra/
196
 
197
  # Visual Studio Code
198
- # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
199
  # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
200
- # and can be added to the global gitignore or merged into this file. However, if you prefer,
201
  # you could uncomment the following to ignore the entire vscode folder
202
  # .vscode/
203
 
@@ -213,4 +215,4 @@ marimo/_lsp/
213
  __marimo__/
214
 
215
  # Streamlit
216
- .streamlit/secrets.toml
 
1
+ runs/
2
+
3
  # Byte-compiled / optimized / DLL files
4
  __pycache__/
5
  *.py[codz]
 
133
  celerybeat-schedule
134
  celerybeat.pid
135
 
136
+ # Redis
137
  *.rdb
138
  *.aof
139
  *.pid
 
197
  .abstra/
198
 
199
  # Visual Studio Code
200
+ # Visual Studio Code specific template is maintained in a separate VisualStudioCode.gitignore
201
  # that can be found at https://github.com/github/gitignore/blob/main/Global/VisualStudioCode.gitignore
202
+ # and can be added to the global gitignore or merged into this file. However, if you prefer,
203
  # you could uncomment the following to ignore the entire vscode folder
204
  # .vscode/
205
 
 
215
  __marimo__/
216
 
217
  # Streamlit
218
+ .streamlit/secrets.toml
agent.capnp CHANGED
@@ -1,13 +1,13 @@
1
- @0x893bac407c81b48c;
2
 
3
  interface Agent {
 
 
 
 
4
 
5
- struct Tensor {
6
- data @0 :Data; # tensor bytes tensor.numpy().tobytes()
7
- shape @1 :List(UInt64); # tensor shape list(tensor.shape())
8
- dtype @2 :Text; # data type name tensor.dtype()
9
- }
10
-
11
- act @0 (obs :Data) -> (action :Tensor);
12
- reset @1 () -> ();
13
  }
 
1
+ @0xbf5147e1a2a3a3b1;
2
 
3
  interface Agent {
4
+ ping @0 (message :Text) -> (response :Text);
5
+ act @1 (obs :Tensor) -> (action :Tensor);
6
+ reset @2 ();
7
+ }
8
 
9
+ struct Tensor {
10
+ data @0 :Data;
11
+ shape @1 :List(Int32);
12
+ dtype @2 :Text;
 
 
 
 
13
  }
agent.py CHANGED
@@ -8,6 +8,7 @@ import logging
8
  from typing import Any, Dict
9
 
10
  import gymnasium as gym
 
11
  import metaworld
12
  import numpy as np
13
  import torch
 
8
  from typing import Any, Dict
9
 
10
  import gymnasium as gym
11
+
12
  import metaworld
13
  import numpy as np
14
  import torch
agent_server.py CHANGED
@@ -1,15 +1,16 @@
1
  #!/usr/bin/env python3
2
  """
3
  Cap'n Proto RPC Server for Agent Interface
 
4
  """
5
 
6
  import asyncio
7
  import logging
8
  import os
9
- import pickle
 
10
  import numpy as np
11
  import torch
12
- import capnp
13
 
14
  # Load the schema
15
  schema_file = os.path.join(os.path.dirname(__file__), "agent.capnp")
@@ -17,6 +18,12 @@ agent_capnp = capnp.load(schema_file)
17
 
18
  logger = logging.getLogger(__name__)
19
 
 
 
 
 
 
 
20
 
21
  class AgentServer(agent_capnp.Agent.Server):
22
  """Cap'n Proto server implementation for AgentInterface"""
@@ -27,73 +34,80 @@ class AgentServer(agent_capnp.Agent.Server):
27
  self.logger.info("AgentServer initialized with agent: %s", type(agent).__name__)
28
 
29
  async def act(self, obs, **kwargs):
30
- """Handle act RPC call"""
31
  try:
32
- # Deserialize observation from bytes
33
- observation = pickle.loads(obs)
34
-
35
- # Call the agent's act method
36
- action_tensor = self.agent.act(observation)
37
-
38
- # Convert to numpy if it's a torch tensor
 
 
 
 
 
 
 
 
 
 
 
39
  if isinstance(action_tensor, torch.Tensor):
40
- action_numpy = action_tensor.detach().cpu().numpy()
41
  else:
42
- action_numpy = np.array(action_tensor)
43
-
44
- # Prepare tensor response
45
- response = agent_capnp.Agent.Tensor.new_message()
46
- response.data = action_numpy.tobytes()
47
- response.shape = list(action_numpy.shape)
48
- response.dtype = str(action_numpy.dtype)
49
 
 
 
 
 
 
50
  return response
51
- except Exception as e:
52
- self.logger.error(f"Error in act: {e}", exc_info=True)
53
  raise
54
 
55
  async def reset(self, **kwargs):
56
- """Handle reset RPC call"""
57
  try:
58
  self.agent.reset()
59
- except Exception as e:
60
- self.logger.error(f"Error in reset: {e}", exc_info=True)
61
  raise
62
 
 
 
 
 
63
 
64
- async def serve(agent, address="127.0.0.1", port=8000):
65
  """Serve the agent using asyncio approach"""
66
 
67
  async def new_connection(stream):
68
- """Handler for each new client connection"""
69
  try:
70
- # Create TwoPartyServer for this connection
71
- server = capnp.TwoPartyServer(stream, bootstrap=AgentServer(agent))
72
-
73
- # Wait for the connection to disconnect
 
74
  await server.on_disconnect()
 
 
75
 
76
- except Exception as e:
77
- logger.error(f"Error handling connection: {e}", exc_info=True)
78
-
79
- # Create the server
80
  server = await capnp.AsyncIoStream.create_server(new_connection, address, port)
81
-
82
- logger.info(f"Agent RPC server listening on {address}:{port}")
83
 
84
  try:
85
- # Keep the server running
86
  async with server:
87
  await server.serve_forever()
88
- except Exception as e:
89
- logger.error(f"Server error: {e}", exc_info=True)
90
  finally:
91
  logger.info("Server shutting down")
92
 
93
 
94
- def start_server(agent, address="127.0.0.1", port=8000):
95
- """Start server with proper asyncio event loop handling"""
96
-
97
  async def run_server_with_kj():
98
  async with capnp.kj_loop():
99
  await serve(agent, address, port)
@@ -104,9 +118,7 @@ def start_server(agent, address="127.0.0.1", port=8000):
104
  logger.info("Server stopped by user")
105
 
106
 
107
- def run_server_in_process(agent, address="127.0.0.1", port=8000):
108
- """Entry point for running server in a separate process"""
109
-
110
  async def run_with_kj():
111
  async with capnp.kj_loop():
112
  await serve(agent, address, port)
 
1
  #!/usr/bin/env python3
2
  """
3
  Cap'n Proto RPC Server for Agent Interface
4
+ Receives observation as Agent.Tensor (no pickle).
5
  """
6
 
7
  import asyncio
8
  import logging
9
  import os
10
+
11
+ import capnp
12
  import numpy as np
13
  import torch
 
14
 
15
  # Load the schema
16
  schema_file = os.path.join(os.path.dirname(__file__), "agent.capnp")
 
18
 
19
  logger = logging.getLogger(__name__)
20
 
21
+ # Default network configuration
22
+ DEFAULT_RPC_ADDRESS = "127.0.0.1"
23
+ DEFAULT_RPC_PORT = 8000
24
+
25
+ _TRAVERSAL_WORDS = 100 * 1024 * 1024 # match client; tune appropriately
26
+
27
 
28
  class AgentServer(agent_capnp.Agent.Server):
29
  """Cap'n Proto server implementation for AgentInterface"""
 
34
  self.logger.info("AgentServer initialized with agent: %s", type(agent).__name__)
35
 
36
  async def act(self, obs, **kwargs):
37
+ """Handle act RPC call. 'obs' is expected to be an Agent.Tensor struct."""
38
  try:
39
+ # obs is a struct with .data, .shape, .dtype
40
+ byte_len = len(obs.data) if obs and obs.data is not None else 0
41
+ self.logger.debug(
42
+ "Server.act invoked; incoming obs bytes=%d shape=%s dtype=%s",
43
+ byte_len,
44
+ list(obs.shape) if obs else None,
45
+ obs.dtype if obs else None,
46
+ )
47
+
48
+ # reconstruct numpy observation
49
+ obs_np = np.frombuffer(obs.data, dtype=np.dtype(obs.dtype)).reshape(
50
+ tuple(obs.shape)
51
+ )
52
+
53
+ # call the underlying agent synchronously (user's agent.act should accept ndarray)
54
+ action_tensor = self.agent.act(obs_np)
55
+
56
+ # convert to numpy
57
  if isinstance(action_tensor, torch.Tensor):
58
+ action_np = action_tensor.detach().cpu().numpy()
59
  else:
60
+ action_np = np.array(action_tensor)
 
 
 
 
 
 
61
 
62
+ # Build response Tensor
63
+ response = agent_capnp.Tensor.new_message()
64
+ response.data = action_np.tobytes()
65
+ response.shape = [int(s) for s in action_np.shape]
66
+ response.dtype = str(action_np.dtype)
67
  return response
68
+ except Exception:
69
+ self.logger.exception("Exception in AgentServer.act")
70
  raise
71
 
72
  async def reset(self, **kwargs):
 
73
  try:
74
  self.agent.reset()
75
+ except Exception:
76
+ self.logger.exception("Error in reset")
77
  raise
78
 
79
+ async def ping(self, message, **kwargs):
80
+ self.logger.info(f"Ping received: {message}")
81
+ return "pong"
82
+
83
 
84
+ async def serve(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
85
  """Serve the agent using asyncio approach"""
86
 
87
  async def new_connection(stream):
 
88
  try:
89
+ server = capnp.TwoPartyServer(
90
+ stream,
91
+ bootstrap=AgentServer(agent),
92
+ traversal_limit_in_words=_TRAVERSAL_WORDS,
93
+ )
94
  await server.on_disconnect()
95
+ except Exception:
96
+ logger.exception("Error handling connection")
97
 
 
 
 
 
98
  server = await capnp.AsyncIoStream.create_server(new_connection, address, port)
99
+ logger.info("Agent RPC server listening on %s:%d", address, port)
 
100
 
101
  try:
 
102
  async with server:
103
  await server.serve_forever()
104
+ except Exception:
105
+ logger.exception("Server error")
106
  finally:
107
  logger.info("Server shutting down")
108
 
109
 
110
+ def start_server(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
 
 
111
  async def run_server_with_kj():
112
  async with capnp.kj_loop():
113
  await serve(agent, address, port)
 
118
  logger.info("Server stopped by user")
119
 
120
 
121
+ def run_server_in_process(agent, address=DEFAULT_RPC_ADDRESS, port=DEFAULT_RPC_PORT):
 
 
122
  async def run_with_kj():
123
  async with capnp.kj_loop():
124
  await serve(agent, address, port)
evaluation.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import logging
3
+ import os
4
+ import sys
5
+ import time
6
+ from datetime import datetime
7
+ from typing import Dict, Optional
8
+
9
+ import gymnasium as gym
10
+ import metaworld
11
+ import numpy as np
12
+ from agent import RLAgent
13
+
14
+ from torch.utils.tensorboard import SummaryWriter
15
+
16
+
17
+ class AgentEvaluator:
18
+ """
19
+ Evaluator for running and assessing the agent in MetaWorld environments.
20
+ Includes TensorBoard logging for performance monitoring.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ task_name: str = "reach-v3",
26
+ render_mode: str = "human",
27
+ max_episodes: int = 5,
28
+ max_steps_per_episode: int = 200,
29
+ seed: Optional[int] = None,
30
+ use_tensorboard: bool = True,
31
+ log_dir: Optional[str] = None,
32
+ ):
33
+ """
34
+ Initialize the evaluator.
35
+
36
+ Args:
37
+ task_name: Name of the MetaWorld task to run
38
+ render_mode: Rendering mode ("human" for GUI, "rgb_array" for headless)
39
+ max_episodes: Maximum number of episodes to run
40
+ max_steps_per_episode: Maximum steps per episode
41
+ seed: Random seed for reproducibility
42
+ use_tensorboard: Whether to enable TensorBoard logging
43
+ log_dir: Directory for TensorBoard logs (auto-generated if None)
44
+ """
45
+ self.task_name = task_name
46
+ self.render_mode = render_mode
47
+ self.max_episodes = max_episodes
48
+ self.max_steps_per_episode = max_steps_per_episode
49
+ self.seed = seed or np.random.randint(0, 1000000)
50
+ self.use_tensorboard = use_tensorboard
51
+
52
+ self.logger = logging.getLogger(__name__)
53
+ self.env = None
54
+ self.agent = None
55
+
56
+ # Statistics tracking
57
+ self.episode_rewards = []
58
+ self.episode_lengths = []
59
+ self.success_rate = 0.0
60
+
61
+ # TensorBoard setup
62
+ self.tb_writer = None
63
+ if self.use_tensorboard:
64
+ if log_dir is None:
65
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
66
+ log_dir = f"runs/{self.task_name}_{timestamp}"
67
+
68
+ os.makedirs(log_dir, exist_ok=True)
69
+ self.tb_writer = SummaryWriter(log_dir)
70
+ self.logger.info(f"TensorBoard logging enabled: {log_dir}")
71
+ self.logger.info(f"View logs with: tensorboard --logdir {log_dir}")
72
+
73
+ """
74
+ Initialize the evaluator.
75
+
76
+ Args:
77
+ task_name: Name of the MetaWorld task to run
78
+ render_mode: Rendering mode ("human" for GUI, "rgb_array" for headless)
79
+ max_episodes: Maximum number of episodes to run
80
+ max_steps_per_episode: Maximum steps per episode
81
+ seed: Random seed for reproducibility
82
+ use_tensorboard: Whether to enable TensorBoard logging
83
+ log_dir: Directory for TensorBoard logs (auto-generated if None)
84
+ """
85
+ self.task_name = task_name
86
+ self.render_mode = render_mode
87
+ self.max_episodes = max_episodes
88
+ self.max_steps_per_episode = max_steps_per_episode
89
+ self.seed = seed or np.random.randint(0, 1000000)
90
+ self.use_tensorboard = use_tensorboard
91
+
92
+ self.logger = logging.getLogger(__name__)
93
+ self.env = None
94
+ self.agent = None
95
+
96
+ # Statistics tracking
97
+ self.episode_rewards = []
98
+ self.episode_lengths = []
99
+ self.success_rate = 0.0
100
+
101
+ # TensorBoard setup
102
+ self.tb_writer = None
103
+ if self.use_tensorboard:
104
+ if log_dir is None:
105
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
106
+ log_dir = f"runs/{self.task_name}_{timestamp}"
107
+
108
+ os.makedirs(log_dir, exist_ok=True)
109
+ self.tb_writer = SummaryWriter(log_dir)
110
+ self.logger.info(f"TensorBoard logging enabled: {log_dir}")
111
+ self.logger.info(f"View logs with: tensorboard --logdir {log_dir}")
112
+
113
+ def setup_environment(self) -> gym.Env:
114
+ """
115
+ Set up the MetaWorld environment with MuJoCo rendering.
116
+
117
+ Returns:
118
+ Configured gymnasium environment
119
+ """
120
+ try:
121
+ # Create MetaWorld environment
122
+ if self.task_name == "reach-v3":
123
+ # Use the reach task that matches our agent's policy
124
+ mt1 = metaworld.MT1(self.task_name, seed=self.seed)
125
+ env = mt1.train_classes[self.task_name]()
126
+ task = mt1.train_tasks[0]
127
+ env.set_task(task)
128
+ else:
129
+ # For other tasks, try to create them directly
130
+ mt1 = metaworld.MT1(self.task_name, seed=self.seed)
131
+ env = mt1.train_classes[self.task_name]()
132
+ task = mt1.train_tasks[0]
133
+ env.set_task(task)
134
+
135
+ # Wrap with gymnasium if needed
136
+ if not isinstance(env, gym.Env):
137
+ env = gym.make(env.spec.id if hasattr(env, "spec") else self.task_name)
138
+
139
+ # Configure rendering
140
+ if hasattr(env, "render_mode"):
141
+ env.render_mode = self.render_mode
142
+
143
+ self.logger.info(f"Environment created: {self.task_name}")
144
+ self.logger.info(f"Observation space: {env.observation_space}")
145
+ self.logger.info(f"Action space: {env.action_space}")
146
+
147
+ return env
148
+
149
+ except Exception as e:
150
+ self.logger.error(f"Failed to create environment {self.task_name}: {e}")
151
+ self.logger.info("Falling back to reach-v3 environment")
152
+
153
+ # Fallback to a simple reach environment
154
+ mt1 = metaworld.MT1("reach-v3", seed=self.seed)
155
+ env = mt1.train_classes["reach-v3"]()
156
+ task = mt1.train_tasks[0]
157
+ env.set_task(task)
158
+
159
+ return env
160
+
161
+ def setup_agent(self, env: gym.Env) -> RLAgent:
162
+ """
163
+ Set up the agent with the environment's observation and action spaces.
164
+
165
+ Args:
166
+ env: The gymnasium environment
167
+
168
+ Returns:
169
+ Configured RLAgent
170
+ """
171
+ agent = RLAgent(
172
+ observation_space=env.observation_space,
173
+ action_space=env.action_space,
174
+ seed=self.seed,
175
+ max_episode_steps=self.max_steps_per_episode,
176
+ )
177
+
178
+ self.logger.info("Agent initialized successfully")
179
+ return agent
180
+
181
+ def run_episode(self, episode_num: int) -> Dict[str, float]:
182
+ """
183
+ Run a single episode and return statistics.
184
+
185
+ Args:
186
+ episode_num: Episode number for logging
187
+
188
+ Returns:
189
+ Dictionary containing episode statistics
190
+ """
191
+ obs, info = self.env.reset(seed=self.seed + episode_num)
192
+ self.agent.reset()
193
+
194
+ episode_reward = 0.0
195
+ episode_length = 0
196
+ success = False
197
+ step_rewards = []
198
+
199
+ self.logger.info(f"Starting episode {episode_num + 1}")
200
+
201
+ for step in range(self.max_steps_per_episode):
202
+ try:
203
+ # Get action from agent
204
+ action_tensor = self.agent.act(obs)
205
+
206
+ # Convert to numpy array if needed
207
+ if hasattr(action_tensor, "numpy"):
208
+ action = action_tensor.numpy()
209
+ elif hasattr(action_tensor, "detach"):
210
+ action = action_tensor.detach().numpy()
211
+ else:
212
+ action = np.array(action_tensor)
213
+
214
+ # Take step in environment
215
+ obs, reward, terminated, truncated, info = self.env.step(action)
216
+
217
+ # Render the environment for human viewing
218
+ if self.render_mode == "human":
219
+ self.env.render()
220
+ time.sleep(0.02) # Small delay to make visualization smoother
221
+
222
+ episode_reward += reward
223
+ episode_length += 1
224
+ step_rewards.append(reward)
225
+
226
+ # Log to TensorBoard (step-level metrics)
227
+ if self.tb_writer:
228
+ global_step = episode_num * self.max_steps_per_episode + step
229
+ self.tb_writer.add_scalar("Step/Reward", reward, global_step)
230
+ self.tb_writer.add_scalar(
231
+ "Step/CumulativeReward", episode_reward, global_step
232
+ )
233
+
234
+ # Check for success (MetaWorld specific)
235
+ if hasattr(info, "get") and info.get("success", False):
236
+ success = True
237
+
238
+ # Log progress occasionally
239
+ if step % 50 == 0:
240
+ self.logger.debug(
241
+ f"Episode {episode_num + 1}, Step {step}: "
242
+ f"Reward {reward:.3f}, Total {episode_reward:.3f}"
243
+ )
244
+
245
+ if terminated or truncated:
246
+ break
247
+
248
+ except Exception as e:
249
+ self.logger.error(f"Error during step {step}: {e}")
250
+ break
251
+
252
+ # Log episode-level metrics to TensorBoard
253
+ if self.tb_writer:
254
+ self.tb_writer.add_scalar("Episode/Reward", episode_reward, episode_num)
255
+ self.tb_writer.add_scalar("Episode/Length", episode_length, episode_num)
256
+ self.tb_writer.add_scalar("Episode/Success", float(success), episode_num)
257
+ if step_rewards:
258
+ self.tb_writer.add_scalar(
259
+ "Episode/AvgStepReward", np.mean(step_rewards), episode_num
260
+ )
261
+ self.tb_writer.add_scalar(
262
+ "Episode/MaxStepReward", np.max(step_rewards), episode_num
263
+ )
264
+ self.tb_writer.add_scalar(
265
+ "Episode/MinStepReward", np.min(step_rewards), episode_num
266
+ )
267
+
268
+ episode_stats = {
269
+ "reward": episode_reward,
270
+ "length": episode_length,
271
+ "success": success,
272
+ }
273
+
274
+ self.logger.info(
275
+ f"Episode {episode_num + 1} completed: "
276
+ f"Reward {episode_reward:.3f}, "
277
+ f"Length {episode_length}, "
278
+ f"Success {success}"
279
+ )
280
+
281
+ return episode_stats
282
+
283
+ def run_evaluation(self):
284
+ """
285
+ Run the complete evaluation session.
286
+ """
287
+ self.logger.info("Starting agent evaluation")
288
+
289
+ # Setup environment and agent
290
+ self.env = self.setup_environment()
291
+ self.agent = self.setup_agent(self.env)
292
+
293
+ # Run episodes
294
+ total_successes = 0
295
+
296
+ for episode in range(self.max_episodes):
297
+ episode_stats = self.run_episode(episode)
298
+
299
+ self.episode_rewards.append(episode_stats["reward"])
300
+ self.episode_lengths.append(episode_stats["length"])
301
+
302
+ if episode_stats["success"]:
303
+ total_successes += 1
304
+
305
+ # Calculate final statistics
306
+ self.success_rate = total_successes / self.max_episodes
307
+ avg_reward = np.mean(self.episode_rewards)
308
+ avg_length = np.mean(self.episode_lengths)
309
+ std_reward = np.std(self.episode_rewards)
310
+ std_length = np.std(self.episode_lengths)
311
+
312
+ # Log summary metrics to TensorBoard
313
+ if self.tb_writer:
314
+ self.tb_writer.add_scalar("Summary/AvgReward", avg_reward, 0)
315
+ self.tb_writer.add_scalar("Summary/StdReward", std_reward, 0)
316
+ self.tb_writer.add_scalar("Summary/AvgLength", avg_length, 0)
317
+ self.tb_writer.add_scalar("Summary/StdLength", std_length, 0)
318
+ self.tb_writer.add_scalar("Summary/SuccessRate", self.success_rate, 0)
319
+
320
+ # Add histogram of rewards and lengths
321
+ self.tb_writer.add_histogram(
322
+ "Summary/RewardDistribution", np.array(self.episode_rewards), 0
323
+ )
324
+ self.tb_writer.add_histogram(
325
+ "Summary/LengthDistribution", np.array(self.episode_lengths), 0
326
+ )
327
+
328
+ # Add hyperparameters
329
+ self.tb_writer.add_hparams(
330
+ {
331
+ "task": self.task_name,
332
+ "episodes": self.max_episodes,
333
+ "max_steps": self.max_steps_per_episode,
334
+ "seed": self.seed,
335
+ "render_mode": self.render_mode,
336
+ },
337
+ {
338
+ "avg_reward": avg_reward,
339
+ "success_rate": self.success_rate,
340
+ "avg_length": avg_length,
341
+ },
342
+ )
343
+
344
+ self.tb_writer.flush()
345
+ self.tb_writer.close()
346
+
347
+ self.logger.info("=" * 50)
348
+ self.logger.info("EVALUATION SUMMARY")
349
+ self.logger.info("=" * 50)
350
+ self.logger.info(f"Task: {self.task_name}")
351
+ self.logger.info(f"Episodes: {self.max_episodes}")
352
+ self.logger.info(f"Average Reward: {avg_reward:.3f} ± {std_reward:.3f}")
353
+ self.logger.info(f"Average Length: {avg_length:.1f} ± {std_length:.1f}")
354
+ self.logger.info(f"Success Rate: {self.success_rate:.1%}")
355
+ if self.tb_writer:
356
+ self.logger.info(
357
+ "TensorBoard logs saved. View with: tensorboard --logdir runs/"
358
+ )
359
+ self.logger.info("=" * 50)
360
+
361
+ # Close environment
362
+ if self.env:
363
+ self.env.close()
364
+
365
+ return {
366
+ "task": self.task_name,
367
+ "episodes": self.max_episodes,
368
+ "avg_reward": avg_reward,
369
+ "std_reward": std_reward,
370
+ "avg_length": avg_length,
371
+ "std_length": std_length,
372
+ "success_rate": self.success_rate,
373
+ "episode_rewards": self.episode_rewards,
374
+ "episode_lengths": self.episode_lengths,
375
+ }
376
+
377
+ def list_available_tasks(self):
378
+ """
379
+ List all available MetaWorld tasks.
380
+ """
381
+ try:
382
+ # Get all MT1 tasks
383
+ mt1_tasks = metaworld.MT1.get_train_tasks()
384
+ self.logger.info("Available MetaWorld MT1 tasks:")
385
+ for i, task in enumerate(mt1_tasks, 1):
386
+ self.logger.info(f" {i}. {task}")
387
+
388
+ # Get all MT10 tasks
389
+ mt10 = metaworld.MT10()
390
+ self.logger.info("\nAvailable MetaWorld MT10 tasks:")
391
+ for i, task in enumerate(mt10.train_classes.keys(), 1):
392
+ self.logger.info(f" {i}. {task}")
393
+
394
+ except Exception as e:
395
+ self.logger.error(f"Error listing tasks: {e}")
396
+ self.logger.info("Some common MetaWorld tasks:")
397
+ common_tasks = [
398
+ "reach-v3",
399
+ "push-v3",
400
+ "pick-place-v3",
401
+ "door-open-v3",
402
+ "drawer-open-v3",
403
+ "button-press-topdown-v3",
404
+ "peg-insert-side-v3",
405
+ ]
406
+ for i, task in enumerate(common_tasks, 1):
407
+ self.logger.info(f" {i}. {task}")
408
+
409
+
410
+ def setup_logging(level=logging.INFO):
411
+ """Configure logging for the evaluator."""
412
+ logging.basicConfig(
413
+ level=level,
414
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
415
+ handlers=[logging.StreamHandler(sys.stdout)],
416
+ )
417
+
418
+
419
+ def main():
420
+ """Main entry point for the evaluator."""
421
+ parser = argparse.ArgumentParser(
422
+ description="Evaluate the MetaWorld agent in MuJoCo"
423
+ )
424
+ parser.add_argument(
425
+ "--task",
426
+ type=str,
427
+ default="reach-v3",
428
+ help="MetaWorld task name (default: reach-v3)",
429
+ )
430
+ parser.add_argument(
431
+ "--episodes",
432
+ type=int,
433
+ default=5,
434
+ help="Number of episodes to run (default: 5)",
435
+ )
436
+ parser.add_argument(
437
+ "--steps",
438
+ type=int,
439
+ default=200,
440
+ help="Maximum steps per episode (default: 200)",
441
+ )
442
+ parser.add_argument(
443
+ "--seed",
444
+ type=int,
445
+ default=None,
446
+ help="Random seed for reproducibility",
447
+ )
448
+ parser.add_argument(
449
+ "--render-mode",
450
+ type=str,
451
+ default="human",
452
+ choices=["human", "rgb_array"],
453
+ help="Rendering mode (default: human)",
454
+ )
455
+ parser.add_argument(
456
+ "--log-level",
457
+ type=str,
458
+ default="INFO",
459
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
460
+ help="Logging level (default: INFO)",
461
+ )
462
+ parser.add_argument(
463
+ "--list-tasks",
464
+ action="store_true",
465
+ help="List available MetaWorld tasks and exit",
466
+ )
467
+
468
+ args = parser.parse_args()
469
+
470
+ # Setup logging
471
+ log_level = getattr(logging, args.log_level)
472
+ setup_logging(log_level)
473
+
474
+ # Create evaluator
475
+ evaluator = AgentEvaluator(
476
+ task_name=args.task,
477
+ render_mode=args.render_mode,
478
+ max_episodes=args.episodes,
479
+ max_steps_per_episode=args.steps,
480
+ seed=args.seed,
481
+ )
482
+
483
+ if args.list_tasks:
484
+ evaluator.list_available_tasks()
485
+ return
486
+
487
+ try:
488
+ evaluator.run_evaluation()
489
+ except KeyboardInterrupt:
490
+ logging.getLogger(__name__).info("Evaluation stopped by user")
491
+ except Exception as e:
492
+ logging.getLogger(__name__).error(
493
+ f"Error during evaluation: {e}", exc_info=True
494
+ )
495
+ sys.exit(1)
496
+
497
+
498
+ if __name__ == "__main__":
499
+ main()
main.py CHANGED
@@ -1,17 +1,22 @@
1
  #!/usr/bin/env python3
2
  """
3
- Main entry point for the agent server.
4
 
5
- This script creates an agent implementation and starts the RPC server
6
- to handle requests from the evaluator.
 
7
  """
8
 
9
  import argparse
10
  import logging
 
11
  import sys
 
 
 
12
 
13
  from agent import RLAgent
14
- from agent_server import start_server
15
 
16
 
17
  def setup_logging(level=logging.INFO):
@@ -23,16 +28,81 @@ def setup_logging(level=logging.INFO):
23
  )
24
 
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
  def main():
27
  """Main entry point."""
28
- parser = argparse.ArgumentParser(description="Start the agent server")
29
- parser.add_argument(
30
- "--host", type=str, default="*", help="Host to bind the server to"
 
 
 
 
 
 
 
 
31
  )
32
- parser.add_argument(
 
 
 
 
 
 
 
 
 
33
  "--port", type=int, default=8000, help="Port to bind the server to"
34
  )
35
- parser.add_argument(
36
  "--log-level",
37
  type=str,
38
  default="INFO",
@@ -40,13 +110,97 @@ def main():
40
  help="Logging level",
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  args = parser.parse_args()
44
 
 
 
 
 
 
45
  # Setup logging
46
  log_level = getattr(logging, args.log_level)
47
  setup_logging(log_level)
48
  logger = logging.getLogger(__name__)
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  logger.info(f"Starting agent server on {args.host}:{args.port}")
51
 
52
  # Create the RLAgent
@@ -62,5 +216,72 @@ def main():
62
  sys.exit(1)
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  if __name__ == "__main__":
66
  main()
 
1
  #!/usr/bin/env python3
2
  """
3
+ Main entry point for the agent server and evaluation.
4
 
5
+ This script provides multiple commands:
6
+ - server: Creates an agent implementation and starts the RPC server
7
+ - eval: Runs local evaluation of the agent with visual rendering
8
  """
9
 
10
  import argparse
11
  import logging
12
+ import subprocess
13
  import sys
14
+ import threading
15
+ import time
16
+ import webbrowser
17
 
18
  from agent import RLAgent
19
+ from evaluation import AgentEvaluator
20
 
21
 
22
  def setup_logging(level=logging.INFO):
 
28
  )
29
 
30
 
31
+ def launch_tensorboard(log_dir, port=6006):
32
+ """Launch TensorBoard in a separate thread."""
33
+
34
+ def run_tensorboard():
35
+ try:
36
+ # Wait a moment for initial logs to be written
37
+ time.sleep(2)
38
+
39
+ # Launch TensorBoard
40
+ subprocess.run(
41
+ [
42
+ "tensorboard",
43
+ "--logdir",
44
+ log_dir,
45
+ "--port",
46
+ str(port),
47
+ "--host",
48
+ "localhost",
49
+ "--reload_interval",
50
+ "1",
51
+ ],
52
+ check=True,
53
+ capture_output=True,
54
+ )
55
+ except subprocess.CalledProcessError:
56
+ # TensorBoard failed to start, but don't crash the evaluation
57
+ pass
58
+ except FileNotFoundError:
59
+ # TensorBoard not installed
60
+ pass
61
+
62
+ # Start TensorBoard in background thread
63
+ tb_thread = threading.Thread(target=run_tensorboard, daemon=True)
64
+ tb_thread.start()
65
+
66
+ # Give TensorBoard a moment to start
67
+ time.sleep(3)
68
+
69
+ # Try to open browser
70
+ try:
71
+ webbrowser.open(f"http://localhost:{port}")
72
+ except Exception:
73
+ # Browser opening failed, but that's okay
74
+ pass
75
+
76
+ return f"http://localhost:{port}"
77
+
78
+
79
  def main():
80
  """Main entry point."""
81
+ parser = argparse.ArgumentParser(
82
+ description="Agent server and evaluation tool",
83
+ formatter_class=argparse.RawDescriptionHelpFormatter,
84
+ epilog="""
85
+ Examples:
86
+ python main.py server --host localhost --port 8000
87
+ python main.py eval --task reach-v3 --episodes 5
88
+ python main.py eval --task push-v3 --episodes 10 --render-mode rgb_array
89
+ python main.py eval --task reach-v3 --episodes 20 --no-tensorboard
90
+ python main.py eval --task door-open-v3 --log-dir custom_logs/
91
+ """,
92
  )
93
+
94
+ # Add subcommands
95
+ subparsers = parser.add_subparsers(dest="command", help="Available commands")
96
+
97
+ # Server subcommand
98
+ server_parser = subparsers.add_parser("server", help="Start the agent server")
99
+ server_parser.add_argument(
100
+ "--host", type=str, default="0.0.0.0", help="Host to bind the server to"
101
+ )
102
+ server_parser.add_argument(
103
  "--port", type=int, default=8000, help="Port to bind the server to"
104
  )
105
+ server_parser.add_argument(
106
  "--log-level",
107
  type=str,
108
  default="INFO",
 
110
  help="Logging level",
111
  )
112
 
113
+ # Evaluation subcommand
114
+ eval_parser = subparsers.add_parser("eval", help="Run local agent evaluation")
115
+ eval_parser.add_argument(
116
+ "--task",
117
+ type=str,
118
+ default="reach-v3",
119
+ help="MetaWorld task name (default: reach-v3)",
120
+ )
121
+ eval_parser.add_argument(
122
+ "--episodes",
123
+ type=int,
124
+ default=5,
125
+ help="Number of episodes to run (default: 5)",
126
+ )
127
+ eval_parser.add_argument(
128
+ "--steps",
129
+ type=int,
130
+ default=200,
131
+ help="Maximum steps per episode (default: 200)",
132
+ )
133
+ eval_parser.add_argument(
134
+ "--seed",
135
+ type=int,
136
+ default=None,
137
+ help="Random seed for reproducibility",
138
+ )
139
+ eval_parser.add_argument(
140
+ "--render-mode",
141
+ type=str,
142
+ default="human",
143
+ choices=["human", "rgb_array"],
144
+ help="Rendering mode (default: human)",
145
+ )
146
+ eval_parser.add_argument(
147
+ "--log-level",
148
+ type=str,
149
+ default="INFO",
150
+ choices=["DEBUG", "INFO", "WARNING", "ERROR"],
151
+ help="Logging level (default: INFO)",
152
+ )
153
+ eval_parser.add_argument(
154
+ "--list-tasks",
155
+ action="store_true",
156
+ help="List available MetaWorld tasks and exit",
157
+ )
158
+ eval_parser.add_argument(
159
+ "--tensorboard",
160
+ action="store_true",
161
+ default=True,
162
+ help="Enable TensorBoard logging (default: True)",
163
+ )
164
+ eval_parser.add_argument(
165
+ "--no-tensorboard",
166
+ action="store_true",
167
+ help="Disable TensorBoard logging",
168
+ )
169
+ eval_parser.add_argument(
170
+ "--log-dir",
171
+ type=str,
172
+ default=None,
173
+ help="TensorBoard log directory (auto-generated if not specified)",
174
+ )
175
+
176
  args = parser.parse_args()
177
 
178
+ # If no command is provided, show help
179
+ if not args.command:
180
+ parser.print_help()
181
+ sys.exit(1)
182
+
183
  # Setup logging
184
  log_level = getattr(logging, args.log_level)
185
  setup_logging(log_level)
186
  logger = logging.getLogger(__name__)
187
 
188
+ if args.command == "server":
189
+ run_server(args, logger)
190
+ elif args.command == "eval":
191
+ run_evaluation(args, logger)
192
+
193
+
194
+ def run_server(args, logger):
195
+ """Run the agent server."""
196
+ # Import server functionality only when needed to avoid capnp dependency for eval
197
+ try:
198
+ from agent_server import start_server
199
+ except ImportError as e:
200
+ logger.error(f"Failed to import server functionality: {e}")
201
+ logger.error("Make sure capnp and other server dependencies are installed")
202
+ sys.exit(1)
203
+
204
  logger.info(f"Starting agent server on {args.host}:{args.port}")
205
 
206
  # Create the RLAgent
 
216
  sys.exit(1)
217
 
218
 
219
+ def run_evaluation(args, logger):
220
+ """Run local agent evaluation."""
221
+ logger.info("Running local evaluation")
222
+
223
+ # Determine TensorBoard usage
224
+ use_tensorboard = args.tensorboard and not args.no_tensorboard
225
+
226
+ # Setup log directory if using TensorBoard
227
+ log_dir = args.log_dir
228
+ if use_tensorboard and not log_dir:
229
+ from datetime import datetime
230
+
231
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
232
+ log_dir = f"runs/{args.task}_{timestamp}"
233
+
234
+ # Create evaluator
235
+ evaluator = AgentEvaluator(
236
+ task_name=args.task,
237
+ render_mode=args.render_mode,
238
+ max_episodes=args.episodes,
239
+ max_steps_per_episode=args.steps,
240
+ seed=args.seed,
241
+ use_tensorboard=use_tensorboard,
242
+ log_dir=log_dir,
243
+ )
244
+
245
+ if args.list_tasks:
246
+ evaluator.list_available_tasks()
247
+ return
248
+
249
+ # Launch TensorBoard if enabled
250
+ tensorboard_url = None
251
+ if use_tensorboard and log_dir:
252
+ logger.info("Starting TensorBoard...")
253
+ try:
254
+ tensorboard_url = launch_tensorboard(log_dir)
255
+ logger.info(f"TensorBoard available at: {tensorboard_url}")
256
+ logger.info("TensorBoard will show metrics in real-time during evaluation")
257
+ except Exception as e:
258
+ logger.warning(f"Failed to start TensorBoard: {e}")
259
+ logger.info("Continuing evaluation without TensorBoard...")
260
+
261
+ try:
262
+ evaluator.run_evaluation()
263
+ logger.info("Evaluation completed successfully")
264
+
265
+ if tensorboard_url:
266
+ logger.info(f"View detailed metrics at: {tensorboard_url}")
267
+ logger.info("TensorBoard will continue running in the background")
268
+
269
+ # Optionally save results to file
270
+ # import json
271
+ # with open("evaluation_results.json", "w") as f:
272
+ # json.dump(results, f, indent=2)
273
+ # logger.info("Results saved to evaluation_results.json")
274
+
275
+ except KeyboardInterrupt:
276
+ logger.info("Evaluation stopped by user")
277
+ if tensorboard_url:
278
+ logger.info(f"TensorBoard still available at: {tensorboard_url}")
279
+ except Exception as e:
280
+ logger.error(f"Error during evaluation: {e}", exc_info=True)
281
+ if tensorboard_url:
282
+ logger.info(f"TensorBoard still available at: {tensorboard_url}")
283
+ sys.exit(1)
284
+
285
+
286
  if __name__ == "__main__":
287
  main()
pyproject.toml CHANGED
@@ -6,7 +6,13 @@ readme = "README.md"
6
  requires-python = ">=3.12"
7
  dependencies = [
8
  "metaworld>=3.0.0",
9
- "torch>=2.8.0"
 
 
 
 
 
 
10
  ]
11
 
12
  [dependency-groups]
 
6
  requires-python = ">=3.12"
7
  dependencies = [
8
  "metaworld>=3.0.0",
9
+ "torch>=2.8.0",
10
+ "gymnasium>=0.29.0",
11
+ "mujoco>=3.0.0",
12
+ "numpy>=1.24.0",
13
+ "pycapnp>=2.1.0",
14
+ "tensorboard>=2.15.0",
15
+ "matplotlib>=3.7.0"
16
  ]
17
 
18
  [dependency-groups]
uv.lock ADDED
The diff for this file is too large to render. See raw diff