rishiad commited on
Commit
3b2789b
·
unverified ·
1 Parent(s): a532d53

feat: setup first submission for kinitro

Browse files
Files changed (7) hide show
  1. agent.capnp +13 -0
  2. agent.py +165 -0
  3. agent_interface.py +53 -0
  4. agent_server.py +70 -0
  5. main.py +66 -0
  6. pyproject.toml +15 -0
  7. uv.lock +0 -0
agent.capnp ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ @0x893bac407c81b48c
2
+
3
+ interface Agent {
4
+
5
+ struct Tensor {
6
+ data @0 :Data; # tensor bytes tensor.numpy().tobytes()
7
+ shape @1 :List(UInt64); # tensor shape list(tensor.shape())
8
+ dtype @2 :Text; # data type name tensor.dtype()
9
+ }
10
+
11
+ act @0 (obs :Data) -> (action :Tensor);
12
+ reset @1 () -> ();
13
+ }
agent.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Implementation of the AgentInterface for MetaWorld tasks.
3
+
4
+ This agent uses the SawyerPickPlaceV2Policy from MetaWorld as an expert policy.
5
+ """
6
+
7
+ import logging
8
+ from typing import Any, Dict
9
+
10
+ import gymnasium as gym
11
+ import metaworld
12
+ import numpy as np
13
+ import torch
14
+ from agent_interface import AgentInterface
15
+ from metaworld.policies import SawyerPickPlaceV2Policy
16
+
17
+
18
+ class RLAgent(AgentInterface):
19
+ """
20
+ MetaWorld agent implementation using the SawyerPickPlaceV2Policy expert policy.
21
+
22
+ This agent uses the expert policy from MetaWorld for pick and place tasks.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ observation_space: gym.Space | None = None,
28
+ action_space: gym.Space | None = None,
29
+ seed: int | None = None,
30
+ **kwargs,
31
+ ):
32
+ super().__init__(observation_space, action_space, seed, **kwargs)
33
+
34
+ self.logger = logging.getLogger(__name__)
35
+ self.logger.info(f"Initializing MetaWorld agent with seed {self.seed}")
36
+
37
+ self.policy = SawyerPickPlaceV2Policy()
38
+ self.logger.info("Successfully initialized SawyerPickPlaceV2Policy")
39
+
40
+ # Track episode state
41
+ self.episode_step = 0
42
+ self.max_episode_steps = kwargs.get("max_episode_steps", 200)
43
+
44
+ self.logger.info("MetaWorld agent initialized successfully")
45
+
46
+ def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor:
47
+ """
48
+ Process the observation and return an action using the MetaWorld expert policy.
49
+
50
+ Args:
51
+ obs: Observation from the environment
52
+ kwargs: Additional arguments
53
+
54
+ Returns:
55
+ action: Action tensor to take in the environment
56
+ """
57
+ try:
58
+ # Process observation to extract the format needed by the expert policy
59
+ processed_obs = self._process_observation(obs)
60
+
61
+ # Use the expert policy (MetaWorld is always available)
62
+ # MetaWorld policies expect numpy arrays
63
+ action_numpy = self.policy.get_action(processed_obs)
64
+ action_tensor = torch.from_numpy(np.array(action_numpy)).float()
65
+
66
+ # Log occasionally
67
+ if self.episode_step % 50 == 0:
68
+ self.logger.debug(f"Using expert policy action: {action_numpy}")
69
+
70
+ # Increment episode step
71
+ self.episode_step += 1
72
+
73
+ # Occasionally log actions to avoid spam
74
+ if self.episode_step % 50 == 0:
75
+ self.logger.debug(
76
+ f"Step {self.episode_step}: Action shape {action_tensor.shape}"
77
+ )
78
+
79
+ return action_tensor
80
+
81
+ except Exception as e:
82
+ self.logger.error(f"Error in act method: {e}", exc_info=True)
83
+ # Return zeros as a fallback
84
+ if isinstance(self.action_space, gym.spaces.Box):
85
+ return torch.zeros(self.action_space.shape[0], dtype=torch.float32)
86
+ else:
87
+ return torch.zeros(4, dtype=torch.float32)
88
+
89
+ def _process_observation(self, obs):
90
+ """
91
+ Helper method to process observations for the MetaWorld expert policy.
92
+
93
+ MetaWorld policies typically expect a specific observation format.
94
+ """
95
+ if isinstance(obs, dict):
96
+ # MetaWorld environment can return observations in different formats
97
+ if "observation" in obs:
98
+ # Standard format for goal-observable environments
99
+ processed_obs = obs["observation"]
100
+ elif "obs" in obs:
101
+ processed_obs = obs["obs"]
102
+ elif "state_observation" in obs:
103
+ # Some MetaWorld environments use this key
104
+ processed_obs = obs["state_observation"]
105
+ elif "goal_achieved" in obs:
106
+ # If we have information about goal achievement
107
+ # This might be needed for certain policy decisions
108
+ achievement = obs.get("goal_achieved", False)
109
+ base_obs = next(iter(obs.values()))
110
+ self.logger.debug(f"Goal achieved: {achievement}")
111
+ processed_obs = base_obs
112
+ else:
113
+ # If structure is unknown, use the first value
114
+ processed_obs = next(iter(obs.values()))
115
+ self.logger.debug(f"Using observation key: {next(iter(obs.keys()))}")
116
+ else:
117
+ # If already a numpy array or similar, use directly
118
+ processed_obs = obs
119
+
120
+ # Ensure we're returning a numpy array as expected by MetaWorld policies
121
+ if not isinstance(processed_obs, np.ndarray):
122
+ try:
123
+ processed_obs = np.array(processed_obs, dtype=np.float32)
124
+ except Exception as e:
125
+ self.logger.error(f"Failed to convert observation to numpy array: {e}")
126
+ # Return a dummy observation if conversion fails
127
+ if (
128
+ self.observation_space
129
+ and hasattr(self.observation_space, "shape")
130
+ and self.observation_space.shape is not None
131
+ ):
132
+ processed_obs = np.zeros(
133
+ self.observation_space.shape, dtype=np.float32
134
+ )
135
+ else:
136
+ # Typical MetaWorld observation dimension if all else fails
137
+ processed_obs = np.zeros(39, dtype=np.float32)
138
+
139
+ return processed_obs
140
+
141
+ def reset(self) -> None:
142
+ """
143
+ Reset agent state between episodes.
144
+ """
145
+ self.logger.debug("Resetting agent")
146
+ self.episode_step = 0
147
+ # Any other stateful components would be reset here
148
+
149
+ def _build_model(self):
150
+ """
151
+ Build a neural network model for the agent.
152
+
153
+ This is a placeholder for where you would define your neural network
154
+ architecture using PyTorch, TensorFlow, or another framework.
155
+ """
156
+ # Example of where you might build a simple PyTorch model
157
+ # model = torch.nn.Sequential(
158
+ # torch.nn.Linear(self.observation_space.shape[0], 128),
159
+ # torch.nn.ReLU(),
160
+ # torch.nn.Linear(128, 64),
161
+ # torch.nn.ReLU(),
162
+ # torch.nn.Linear(64, self.action_space.shape[0]),
163
+ # )
164
+ # return model
165
+ pass
agent_interface.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Abstract base class defining the standard interface for all agents.
3
+
4
+ All miner-submitted agents must implement this interface to be evaluated.
5
+ """
6
+
7
+ from abc import ABC, abstractmethod
8
+
9
+ import gymnasium as gym
10
+ import numpy as np
11
+ import torch
12
+
13
+
14
+ class AgentInterface(ABC):
15
+ """
16
+ Standard interface that all miner implementations must follow.
17
+
18
+ This ensures a consistent contract between the evaluator and any submitted agent,
19
+ regardless of the underlying model architecture or implementation details.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ observation_space: gym.Space | None = None,
25
+ action_space: gym.Space | None = None,
26
+ seed: int | None = None,
27
+ **kwargs,
28
+ ):
29
+ self.observation_space = observation_space or gym.spaces.Box(
30
+ low=-1, high=1, shape=(100,), dtype=np.float32
31
+ )
32
+ self.action_space = action_space or gym.spaces.Box(
33
+ low=-1, high=1, shape=(4,), dtype=np.float32
34
+ )
35
+ self.seed = seed or np.random.randint(0, 1000000)
36
+ self.rng = np.random.default_rng(seed)
37
+
38
+ @abstractmethod
39
+ def act(self, obs: dict, **kwargs) -> torch.Tensor:
40
+ """
41
+ Take action given current observation and any additional arguments.
42
+ """
43
+ pass
44
+
45
+ def reset(self) -> None:
46
+ """
47
+ Reset agent state for new episode.
48
+
49
+ This is called at the beginning of each episode. Stateless agents
50
+ can implement this as a no-op. Agents with internal memory/history
51
+ should reset their state here.
52
+ """
53
+ pass
agent_server.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The agent server runs on the miner container. The host calls these functions
2
+
3
+ import asyncio
4
+ import logging
5
+ import pickle
6
+
7
+ import agent_capnp
8
+ import capnp
9
+ import numpy as np
10
+ import torch
11
+
12
+ from .agent_interface import AgentInterface
13
+
14
+
15
+ class AgentServer(agent_capnp.Agent.Server):
16
+ def __init__(self, agent: AgentInterface):
17
+ self.agent = agent
18
+ self.logger = logging.getLogger(__name__)
19
+ self.logger.info("AgentServer initialized with agent: %s", type(agent).__name__)
20
+
21
+ async def act(self, obs, **kwargs):
22
+ try:
23
+ # Deserialize observation from bytes
24
+ observation = pickle.loads(obs)
25
+
26
+ # Call the agent's act method
27
+ action_tensor = self.agent.act(observation)
28
+
29
+ # Convert to numpy if it's a torch tensor
30
+ if isinstance(action_tensor, torch.Tensor):
31
+ action_numpy = action_tensor.detach().cpu().numpy()
32
+ else:
33
+ action_numpy = np.array(action_tensor)
34
+
35
+ # Prepare tensor response
36
+ response = agent_capnp.Agent.Tensor.new_message()
37
+ response.data = action_numpy.tobytes()
38
+ response.shape = list(action_numpy.shape)
39
+ response.dtype = str(action_numpy.dtype)
40
+
41
+ return response
42
+ except Exception as e:
43
+ self.logger.error(f"Error in act: {e}", exc_info=True)
44
+ raise
45
+
46
+ async def reset(self, **kwargs):
47
+ try:
48
+ self.agent.reset()
49
+ except Exception as e:
50
+ self.logger.error(f"Error in reset: {e}", exc_info=True)
51
+ raise
52
+
53
+
54
+ async def serve(agent: AgentInterface, address="*", port=8000):
55
+ server = capnp.TwoPartyServer(address, port, bootstrap=AgentServer(agent))
56
+ logging.info(f"Agent RPC server listening on {address}:{port}")
57
+
58
+ # Keep the server running
59
+ try:
60
+ await server.run_forever()
61
+ finally:
62
+ server.close()
63
+
64
+
65
+ def start_server(agent: AgentInterface, address="*", port=8000):
66
+ loop = asyncio.get_event_loop()
67
+ try:
68
+ loop.run_until_complete(serve(agent, address, port))
69
+ except KeyboardInterrupt:
70
+ logging.info("Server stopped by user")
main.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Main entry point for the agent server.
4
+
5
+ This script creates an agent implementation and starts the RPC server
6
+ to handle requests from the evaluator.
7
+ """
8
+
9
+ import argparse
10
+ import logging
11
+ import sys
12
+
13
+ from agent import RLAgent
14
+ from agent_server import start_server
15
+
16
+
17
+ def setup_logging(level=logging.INFO):
18
+ """Configure logging."""
19
+ logging.basicConfig(
20
+ level=level,
21
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
22
+ handlers=[logging.StreamHandler(sys.stdout)],
23
+ )
24
+
25
+
26
+ def main():
27
+ """Main entry point."""
28
+ parser = argparse.ArgumentParser(description="Start the agent server")
29
+ parser.add_argument(
30
+ "--host", type=str, default="*", help="Host to bind the server to"
31
+ )
32
+ parser.add_argument(
33
+ "--port", type=int, default=8000, help="Port to bind the server to"
34
+ )
35
+ parser.add_argument(
36
+ "--log-level",
37
+ type=str,
38
+ default="INFO",
39
+ choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"],
40
+ help="Logging level",
41
+ )
42
+
43
+ args = parser.parse_args()
44
+
45
+ # Setup logging
46
+ log_level = getattr(logging, args.log_level)
47
+ setup_logging(log_level)
48
+ logger = logging.getLogger(__name__)
49
+
50
+ logger.info(f"Starting agent server on {args.host}:{args.port}")
51
+
52
+ # Create the RLAgent
53
+ agent = RLAgent()
54
+
55
+ # Start the server
56
+ try:
57
+ start_server(agent, args.host, args.port)
58
+ except KeyboardInterrupt:
59
+ logger.info("Server stopped by user")
60
+ except Exception as e:
61
+ logger.error(f"Error starting server: {e}", exc_info=True)
62
+ sys.exit(1)
63
+
64
+
65
+ if __name__ == "__main__":
66
+ main()
pyproject.toml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "storb-rl-miner"
3
+ version = "0.0.1"
4
+ description = "Storb RL Subnet Miner CLI"
5
+ readme = "README.md"
6
+ license = { file = "LICENSE" }
7
+ requires-python = ">=3.13"
8
+ dependencies = [
9
+ "fiber @ git+https://github.com/storb-tech/fiber.git#egg=fiber[chain]",
10
+ "metaworld>=3.0.0",
11
+ "torch>=2.8.0"
12
+ ]
13
+
14
+ [dependency-groups]
15
+ dev = ["debugpy>=1.8.9", "py-spy>=0.4.0", "pytest>=8.3.4", "ruff>=0.8.2"]
uv.lock ADDED
The diff for this file is too large to render. See raw diff