Spaces:
Runtime error
Runtime error
| from typing import ( | |
| Any, | |
| Dict, | |
| List, | |
| Tuple, | |
| ) | |
| import gym | |
| import numpy as np | |
| class SuccessCounter(gym.Wrapper): | |
| """Helper class to keep count of successes in MetaWorld environments.""" | |
| def __init__(self, env: gym.Env) -> None: | |
| super().__init__(env) | |
| self.successes = [] | |
| self.current_success = False | |
| def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]: | |
| obs, reward, done, info = self.env.step(action) | |
| if info.get("success", False): | |
| self.current_success = True | |
| if done: | |
| self.successes.append(self.current_success) | |
| return obs, reward, done, info | |
| def pop_successes(self) -> List[bool]: | |
| res = self.successes | |
| self.successes = [] | |
| return res | |
| def reset(self, **kwargs) -> np.ndarray: | |
| self.current_success = False | |
| return self.env.reset(**kwargs) | |