File size: 6,244 Bytes
fe5151f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 | """
Implementation of the AgentInterface for MetaWorld tasks.
This agent uses the SawyerPickPlaceV2Policy from MetaWorld as an expert policy.
"""
import logging
from typing import Any, Dict
import gymnasium as gym
import metaworld
import numpy as np
import torch
from agent_interface import AgentInterface
from metaworld.policies.sawyer_reach_v3_policy import SawyerReachV3Policy
class RLAgent(AgentInterface):
"""
MetaWorld agent implementation using the SawyerReachV3Policy expert policy.
This agent uses the expert policy from MetaWorld for reach tasks.
"""
def __init__(
self,
observation_space: gym.Space | None = None,
action_space: gym.Space | None = None,
seed: int | None = None,
**kwargs,
):
super().__init__(observation_space, action_space, seed, **kwargs)
self.logger = logging.getLogger(__name__)
self.logger.info(f"Initializing MetaWorld agent with seed {self.seed}")
self.policy = SawyerReachV3Policy()
self.logger.info("Successfully initialized SawyerReachV3Policy")
# Track episode state
self.episode_step = 0
self.max_episode_steps = kwargs.get("max_episode_steps", 200)
self.logger.info("MetaWorld agent initialized successfully")
def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor:
"""
Process the observation and return an action using the MetaWorld expert policy.
Args:
obs: Observation from the environment
kwargs: Additional arguments
Returns:
action: Action tensor to take in the environment
"""
try:
# Process observation to extract the format needed by the expert policy
processed_obs = self._process_observation(obs)
# Use the expert policy (MetaWorld is always available)
# MetaWorld policies expect numpy arrays
action_numpy = self.policy.get_action(processed_obs)
action_tensor = torch.from_numpy(np.array(action_numpy)).float()
# Log occasionally
if self.episode_step % 50 == 0:
self.logger.debug(f"Using expert policy action: {action_numpy}")
# Increment episode step
self.episode_step += 1
# Occasionally log actions to avoid spam
if self.episode_step % 50 == 0:
self.logger.debug(
f"Step {self.episode_step}: Action shape {action_tensor.shape}"
)
return action_tensor
except Exception as e:
self.logger.error(f"Error in act method: {e}", exc_info=True)
# Return zeros as a fallback
if isinstance(self.action_space, gym.spaces.Box):
return torch.zeros(self.action_space.shape[0], dtype=torch.float32)
else:
return torch.zeros(4, dtype=torch.float32)
def _process_observation(self, obs):
"""
Helper method to process observations for the MetaWorld expert policy.
MetaWorld policies typically expect a specific observation format.
"""
if isinstance(obs, dict):
# MetaWorld environment can return observations in different formats
if "observation" in obs:
# Standard format for goal-observable environments
processed_obs = obs["observation"]
elif "obs" in obs:
processed_obs = obs["obs"]
elif "state_observation" in obs:
# Some MetaWorld environments use this key
processed_obs = obs["state_observation"]
elif "goal_achieved" in obs:
# If we have information about goal achievement
# This might be needed for certain policy decisions
achievement = obs.get("goal_achieved", False)
base_obs = next(iter(obs.values()))
self.logger.debug(f"Goal achieved: {achievement}")
processed_obs = base_obs
else:
# If structure is unknown, use the first value
processed_obs = next(iter(obs.values()))
self.logger.debug(f"Using observation key: {next(iter(obs.keys()))}")
else:
# If already a numpy array or similar, use directly
processed_obs = obs
# Ensure we're returning a numpy array as expected by MetaWorld policies
if not isinstance(processed_obs, np.ndarray):
try:
processed_obs = np.array(processed_obs, dtype=np.float32)
except Exception as e:
self.logger.error(f"Failed to convert observation to numpy array: {e}")
# Return a dummy observation if conversion fails
if (
self.observation_space
and hasattr(self.observation_space, "shape")
and self.observation_space.shape is not None
):
processed_obs = np.zeros(
self.observation_space.shape, dtype=np.float32
)
else:
# Typical MetaWorld observation dimension if all else fails
processed_obs = np.zeros(39, dtype=np.float32)
return processed_obs
def reset(self) -> None:
"""
Reset agent state between episodes.
"""
self.logger.debug("Resetting agent")
self.episode_step = 0
# Any other stateful components would be reset here
def _build_model(self):
"""
Build a neural network model for the agent.
This is a placeholder for where you would define your neural network
architecture using PyTorch, TensorFlow, or another framework.
"""
# Example of where you might build a simple PyTorch model
# model = torch.nn.Sequential(
# torch.nn.Linear(self.observation_space.shape[0], 128),
# torch.nn.ReLU(),
# torch.nn.Linear(128, 64),
# torch.nn.ReLU(),
# torch.nn.Linear(64, self.action_space.shape[0]),
# )
# return model
pass
|