File size: 6,244 Bytes
725af76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
"""
Implementation of the AgentInterface for MetaWorld tasks.

This agent uses the SawyerPickPlaceV2Policy from MetaWorld as an expert policy.
"""

import logging
from typing import Any, Dict

import gymnasium as gym
import metaworld
import numpy as np
import torch
from agent_interface import AgentInterface
from metaworld.policies.sawyer_reach_v3_policy import SawyerReachV3Policy


class RLAgent(AgentInterface):
    """
    MetaWorld agent implementation using the SawyerReachV3Policy expert policy.

    This agent uses the expert policy from MetaWorld for reach tasks.
    """

    def __init__(
        self,
        observation_space: gym.Space | None = None,
        action_space: gym.Space | None = None,
        seed: int | None = None,
        **kwargs,
    ):
        super().__init__(observation_space, action_space, seed, **kwargs)

        self.logger = logging.getLogger(__name__)
        self.logger.info(f"Initializing MetaWorld agent with seed {self.seed}")

        self.policy = SawyerReachV3Policy()
        self.logger.info("Successfully initialized SawyerReachV3Policy")

        # Track episode state
        self.episode_step = 0
        self.max_episode_steps = kwargs.get("max_episode_steps", 200)

        self.logger.info("MetaWorld agent initialized successfully")

    def act(self, obs: Dict[str, Any], **kwargs) -> torch.Tensor:
        """
        Process the observation and return an action using the MetaWorld expert policy.

        Args:
            obs: Observation from the environment
            kwargs: Additional arguments

        Returns:
            action: Action tensor to take in the environment
        """
        try:
            # Process observation to extract the format needed by the expert policy
            processed_obs = self._process_observation(obs)

            # Use the expert policy (MetaWorld is always available)
            # MetaWorld policies expect numpy arrays
            action_numpy = self.policy.get_action(processed_obs)
            action_tensor = torch.from_numpy(np.array(action_numpy)).float()

            # Log occasionally
            if self.episode_step % 50 == 0:
                self.logger.debug(f"Using expert policy action: {action_numpy}")

            # Increment episode step
            self.episode_step += 1

            # Occasionally log actions to avoid spam
            if self.episode_step % 50 == 0:
                self.logger.debug(
                    f"Step {self.episode_step}: Action shape {action_tensor.shape}"
                )

            return action_tensor

        except Exception as e:
            self.logger.error(f"Error in act method: {e}", exc_info=True)
            # Return zeros as a fallback
            if isinstance(self.action_space, gym.spaces.Box):
                return torch.zeros(self.action_space.shape[0], dtype=torch.float32)
            else:
                return torch.zeros(4, dtype=torch.float32)

    def _process_observation(self, obs):
        """
        Helper method to process observations for the MetaWorld expert policy.

        MetaWorld policies typically expect a specific observation format.
        """
        if isinstance(obs, dict):
            # MetaWorld environment can return observations in different formats
            if "observation" in obs:
                # Standard format for goal-observable environments
                processed_obs = obs["observation"]
            elif "obs" in obs:
                processed_obs = obs["obs"]
            elif "state_observation" in obs:
                # Some MetaWorld environments use this key
                processed_obs = obs["state_observation"]
            elif "goal_achieved" in obs:
                # If we have information about goal achievement
                # This might be needed for certain policy decisions
                achievement = obs.get("goal_achieved", False)
                base_obs = next(iter(obs.values()))
                self.logger.debug(f"Goal achieved: {achievement}")
                processed_obs = base_obs
            else:
                # If structure is unknown, use the first value
                processed_obs = next(iter(obs.values()))
                self.logger.debug(f"Using observation key: {next(iter(obs.keys()))}")
        else:
            # If already a numpy array or similar, use directly
            processed_obs = obs

        # Ensure we're returning a numpy array as expected by MetaWorld policies
        if not isinstance(processed_obs, np.ndarray):
            try:
                processed_obs = np.array(processed_obs, dtype=np.float32)
            except Exception as e:
                self.logger.error(f"Failed to convert observation to numpy array: {e}")
                # Return a dummy observation if conversion fails
                if (
                    self.observation_space
                    and hasattr(self.observation_space, "shape")
                    and self.observation_space.shape is not None
                ):
                    processed_obs = np.zeros(
                        self.observation_space.shape, dtype=np.float32
                    )
                else:
                    # Typical MetaWorld observation dimension if all else fails
                    processed_obs = np.zeros(39, dtype=np.float32)

        return processed_obs

    def reset(self) -> None:
        """
        Reset agent state between episodes.
        """
        self.logger.debug("Resetting agent")
        self.episode_step = 0
        # Any other stateful components would be reset here

    def _build_model(self):
        """
        Build a neural network model for the agent.

        This is a placeholder for where you would define your neural network
        architecture using PyTorch, TensorFlow, or another framework.
        """
        # Example of where you might build a simple PyTorch model
        # model = torch.nn.Sequential(
        #     torch.nn.Linear(self.observation_space.shape[0], 128),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(128, 64),
        #     torch.nn.ReLU(),
        #     torch.nn.Linear(64, self.action_space.shape[0]),
        # )
        # return model
        pass